7

Simplify saving/retrieving user segments

This commit is contained in:
Michael van Tellingen
2017-05-31 15:48:49 +02:00
committed by Michael van Tellingen
parent 2450bd45ac
commit cbc2ec7270
2 changed files with 33 additions and 50 deletions

View File

@ -80,17 +80,37 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
self.request.session.setdefault('segments', [])
def get_segments(self):
"""Return the segments stored in the request session.
"""Return the persistent segments stored in the request session.
:returns: The segments in the request session
:rtype: list of wagtail_personalisation.models.Segment or empty list
"""
return self.request.session['segments']
raw_segments = self.request.session['segments']
segment_ids = [segment['id'] for segment in raw_segments]
segments = (
Segment.objects
.filter(status=Segment.STATUS_ENABLED)
.filter(persistent=True)
.in_bulk(segment_ids))
return [segments[pk] for pk in segment_ids if pk in segments]
def set_segments(self, segments):
"""Set the currently active segments"""
self.request.session['segments'] = segments
serialized_segments = []
segment_ids = set()
for segment in segments:
serialized = create_segment_dictionary(segment)
if serialized['id'] in segment_ids:
continue
serialized_segments.append(serialized)
segment_ids.add(segment.pk)
self.request.session['segments'] = serialized_segments
def get_segment_by_id(self, segment_id):
"""Find and return a single segment from the request session.
@ -107,30 +127,6 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
except StopIteration:
return None
def add(self, segment):
"""Add a segment to the request session.
:param segment: The segment to add to the request session
:type segment: wagtail_personalisation.models.Segment
"""
def check_if_segmented(item):
"""Check if the user has been segmented.
:param item: The segment to check for
:type item: wagtail_personalisation.models.Segment
:returns: Whether the segment is in the request session
:rtype: bool
"""
return any(seg['encoded_name'] == item.encoded_name()
for seg in self.request.session['segments'])
if not check_if_segmented(segment):
segdict = create_segment_dictionary(segment)
self.request.session['segments'].append(segdict)
def add_page_visit(self, page):
"""Mark the page as visited by the user"""
visit_count = self.request.session.setdefault('visit_count', [])
@ -177,30 +173,25 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
still apply to the requesting visitor.
"""
enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED)
persistent_segments = enabled_segments.filter(persistent=True)
all_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED)
session_segments = self.get_segments()
current_segments = self.get_segments()
rules = AbstractBaseRule.__subclasses__()
# Create a list to store the new request session segments and
# re-apply all persistent segments (if they are still enabled).
new_segments = [segment for segment in session_segments
if persistent_segments.filter(id=segment['id']).exists()]
# Run tests on all remaining enabled segments to verify applicability.
for segment in enabled_segments:
additional_segments = []
for segment in all_segments:
segment_rules = []
for rule in rules:
segment_rules += rule.objects.filter(segment=segment)
result = self._test_rules(segment_rules, self.request,
match_any=segment.match_any)
result = self._test_rules(
segment_rules, self.request, match_any=segment.match_any)
if result:
segdict = create_segment_dictionary(segment)
if not any(seg['id'] == segdict['id'] for seg in new_segments):
new_segments.append(segdict)
additional_segments.append(segment)
new_segments = current_segments + additional_segments
self.set_segments(new_segments)
self.update_visit_count()

View File

@ -74,15 +74,7 @@ def serve_variation(page, request, serve_args, serve_kwargs):
"""
user_segments = []
adapter = get_segment_adapter(request)
for segment in adapter.get_segments():
try:
user_segment = Segment.objects.get(
pk=segment['id'], status=Segment.STATUS_ENABLED)
except Segment.DoesNotExist:
user_segment = None
if user_segment:
user_segments.append(user_segment)
user_segments = adapter.get_segments()
if len(user_segments) > 0:
variations = _check_for_variations(user_segments, page)