7

Add unittests for the session adapter

This commit is contained in:
Michael van Tellingen
2017-05-31 16:08:07 +02:00
committed by Michael van Tellingen
parent e107d73716
commit 03073eb004
5 changed files with 68 additions and 10 deletions

View File

@ -78,6 +78,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
def __init__(self, request):
super(SessionSegmentsAdapter, self).__init__(request)
self.request.session.setdefault('segments', [])
self._segment_cache = None
def get_segments(self):
"""Return the persistent segments stored in the request session.
@ -86,6 +87,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: list of wagtail_personalisation.models.Segment or empty list
"""
if self._segment_cache is not None:
return self._segment_cache
raw_segments = self.request.session['segments']
segment_ids = [segment['id'] for segment in raw_segments]
@ -95,7 +99,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
.filter(persistent=True)
.in_bulk(segment_ids))
return [segments[pk] for pk in segment_ids if pk in segments]
retval = [segments[pk] for pk in segment_ids if pk in segments]
self._segment_cache = retval
return retval
def set_segments(self, segments):
"""Set the currently active segments
@ -104,6 +110,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:type segments: list of wagtail_personalisation.models.Segment
"""
cache_segments = []
serialized_segments = []
segment_ids = set()
for segment in segments:
@ -111,10 +118,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
if serialized['id'] in segment_ids:
continue
cache_segments.append(segment)
serialized_segments.append(serialized)
segment_ids.add(segment.pk)
self.request.session['segments'] = serialized_segments
self._segment_cache = cache_segments
def get_segment_by_id(self, segment_id):
"""Find and return a single segment from the request session.
@ -125,11 +134,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: wagtail_personalisation.models.Segment or None
"""
try:
return next(item for item in self.request.session['segments']
if item['id'] == segment_id)
except StopIteration:
return None
segments = self.get_segments()
return next((s for s in segments if s.pk == segment_id), None)
def add_page_visit(self, page):
"""Mark the page as visited by the user"""