7

improve SessionSegmentAdapter

This commit is contained in:
Paul J Stevens
2018-05-26 16:04:11 +02:00
parent 2a48eb3498
commit 0bdb80f25a
2 changed files with 38 additions and 11 deletions

View File

@ -66,6 +66,17 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
self.request.session.setdefault('segments', []) self.request.session.setdefault('segments', [])
self._segment_cache = None self._segment_cache = None
def _segments(self, ids=None):
if not ids:
ids = []
segments = (
Segment.objects
.enabled()
.filter(persistent=True)
.filter(pk__in=ids)
)
return segments
def get_segments(self, key="segments"): def get_segments(self, key="segments"):
"""Return the persistent segments stored in the request session. """Return the persistent segments stored in the request session.
@ -83,16 +94,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
raw_segments = self.request.session[key] raw_segments = self.request.session[key]
segment_ids = [segment['id'] for segment in raw_segments] segment_ids = [segment['id'] for segment in raw_segments]
segments = ( segments = self._segments(ids=segment_ids)
Segment.objects
.enabled()
.filter(persistent=True)
.in_bulk(segment_ids))
retval = [segments[pk] for pk in segment_ids if pk in segments] result = list(segments)
if key == "segments": if key == "segments":
self._segment_cache = retval self._segment_cache = result
return retval return result
def set_segments(self, segments, key="segments"): def set_segments(self, segments, key="segments"):
"""Set the currently active segments """Set the currently active segments
@ -128,9 +135,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
:rtype: wagtail_personalisation.models.Segment or None :rtype: wagtail_personalisation.models.Segment or None
""" """
for segment in self.get_segments(): segments = self._segments(ids=[segment_id])
if segment.pk == segment_id: if segments.exists():
return segment return segments.get()
def add_page_visit(self, page): def add_page_visit(self, page):
"""Mark the page as visited by the user""" """Mark the page as visited by the user"""
@ -180,6 +187,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
current_segments = self.get_segments() current_segments = self.get_segments()
excluded_segments = self.get_segments("excluded_segments") excluded_segments = self.get_segments("excluded_segments")
current_segments = list(
set(current_segments) - set(excluded_segments)
)
# Run tests on all remaining enabled segments to verify applicability. # Run tests on all remaining enabled segments to verify applicability.
additional_segments = [] additional_segments = []

View File

@ -20,6 +20,23 @@ def test_get_segments(rf):
assert segments == [segment_1, segment_2] assert segments == [segment_1, segment_2]
@pytest.mark.django_db
def test_get_segments_session(rf):
request = rf.get('/')
adapter = adapters.SessionSegmentsAdapter(request)
segment_1 = SegmentFactory(name='segment-1', persistent=True)
segment_2 = SegmentFactory(name='segment-2', persistent=True)
adapter.set_segments([segment_1, segment_2])
assert len(request.session['segments']) == 2
adapter._segment_cache = None
segments = adapter.get_segments()
assert segments == [segment_1, segment_2]
@pytest.mark.django_db @pytest.mark.django_db
def test_get_segment_by_id(rf): def test_get_segment_by_id(rf):
request = rf.get('/') request = rf.get('/')