Add unittests for the session adapter
This commit is contained in:
committed by
Michael van Tellingen
parent
e107d73716
commit
03073eb004
@ -78,6 +78,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
def __init__(self, request):
|
||||
super(SessionSegmentsAdapter, self).__init__(request)
|
||||
self.request.session.setdefault('segments', [])
|
||||
self._segment_cache = None
|
||||
|
||||
def get_segments(self):
|
||||
"""Return the persistent segments stored in the request session.
|
||||
@ -86,6 +87,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
:rtype: list of wagtail_personalisation.models.Segment or empty list
|
||||
|
||||
"""
|
||||
if self._segment_cache is not None:
|
||||
return self._segment_cache
|
||||
|
||||
raw_segments = self.request.session['segments']
|
||||
segment_ids = [segment['id'] for segment in raw_segments]
|
||||
|
||||
@ -95,7 +99,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
.filter(persistent=True)
|
||||
.in_bulk(segment_ids))
|
||||
|
||||
return [segments[pk] for pk in segment_ids if pk in segments]
|
||||
retval = [segments[pk] for pk in segment_ids if pk in segments]
|
||||
self._segment_cache = retval
|
||||
return retval
|
||||
|
||||
def set_segments(self, segments):
|
||||
"""Set the currently active segments
|
||||
@ -104,6 +110,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
:type segments: list of wagtail_personalisation.models.Segment
|
||||
|
||||
"""
|
||||
cache_segments = []
|
||||
serialized_segments = []
|
||||
segment_ids = set()
|
||||
for segment in segments:
|
||||
@ -111,10 +118,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
if serialized['id'] in segment_ids:
|
||||
continue
|
||||
|
||||
cache_segments.append(segment)
|
||||
serialized_segments.append(serialized)
|
||||
segment_ids.add(segment.pk)
|
||||
|
||||
self.request.session['segments'] = serialized_segments
|
||||
self._segment_cache = cache_segments
|
||||
|
||||
def get_segment_by_id(self, segment_id):
|
||||
"""Find and return a single segment from the request session.
|
||||
@ -125,11 +134,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
||||
:rtype: wagtail_personalisation.models.Segment or None
|
||||
|
||||
"""
|
||||
try:
|
||||
return next(item for item in self.request.session['segments']
|
||||
if item['id'] == segment_id)
|
||||
except StopIteration:
|
||||
return None
|
||||
segments = self.get_segments()
|
||||
return next((s for s in segments if s.pk == segment_id), None)
|
||||
|
||||
def add_page_visit(self, page):
|
||||
"""Mark the page as visited by the user"""
|
||||
|
@ -2,8 +2,6 @@ from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import pytest
|
||||
from wagtail.wagtailcore.models import Page, Site
|
||||
from wagtail_factories import SiteFactory
|
||||
from tests.factories.page import PageFactory
|
||||
|
||||
pytest_plugins = [
|
||||
'tests.fixtures'
|
||||
@ -16,4 +14,3 @@ def django_db_setup(django_db_setup, django_db_blocker):
|
||||
# Remove some initial data that is brought by the sandbox module
|
||||
Site.objects.all().delete()
|
||||
Page.objects.all().exclude(depth=1).delete()
|
||||
|
||||
|
@ -7,7 +7,7 @@ from wagtail_personalisation import models
|
||||
|
||||
class SegmentFactory(factory.DjangoModelFactory):
|
||||
name = 'TestSegment'
|
||||
status = 'enabled'
|
||||
status = models.Segment.STATUS_ENABLED
|
||||
|
||||
class Meta:
|
||||
model = models.Segment
|
||||
|
@ -1,5 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.contrib.messages.storage.fallback import FallbackStorage
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
from django.test.client import RequestFactory as BaseRequestFactory
|
||||
from tests.factories.page import PageFactory
|
||||
from tests.factories.segment import SegmentFactory
|
||||
from tests.factories.site import SiteFactory
|
||||
@ -18,3 +22,19 @@ def segmented_page(site):
|
||||
page = PageFactory(parent=site.root_page)
|
||||
segment = SegmentFactory()
|
||||
return page.copy_for_segment(segment)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def rf():
|
||||
"""RequestFactory instance"""
|
||||
return RequestFactory()
|
||||
|
||||
|
||||
class RequestFactory(BaseRequestFactory):
|
||||
|
||||
def request(self, user=None, **request):
|
||||
request = super(RequestFactory, self).request(**request)
|
||||
request.user = AnonymousUser()
|
||||
request.session = SessionStore()
|
||||
request._messages = FallbackStorage(request)
|
||||
return request
|
||||
|
35
tests/unit/test_adapter_session.py
Normal file
35
tests/unit/test_adapter_session.py
Normal file
@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
|
||||
from wagtail_personalisation import adapters
|
||||
from tests.factories.segment import SegmentFactory
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_segments(rf, monkeypatch):
|
||||
request = rf.get('/')
|
||||
|
||||
adapter = adapters.SessionSegmentsAdapter(request)
|
||||
|
||||
segment_1 = SegmentFactory(name='segment-1', persistent=True)
|
||||
segment_2 = SegmentFactory(name='segment-2', persistent=True)
|
||||
|
||||
adapter.set_segments([segment_1, segment_2])
|
||||
assert len(request.session['segments']) == 2
|
||||
|
||||
segments = adapter.get_segments()
|
||||
assert segments == [segment_1, segment_2]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_get_segment_by_id(rf, monkeypatch):
|
||||
request = rf.get('/')
|
||||
|
||||
adapter = adapters.SessionSegmentsAdapter(request)
|
||||
|
||||
segment_1 = SegmentFactory(name='segment-1', persistent=True)
|
||||
segment_2 = SegmentFactory(name='segment-2', persistent=True)
|
||||
|
||||
adapter.set_segments([segment_1, segment_2])
|
||||
|
||||
segment_x = adapter.get_segment_by_id(segment_2.pk)
|
||||
assert segment_x == segment_2
|
Reference in New Issue
Block a user