Add unittests for the session adapter
This commit is contained in:
committed by
Michael van Tellingen
parent
e107d73716
commit
03073eb004
@ -78,6 +78,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
def __init__(self, request):
|
def __init__(self, request):
|
||||||
super(SessionSegmentsAdapter, self).__init__(request)
|
super(SessionSegmentsAdapter, self).__init__(request)
|
||||||
self.request.session.setdefault('segments', [])
|
self.request.session.setdefault('segments', [])
|
||||||
|
self._segment_cache = None
|
||||||
|
|
||||||
def get_segments(self):
|
def get_segments(self):
|
||||||
"""Return the persistent segments stored in the request session.
|
"""Return the persistent segments stored in the request session.
|
||||||
@ -86,6 +87,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
:rtype: list of wagtail_personalisation.models.Segment or empty list
|
:rtype: list of wagtail_personalisation.models.Segment or empty list
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if self._segment_cache is not None:
|
||||||
|
return self._segment_cache
|
||||||
|
|
||||||
raw_segments = self.request.session['segments']
|
raw_segments = self.request.session['segments']
|
||||||
segment_ids = [segment['id'] for segment in raw_segments]
|
segment_ids = [segment['id'] for segment in raw_segments]
|
||||||
|
|
||||||
@ -95,7 +99,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
.filter(persistent=True)
|
.filter(persistent=True)
|
||||||
.in_bulk(segment_ids))
|
.in_bulk(segment_ids))
|
||||||
|
|
||||||
return [segments[pk] for pk in segment_ids if pk in segments]
|
retval = [segments[pk] for pk in segment_ids if pk in segments]
|
||||||
|
self._segment_cache = retval
|
||||||
|
return retval
|
||||||
|
|
||||||
def set_segments(self, segments):
|
def set_segments(self, segments):
|
||||||
"""Set the currently active segments
|
"""Set the currently active segments
|
||||||
@ -104,6 +110,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
:type segments: list of wagtail_personalisation.models.Segment
|
:type segments: list of wagtail_personalisation.models.Segment
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
cache_segments = []
|
||||||
serialized_segments = []
|
serialized_segments = []
|
||||||
segment_ids = set()
|
segment_ids = set()
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
@ -111,10 +118,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
if serialized['id'] in segment_ids:
|
if serialized['id'] in segment_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
cache_segments.append(segment)
|
||||||
serialized_segments.append(serialized)
|
serialized_segments.append(serialized)
|
||||||
segment_ids.add(segment.pk)
|
segment_ids.add(segment.pk)
|
||||||
|
|
||||||
self.request.session['segments'] = serialized_segments
|
self.request.session['segments'] = serialized_segments
|
||||||
|
self._segment_cache = cache_segments
|
||||||
|
|
||||||
def get_segment_by_id(self, segment_id):
|
def get_segment_by_id(self, segment_id):
|
||||||
"""Find and return a single segment from the request session.
|
"""Find and return a single segment from the request session.
|
||||||
@ -125,11 +134,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
:rtype: wagtail_personalisation.models.Segment or None
|
:rtype: wagtail_personalisation.models.Segment or None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
segments = self.get_segments()
|
||||||
return next(item for item in self.request.session['segments']
|
return next((s for s in segments if s.pk == segment_id), None)
|
||||||
if item['id'] == segment_id)
|
|
||||||
except StopIteration:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def add_page_visit(self, page):
|
def add_page_visit(self, page):
|
||||||
"""Mark the page as visited by the user"""
|
"""Mark the page as visited by the user"""
|
||||||
|
@ -2,8 +2,6 @@ from __future__ import absolute_import, unicode_literals
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from wagtail.wagtailcore.models import Page, Site
|
from wagtail.wagtailcore.models import Page, Site
|
||||||
from wagtail_factories import SiteFactory
|
|
||||||
from tests.factories.page import PageFactory
|
|
||||||
|
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
'tests.fixtures'
|
'tests.fixtures'
|
||||||
@ -16,4 +14,3 @@ def django_db_setup(django_db_setup, django_db_blocker):
|
|||||||
# Remove some initial data that is brought by the sandbox module
|
# Remove some initial data that is brought by the sandbox module
|
||||||
Site.objects.all().delete()
|
Site.objects.all().delete()
|
||||||
Page.objects.all().exclude(depth=1).delete()
|
Page.objects.all().exclude(depth=1).delete()
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from wagtail_personalisation import models
|
|||||||
|
|
||||||
class SegmentFactory(factory.DjangoModelFactory):
|
class SegmentFactory(factory.DjangoModelFactory):
|
||||||
name = 'TestSegment'
|
name = 'TestSegment'
|
||||||
status = 'enabled'
|
status = models.Segment.STATUS_ENABLED
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = models.Segment
|
model = models.Segment
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from django.contrib.auth.models import AnonymousUser
|
||||||
|
from django.contrib.messages.storage.fallback import FallbackStorage
|
||||||
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
from django.test.client import RequestFactory as BaseRequestFactory
|
||||||
from tests.factories.page import PageFactory
|
from tests.factories.page import PageFactory
|
||||||
from tests.factories.segment import SegmentFactory
|
from tests.factories.segment import SegmentFactory
|
||||||
from tests.factories.site import SiteFactory
|
from tests.factories.site import SiteFactory
|
||||||
@ -18,3 +22,19 @@ def segmented_page(site):
|
|||||||
page = PageFactory(parent=site.root_page)
|
page = PageFactory(parent=site.root_page)
|
||||||
segment = SegmentFactory()
|
segment = SegmentFactory()
|
||||||
return page.copy_for_segment(segment)
|
return page.copy_for_segment(segment)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def rf():
|
||||||
|
"""RequestFactory instance"""
|
||||||
|
return RequestFactory()
|
||||||
|
|
||||||
|
|
||||||
|
class RequestFactory(BaseRequestFactory):
|
||||||
|
|
||||||
|
def request(self, user=None, **request):
|
||||||
|
request = super(RequestFactory, self).request(**request)
|
||||||
|
request.user = AnonymousUser()
|
||||||
|
request.session = SessionStore()
|
||||||
|
request._messages = FallbackStorage(request)
|
||||||
|
return request
|
||||||
|
35
tests/unit/test_adapter_session.py
Normal file
35
tests/unit/test_adapter_session.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from wagtail_personalisation import adapters
|
||||||
|
from tests.factories.segment import SegmentFactory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_get_segments(rf, monkeypatch):
|
||||||
|
request = rf.get('/')
|
||||||
|
|
||||||
|
adapter = adapters.SessionSegmentsAdapter(request)
|
||||||
|
|
||||||
|
segment_1 = SegmentFactory(name='segment-1', persistent=True)
|
||||||
|
segment_2 = SegmentFactory(name='segment-2', persistent=True)
|
||||||
|
|
||||||
|
adapter.set_segments([segment_1, segment_2])
|
||||||
|
assert len(request.session['segments']) == 2
|
||||||
|
|
||||||
|
segments = adapter.get_segments()
|
||||||
|
assert segments == [segment_1, segment_2]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_get_segment_by_id(rf, monkeypatch):
|
||||||
|
request = rf.get('/')
|
||||||
|
|
||||||
|
adapter = adapters.SessionSegmentsAdapter(request)
|
||||||
|
|
||||||
|
segment_1 = SegmentFactory(name='segment-1', persistent=True)
|
||||||
|
segment_2 = SegmentFactory(name='segment-2', persistent=True)
|
||||||
|
|
||||||
|
adapter.set_segments([segment_1, segment_2])
|
||||||
|
|
||||||
|
segment_x = adapter.get_segment_by_id(segment_2.pk)
|
||||||
|
assert segment_x == segment_2
|
Reference in New Issue
Block a user