7

Add unittests for the session adapter

This commit is contained in:
Michael van Tellingen
2017-05-31 16:08:07 +02:00
committed by Michael van Tellingen
parent e107d73716
commit 03073eb004
5 changed files with 68 additions and 10 deletions

View File

@ -78,6 +78,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
def __init__(self, request): def __init__(self, request):
super(SessionSegmentsAdapter, self).__init__(request) super(SessionSegmentsAdapter, self).__init__(request)
self.request.session.setdefault('segments', []) self.request.session.setdefault('segments', [])
self._segment_cache = None
def get_segments(self): def get_segments(self):
"""Return the persistent segments stored in the request session. """Return the persistent segments stored in the request session.
@ -86,6 +87,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: list of wagtail_personalisation.models.Segment or empty list :rtype: list of wagtail_personalisation.models.Segment or empty list
""" """
if self._segment_cache is not None:
return self._segment_cache
raw_segments = self.request.session['segments'] raw_segments = self.request.session['segments']
segment_ids = [segment['id'] for segment in raw_segments] segment_ids = [segment['id'] for segment in raw_segments]
@ -95,7 +99,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
.filter(persistent=True) .filter(persistent=True)
.in_bulk(segment_ids)) .in_bulk(segment_ids))
return [segments[pk] for pk in segment_ids if pk in segments] retval = [segments[pk] for pk in segment_ids if pk in segments]
self._segment_cache = retval
return retval
def set_segments(self, segments): def set_segments(self, segments):
"""Set the currently active segments """Set the currently active segments
@ -104,6 +110,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:type segments: list of wagtail_personalisation.models.Segment :type segments: list of wagtail_personalisation.models.Segment
""" """
cache_segments = []
serialized_segments = [] serialized_segments = []
segment_ids = set() segment_ids = set()
for segment in segments: for segment in segments:
@ -111,10 +118,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
if serialized['id'] in segment_ids: if serialized['id'] in segment_ids:
continue continue
cache_segments.append(segment)
serialized_segments.append(serialized) serialized_segments.append(serialized)
segment_ids.add(segment.pk) segment_ids.add(segment.pk)
self.request.session['segments'] = serialized_segments self.request.session['segments'] = serialized_segments
self._segment_cache = cache_segments
def get_segment_by_id(self, segment_id): def get_segment_by_id(self, segment_id):
"""Find and return a single segment from the request session. """Find and return a single segment from the request session.
@ -125,11 +134,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: wagtail_personalisation.models.Segment or None :rtype: wagtail_personalisation.models.Segment or None
""" """
try: segments = self.get_segments()
return next(item for item in self.request.session['segments'] return next((s for s in segments if s.pk == segment_id), None)
if item['id'] == segment_id)
except StopIteration:
return None
def add_page_visit(self, page): def add_page_visit(self, page):
"""Mark the page as visited by the user""" """Mark the page as visited by the user"""

View File

@ -2,8 +2,6 @@ from __future__ import absolute_import, unicode_literals
import pytest import pytest
from wagtail.wagtailcore.models import Page, Site from wagtail.wagtailcore.models import Page, Site
from wagtail_factories import SiteFactory
from tests.factories.page import PageFactory
pytest_plugins = [ pytest_plugins = [
'tests.fixtures' 'tests.fixtures'
@ -16,4 +14,3 @@ def django_db_setup(django_db_setup, django_db_blocker):
# Remove some initial data that is brought by the sandbox module # Remove some initial data that is brought by the sandbox module
Site.objects.all().delete() Site.objects.all().delete()
Page.objects.all().exclude(depth=1).delete() Page.objects.all().exclude(depth=1).delete()

View File

@ -7,7 +7,7 @@ from wagtail_personalisation import models
class SegmentFactory(factory.DjangoModelFactory): class SegmentFactory(factory.DjangoModelFactory):
name = 'TestSegment' name = 'TestSegment'
status = 'enabled' status = models.Segment.STATUS_ENABLED
class Meta: class Meta:
model = models.Segment model = models.Segment

View File

@ -1,5 +1,9 @@
import pytest import pytest
from django.contrib.auth.models import AnonymousUser
from django.contrib.messages.storage.fallback import FallbackStorage
from django.contrib.sessions.backends.db import SessionStore
from django.test.client import RequestFactory as BaseRequestFactory
from tests.factories.page import PageFactory from tests.factories.page import PageFactory
from tests.factories.segment import SegmentFactory from tests.factories.segment import SegmentFactory
from tests.factories.site import SiteFactory from tests.factories.site import SiteFactory
@ -18,3 +22,19 @@ def segmented_page(site):
page = PageFactory(parent=site.root_page) page = PageFactory(parent=site.root_page)
segment = SegmentFactory() segment = SegmentFactory()
return page.copy_for_segment(segment) return page.copy_for_segment(segment)
@pytest.fixture()
def rf():
"""RequestFactory instance"""
return RequestFactory()
class RequestFactory(BaseRequestFactory):
def request(self, user=None, **request):
request = super(RequestFactory, self).request(**request)
request.user = AnonymousUser()
request.session = SessionStore()
request._messages = FallbackStorage(request)
return request

View File

@ -0,0 +1,35 @@
import pytest
from wagtail_personalisation import adapters
from tests.factories.segment import SegmentFactory
@pytest.mark.django_db
def test_get_segments(rf, monkeypatch):
request = rf.get('/')
adapter = adapters.SessionSegmentsAdapter(request)
segment_1 = SegmentFactory(name='segment-1', persistent=True)
segment_2 = SegmentFactory(name='segment-2', persistent=True)
adapter.set_segments([segment_1, segment_2])
assert len(request.session['segments']) == 2
segments = adapter.get_segments()
assert segments == [segment_1, segment_2]
@pytest.mark.django_db
def test_get_segment_by_id(rf, monkeypatch):
request = rf.get('/')
adapter = adapters.SessionSegmentsAdapter(request)
segment_1 = SegmentFactory(name='segment-1', persistent=True)
segment_2 = SegmentFactory(name='segment-2', persistent=True)
adapter.set_segments([segment_1, segment_2])
segment_x = adapter.get_segment_by_id(segment_2.pk)
assert segment_x == segment_2