From cbc2ec7270bfc0fb7f87ad796563eee4692ce3f7 Mon Sep 17 00:00:00 2001 From: Michael van Tellingen Date: Wed, 31 May 2017 15:48:49 +0200 Subject: [PATCH] Simplify saving/retrieving user segments --- src/wagtail_personalisation/adapters.py | 73 +++++++++----------- src/wagtail_personalisation/wagtail_hooks.py | 10 +-- 2 files changed, 33 insertions(+), 50 deletions(-) diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index 0739daa..e0fc1df 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -80,17 +80,37 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): self.request.session.setdefault('segments', []) def get_segments(self): - """Return the segments stored in the request session. + """Return the persistent segments stored in the request session. :returns: The segments in the request session :rtype: list of wagtail_personalisation.models.Segment or empty list """ - return self.request.session['segments'] + raw_segments = self.request.session['segments'] + segment_ids = [segment['id'] for segment in raw_segments] + + segments = ( + Segment.objects + .filter(status=Segment.STATUS_ENABLED) + .filter(persistent=True) + .in_bulk(segment_ids)) + + return [segments[pk] for pk in segment_ids if pk in segments] def set_segments(self, segments): """Set the currently active segments""" - self.request.session['segments'] = segments + + serialized_segments = [] + segment_ids = set() + for segment in segments: + serialized = create_segment_dictionary(segment) + if serialized['id'] in segment_ids: + continue + + serialized_segments.append(serialized) + segment_ids.add(segment.pk) + + self.request.session['segments'] = serialized_segments def get_segment_by_id(self, segment_id): """Find and return a single segment from the request session. @@ -107,30 +127,6 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): except StopIteration: return None - def add(self, segment): - """Add a segment to the request session. - - :param segment: The segment to add to the request session - :type segment: wagtail_personalisation.models.Segment - - """ - - def check_if_segmented(item): - """Check if the user has been segmented. - - :param item: The segment to check for - :type item: wagtail_personalisation.models.Segment - :returns: Whether the segment is in the request session - :rtype: bool - - """ - return any(seg['encoded_name'] == item.encoded_name() - for seg in self.request.session['segments']) - - if not check_if_segmented(segment): - segdict = create_segment_dictionary(segment) - self.request.session['segments'].append(segdict) - def add_page_visit(self, page): """Mark the page as visited by the user""" visit_count = self.request.session.setdefault('visit_count', []) @@ -177,30 +173,25 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): still apply to the requesting visitor. """ - enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED) - persistent_segments = enabled_segments.filter(persistent=True) + all_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED) - session_segments = self.get_segments() + current_segments = self.get_segments() rules = AbstractBaseRule.__subclasses__() - # Create a list to store the new request session segments and - # re-apply all persistent segments (if they are still enabled). - new_segments = [segment for segment in session_segments - if persistent_segments.filter(id=segment['id']).exists()] - # Run tests on all remaining enabled segments to verify applicability. - for segment in enabled_segments: + additional_segments = [] + for segment in all_segments: segment_rules = [] for rule in rules: segment_rules += rule.objects.filter(segment=segment) - result = self._test_rules(segment_rules, self.request, - match_any=segment.match_any) + + result = self._test_rules( + segment_rules, self.request, match_any=segment.match_any) if result: - segdict = create_segment_dictionary(segment) - if not any(seg['id'] == segdict['id'] for seg in new_segments): - new_segments.append(segdict) + additional_segments.append(segment) + new_segments = current_segments + additional_segments self.set_segments(new_segments) self.update_visit_count() diff --git a/src/wagtail_personalisation/wagtail_hooks.py b/src/wagtail_personalisation/wagtail_hooks.py index 1f50c52..86a06c7 100644 --- a/src/wagtail_personalisation/wagtail_hooks.py +++ b/src/wagtail_personalisation/wagtail_hooks.py @@ -74,15 +74,7 @@ def serve_variation(page, request, serve_args, serve_kwargs): """ user_segments = [] adapter = get_segment_adapter(request) - - for segment in adapter.get_segments(): - try: - user_segment = Segment.objects.get( - pk=segment['id'], status=Segment.STATUS_ENABLED) - except Segment.DoesNotExist: - user_segment = None - if user_segment: - user_segments.append(user_segment) + user_segments = adapter.get_segments() if len(user_segments) > 0: variations = _check_for_variations(user_segments, page)