7

Add Adapter.set_segments() and use it when refreshig the segment

This makes the internal API a bit more consistent
This commit is contained in:
Michael van Tellingen
2017-05-31 15:14:32 +02:00
parent decfc88efe
commit f2aa8879a9
4 changed files with 17 additions and 13 deletions

View File

@ -25,11 +25,11 @@ class BaseSegmentsAdapter(object):
"""Prepare the adapter for segment storage.""" """Prepare the adapter for segment storage."""
return None return None
def get_all_segments(self): def get_segments(self):
"""Return the segments stored in the adapter storage.""" """Return the segments stored in the adapter storage."""
return None return None
def get_segment(self): def get_segment_by_id(self):
"""Return a single segment stored in the adapter storage.""" """Return a single segment stored in the adapter storage."""
return None return None
@ -79,7 +79,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
super(SessionSegmentsAdapter, self).__init__(request) super(SessionSegmentsAdapter, self).__init__(request)
self.request.session.setdefault('segments', []) self.request.session.setdefault('segments', [])
def get_all_segments(self): def get_segments(self):
"""Return the segments stored in the request session. """Return the segments stored in the request session.
:returns: The segments in the request session :returns: The segments in the request session
@ -88,7 +88,11 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
""" """
return self.request.session['segments'] return self.request.session['segments']
def get_segment(self, segment_id): def set_segments(self, segments):
"""Set the currently active segments"""
self.request.session['segments'] = segments
def get_segment_by_id(self, segment_id):
"""Find and return a single segment from the request session. """Find and return a single segment from the request session.
:param segment_id: The primary key of the segment :param segment_id: The primary key of the segment
@ -175,7 +179,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
""" """
enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED) enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED)
persistent_segments = enabled_segments.filter(persistent=True) persistent_segments = enabled_segments.filter(persistent=True)
session_segments = self.request.session['segments']
session_segments = self.get_segments()
rules = AbstractBaseRule.__subclasses__() rules = AbstractBaseRule.__subclasses__()
# Create a list to store the new request session segments and # Create a list to store the new request session segments and
@ -196,8 +201,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
if not any(seg['id'] == segdict['id'] for seg in new_segments): if not any(seg['id'] == segdict['id'] for seg in new_segments):
new_segments.append(segdict) new_segments.append(segdict)
self.request.session['segments'] = new_segments self.set_segments(new_segments)
self.update_visit_count() self.update_visit_count()

View File

@ -33,7 +33,7 @@ class PersonalisedStructBlock(blocks.StructBlock):
""" """
request = context['request'] request = context['request']
adapter = get_segment_adapter(request) adapter = get_segment_adapter(request)
user_segments = adapter.get_all_segments() user_segments = adapter.get_segments()
if value['segment']: if value['segment']:
for segment in user_segments: for segment in user_segments:

View File

@ -1,8 +1,8 @@
from django import template from django import template
from django.template import TemplateSyntaxError from django.template import TemplateSyntaxError
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from wagtail_personalisation.adapters import get_segment_adapter
from wagtail_personalisation.models import Segment from wagtail_personalisation.models import Segment
from wagtail_personalisation.utils import parse_tag from wagtail_personalisation.utils import parse_tag
@ -48,11 +48,11 @@ class SegmentNode(template.Node):
return "" return ""
# Check if user has segment # Check if user has segment
user_segment = context['request'].segment_adapter.get_segment(segment_id=segment.pk) adapter = get_segment_adapter(context['request'])
user_segment = adapter.get_segment_by_id(segment_id=segment.pk)
if not user_segment: if not user_segment:
return "" return ''
content = self.nodelist.render(context) content = self.nodelist.render(context)
content = mark_safe(content) content = mark_safe(content)
return content return content

View File

@ -74,7 +74,7 @@ def serve_variation(page, request, serve_args, serve_kwargs):
user_segments = [] user_segments = []
adapter = get_segment_adapter(request) adapter = get_segment_adapter(request)
for segment in adapter.get_all_segments(): for segment in adapter.get_segments():
try: try:
user_segment = Segment.objects.get( user_segment = Segment.objects.get(
pk=segment['id'], status=Segment.STATUS_ENABLED) pk=segment['id'], status=Segment.STATUS_ENABLED)