7

improve SessionSegmentAdapter

This commit is contained in:
Paul J Stevens
2018-05-26 16:04:11 +02:00
parent 2a48eb3498
commit 0bdb80f25a
2 changed files with 38 additions and 11 deletions

View File

@ -66,6 +66,17 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
self.request.session.setdefault('segments', [])
self._segment_cache = None
def _segments(self, ids=None):
if not ids:
ids = []
segments = (
Segment.objects
.enabled()
.filter(persistent=True)
.filter(pk__in=ids)
)
return segments
def get_segments(self, key="segments"):
"""Return the persistent segments stored in the request session.
@ -83,16 +94,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
raw_segments = self.request.session[key]
segment_ids = [segment['id'] for segment in raw_segments]
segments = (
Segment.objects
.enabled()
.filter(persistent=True)
.in_bulk(segment_ids))
segments = self._segments(ids=segment_ids)
retval = [segments[pk] for pk in segment_ids if pk in segments]
result = list(segments)
if key == "segments":
self._segment_cache = retval
return retval
self._segment_cache = result
return result
def set_segments(self, segments, key="segments"):
"""Set the currently active segments
@ -128,9 +135,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: wagtail_personalisation.models.Segment or None
"""
for segment in self.get_segments():
if segment.pk == segment_id:
return segment
segments = self._segments(ids=[segment_id])
if segments.exists():
return segments.get()
def add_page_visit(self, page):
"""Mark the page as visited by the user"""
@ -180,6 +187,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
current_segments = self.get_segments()
excluded_segments = self.get_segments("excluded_segments")
current_segments = list(
set(current_segments) - set(excluded_segments)
)
# Run tests on all remaining enabled segments to verify applicability.
additional_segments = []