diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index 5120472..8511627 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -78,6 +78,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): def __init__(self, request): super(SessionSegmentsAdapter, self).__init__(request) self.request.session.setdefault('segments', []) + self._segment_cache = None def get_segments(self): """Return the persistent segments stored in the request session. @@ -86,6 +87,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): :rtype: list of wagtail_personalisation.models.Segment or empty list """ + if self._segment_cache is not None: + return self._segment_cache + raw_segments = self.request.session['segments'] segment_ids = [segment['id'] for segment in raw_segments] @@ -95,7 +99,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): .filter(persistent=True) .in_bulk(segment_ids)) - return [segments[pk] for pk in segment_ids if pk in segments] + retval = [segments[pk] for pk in segment_ids if pk in segments] + self._segment_cache = retval + return retval def set_segments(self, segments): """Set the currently active segments @@ -104,6 +110,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): :type segments: list of wagtail_personalisation.models.Segment """ + cache_segments = [] serialized_segments = [] segment_ids = set() for segment in segments: @@ -111,10 +118,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): if serialized['id'] in segment_ids: continue + cache_segments.append(segment) serialized_segments.append(serialized) segment_ids.add(segment.pk) self.request.session['segments'] = serialized_segments + self._segment_cache = cache_segments def get_segment_by_id(self, segment_id): """Find and return a single segment from the request session. @@ -125,11 +134,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): :rtype: wagtail_personalisation.models.Segment or None """ - try: - return next(item for item in self.request.session['segments'] - if item['id'] == segment_id) - except StopIteration: - return None + segments = self.get_segments() + return next((s for s in segments if s.pk == segment_id), None) def add_page_visit(self, page): """Mark the page as visited by the user""" diff --git a/tests/conftest.py b/tests/conftest.py index 0537186..8942e2e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,6 @@ from __future__ import absolute_import, unicode_literals import pytest from wagtail.wagtailcore.models import Page, Site -from wagtail_factories import SiteFactory -from tests.factories.page import PageFactory pytest_plugins = [ 'tests.fixtures' @@ -16,4 +14,3 @@ def django_db_setup(django_db_setup, django_db_blocker): # Remove some initial data that is brought by the sandbox module Site.objects.all().delete() Page.objects.all().exclude(depth=1).delete() - diff --git a/tests/factories/segment.py b/tests/factories/segment.py index 5ea0b6c..f92f579 100644 --- a/tests/factories/segment.py +++ b/tests/factories/segment.py @@ -7,7 +7,7 @@ from wagtail_personalisation import models class SegmentFactory(factory.DjangoModelFactory): name = 'TestSegment' - status = 'enabled' + status = models.Segment.STATUS_ENABLED class Meta: model = models.Segment diff --git a/tests/fixtures.py b/tests/fixtures.py index d023261..0e3e437 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,5 +1,9 @@ import pytest +from django.contrib.auth.models import AnonymousUser +from django.contrib.messages.storage.fallback import FallbackStorage +from django.contrib.sessions.backends.db import SessionStore +from django.test.client import RequestFactory as BaseRequestFactory from tests.factories.page import PageFactory from tests.factories.segment import SegmentFactory from tests.factories.site import SiteFactory @@ -18,3 +22,19 @@ def segmented_page(site): page = PageFactory(parent=site.root_page) segment = SegmentFactory() return page.copy_for_segment(segment) + + +@pytest.fixture() +def rf(): + """RequestFactory instance""" + return RequestFactory() + + +class RequestFactory(BaseRequestFactory): + + def request(self, user=None, **request): + request = super(RequestFactory, self).request(**request) + request.user = AnonymousUser() + request.session = SessionStore() + request._messages = FallbackStorage(request) + return request diff --git a/tests/unit/test_adapter_session.py b/tests/unit/test_adapter_session.py new file mode 100644 index 0000000..75f4bcc --- /dev/null +++ b/tests/unit/test_adapter_session.py @@ -0,0 +1,35 @@ +import pytest + +from wagtail_personalisation import adapters +from tests.factories.segment import SegmentFactory + + +@pytest.mark.django_db +def test_get_segments(rf, monkeypatch): + request = rf.get('/') + + adapter = adapters.SessionSegmentsAdapter(request) + + segment_1 = SegmentFactory(name='segment-1', persistent=True) + segment_2 = SegmentFactory(name='segment-2', persistent=True) + + adapter.set_segments([segment_1, segment_2]) + assert len(request.session['segments']) == 2 + + segments = adapter.get_segments() + assert segments == [segment_1, segment_2] + + +@pytest.mark.django_db +def test_get_segment_by_id(rf, monkeypatch): + request = rf.get('/') + + adapter = adapters.SessionSegmentsAdapter(request) + + segment_1 = SegmentFactory(name='segment-1', persistent=True) + segment_2 = SegmentFactory(name='segment-2', persistent=True) + + adapter.set_segments([segment_1, segment_2]) + + segment_x = adapter.get_segment_by_id(segment_2.pk) + assert segment_x == segment_2