improve SessionSegmentAdapter
This commit is contained in:
@ -66,6 +66,17 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
self.request.session.setdefault('segments', [])
|
self.request.session.setdefault('segments', [])
|
||||||
self._segment_cache = None
|
self._segment_cache = None
|
||||||
|
|
||||||
|
def _segments(self, ids=None):
|
||||||
|
if not ids:
|
||||||
|
ids = []
|
||||||
|
segments = (
|
||||||
|
Segment.objects
|
||||||
|
.enabled()
|
||||||
|
.filter(persistent=True)
|
||||||
|
.filter(pk__in=ids)
|
||||||
|
)
|
||||||
|
return segments
|
||||||
|
|
||||||
def get_segments(self, key="segments"):
|
def get_segments(self, key="segments"):
|
||||||
"""Return the persistent segments stored in the request session.
|
"""Return the persistent segments stored in the request session.
|
||||||
|
|
||||||
@ -83,16 +94,12 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
raw_segments = self.request.session[key]
|
raw_segments = self.request.session[key]
|
||||||
segment_ids = [segment['id'] for segment in raw_segments]
|
segment_ids = [segment['id'] for segment in raw_segments]
|
||||||
|
|
||||||
segments = (
|
segments = self._segments(ids=segment_ids)
|
||||||
Segment.objects
|
|
||||||
.enabled()
|
|
||||||
.filter(persistent=True)
|
|
||||||
.in_bulk(segment_ids))
|
|
||||||
|
|
||||||
retval = [segments[pk] for pk in segment_ids if pk in segments]
|
result = list(segments)
|
||||||
if key == "segments":
|
if key == "segments":
|
||||||
self._segment_cache = retval
|
self._segment_cache = result
|
||||||
return retval
|
return result
|
||||||
|
|
||||||
def set_segments(self, segments, key="segments"):
|
def set_segments(self, segments, key="segments"):
|
||||||
"""Set the currently active segments
|
"""Set the currently active segments
|
||||||
@ -128,9 +135,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
:rtype: wagtail_personalisation.models.Segment or None
|
:rtype: wagtail_personalisation.models.Segment or None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for segment in self.get_segments():
|
segments = self._segments(ids=[segment_id])
|
||||||
if segment.pk == segment_id:
|
if segments.exists():
|
||||||
return segment
|
return segments.get()
|
||||||
|
|
||||||
def add_page_visit(self, page):
|
def add_page_visit(self, page):
|
||||||
"""Mark the page as visited by the user"""
|
"""Mark the page as visited by the user"""
|
||||||
@ -180,6 +187,9 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
|
|||||||
|
|
||||||
current_segments = self.get_segments()
|
current_segments = self.get_segments()
|
||||||
excluded_segments = self.get_segments("excluded_segments")
|
excluded_segments = self.get_segments("excluded_segments")
|
||||||
|
current_segments = list(
|
||||||
|
set(current_segments) - set(excluded_segments)
|
||||||
|
)
|
||||||
|
|
||||||
# Run tests on all remaining enabled segments to verify applicability.
|
# Run tests on all remaining enabled segments to verify applicability.
|
||||||
additional_segments = []
|
additional_segments = []
|
||||||
|
@ -20,6 +20,23 @@ def test_get_segments(rf):
|
|||||||
assert segments == [segment_1, segment_2]
|
assert segments == [segment_1, segment_2]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
def test_get_segments_session(rf):
|
||||||
|
request = rf.get('/')
|
||||||
|
|
||||||
|
adapter = adapters.SessionSegmentsAdapter(request)
|
||||||
|
|
||||||
|
segment_1 = SegmentFactory(name='segment-1', persistent=True)
|
||||||
|
segment_2 = SegmentFactory(name='segment-2', persistent=True)
|
||||||
|
|
||||||
|
adapter.set_segments([segment_1, segment_2])
|
||||||
|
assert len(request.session['segments']) == 2
|
||||||
|
|
||||||
|
adapter._segment_cache = None
|
||||||
|
segments = adapter.get_segments()
|
||||||
|
assert segments == [segment_1, segment_2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
def test_get_segment_by_id(rf):
|
def test_get_segment_by_id(rf):
|
||||||
request = rf.get('/')
|
request = rf.get('/')
|
||||||
|
Reference in New Issue
Block a user