diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index 2c9ce72..0739daa 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -25,11 +25,11 @@ class BaseSegmentsAdapter(object): """Prepare the adapter for segment storage.""" return None - def get_all_segments(self): + def get_segments(self): """Return the segments stored in the adapter storage.""" return None - def get_segment(self): + def get_segment_by_id(self): """Return a single segment stored in the adapter storage.""" return None @@ -79,7 +79,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): super(SessionSegmentsAdapter, self).__init__(request) self.request.session.setdefault('segments', []) - def get_all_segments(self): + def get_segments(self): """Return the segments stored in the request session. :returns: The segments in the request session @@ -88,7 +88,11 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): """ return self.request.session['segments'] - def get_segment(self, segment_id): + def set_segments(self, segments): + """Set the currently active segments""" + self.request.session['segments'] = segments + + def get_segment_by_id(self, segment_id): """Find and return a single segment from the request session. :param segment_id: The primary key of the segment @@ -175,7 +179,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): """ enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED) persistent_segments = enabled_segments.filter(persistent=True) - session_segments = self.request.session['segments'] + + session_segments = self.get_segments() rules = AbstractBaseRule.__subclasses__() # Create a list to store the new request session segments and @@ -196,8 +201,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): if not any(seg['id'] == segdict['id'] for seg in new_segments): new_segments.append(segdict) - self.request.session['segments'] = new_segments - + self.set_segments(new_segments) self.update_visit_count() diff --git a/src/wagtail_personalisation/blocks.py b/src/wagtail_personalisation/blocks.py index 0faa135..b3ad67c 100644 --- a/src/wagtail_personalisation/blocks.py +++ b/src/wagtail_personalisation/blocks.py @@ -33,7 +33,7 @@ class PersonalisedStructBlock(blocks.StructBlock): """ request = context['request'] adapter = get_segment_adapter(request) - user_segments = adapter.get_all_segments() + user_segments = adapter.get_segments() if value['segment']: for segment in user_segments: diff --git a/src/wagtail_personalisation/templatetags/wagtail_personalisation_tags.py b/src/wagtail_personalisation/templatetags/wagtail_personalisation_tags.py index 00e166c..ac3b1f5 100644 --- a/src/wagtail_personalisation/templatetags/wagtail_personalisation_tags.py +++ b/src/wagtail_personalisation/templatetags/wagtail_personalisation_tags.py @@ -1,8 +1,8 @@ from django import template from django.template import TemplateSyntaxError - from django.utils.safestring import mark_safe +from wagtail_personalisation.adapters import get_segment_adapter from wagtail_personalisation.models import Segment from wagtail_personalisation.utils import parse_tag @@ -48,11 +48,11 @@ class SegmentNode(template.Node): return "" # Check if user has segment - user_segment = context['request'].segment_adapter.get_segment(segment_id=segment.pk) + adapter = get_segment_adapter(context['request']) + user_segment = adapter.get_segment_by_id(segment_id=segment.pk) if not user_segment: - return "" + return '' content = self.nodelist.render(context) content = mark_safe(content) - return content diff --git a/src/wagtail_personalisation/wagtail_hooks.py b/src/wagtail_personalisation/wagtail_hooks.py index 6a4bf60..a8f1175 100644 --- a/src/wagtail_personalisation/wagtail_hooks.py +++ b/src/wagtail_personalisation/wagtail_hooks.py @@ -74,7 +74,7 @@ def serve_variation(page, request, serve_args, serve_kwargs): user_segments = [] adapter = get_segment_adapter(request) - for segment in adapter.get_all_segments(): + for segment in adapter.get_segments(): try: user_segment = Segment.objects.get( pk=segment['id'], status=Segment.STATUS_ENABLED)