7

Get excluded segments from session and don't check them again

This commit is contained in:
Kaitlyn Crawford
2018-02-12 18:00:38 +02:00
parent 0f0aecf673
commit ea1ecc2a98

View File

@ -66,17 +66,21 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
self.request.session.setdefault('segments', []) self.request.session.setdefault('segments', [])
self._segment_cache = None self._segment_cache = None
def get_segments(self): def get_segments(self, key="segments"):
"""Return the persistent segments stored in the request session. """Return the persistent segments stored in the request session.
:param key: The key under which the segments are stored
:type key: String
:returns: The segments in the request session :returns: The segments in the request session
:rtype: list of wagtail_personalisation.models.Segment or empty list :rtype: list of wagtail_personalisation.models.Segment or empty list
""" """
if self._segment_cache is not None: if key == "segments" and self._segment_cache is not None:
return self._segment_cache return self._segment_cache
raw_segments = self.request.session['segments'] if key not in self.request.session:
return []
raw_segments = self.request.session[key]
segment_ids = [segment['id'] for segment in raw_segments] segment_ids = [segment['id'] for segment in raw_segments]
segments = ( segments = (
@ -86,7 +90,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
.in_bulk(segment_ids)) .in_bulk(segment_ids))
retval = [segments[pk] for pk in segment_ids if pk in segments] retval = [segments[pk] for pk in segment_ids if pk in segments]
self._segment_cache = retval if key == "segments":
self._segment_cache = retval
return retval return retval
def set_segments(self, segments, key="segments"): def set_segments(self, segments, key="segments"):
@ -174,14 +179,15 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter):
rule_models = AbstractBaseRule.get_descendant_models() rule_models = AbstractBaseRule.get_descendant_models()
current_segments = self.get_segments() current_segments = self.get_segments()
excluded_segments = [] excluded_segments = self.get_segments("excluded_segments")
# Run tests on all remaining enabled segments to verify applicability. # Run tests on all remaining enabled segments to verify applicability.
additional_segments = [] additional_segments = []
for segment in enabled_segments: for segment in enabled_segments:
if segment.is_static and segment.static_users.filter(id=self.request.user.id).exists(): if segment.is_static and segment.static_users.filter(id=self.request.user.id).exists():
additional_segments.append(segment) additional_segments.append(segment)
elif segment.excluded_users.filter(id=self.request.user.id).exists(): elif (segment.excluded_users.filter(id=self.request.user.id).exists() or
segment in excluded_segments):
continue continue
elif not segment.is_static or not segment.is_full: elif not segment.is_static or not segment.is_full:
segment_rules = [] segment_rules = []