diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index 57b0287..1590a1b 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -81,7 +81,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): segments = ( Segment.objects - .filter(status=Segment.STATUS_ENABLED) + .enabled() .filter(persistent=True) .in_bulk(segment_ids)) @@ -166,7 +166,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): still apply to the requesting visitor. """ - enabled_segments = Segment.objects.filter(status=Segment.STATUS_ENABLED) + enabled_segments = Segment.objects.enabled() rule_models = AbstractBaseRule.get_descendant_models() current_segments = self.get_segments() @@ -196,8 +196,6 @@ SEGMENT_ADAPTER_CLASS = import_string(getattr( def get_segment_adapter(request): """Return the Segment Adapter for the given request""" - try: - return request.segment_adapter - except AttributeError: + if not hasattr(request, 'segment_adapter'): request.segment_adapter = SEGMENT_ADAPTER_CLASS(request) - return request.segment_adapter + return request.segment_adapter diff --git a/src/wagtail_personalisation/models.py b/src/wagtail_personalisation/models.py index ef26c90..230f09a 100644 --- a/src/wagtail_personalisation/models.py +++ b/src/wagtail_personalisation/models.py @@ -88,6 +88,13 @@ class Segment(ClusterableModel): rule_model._default_manager.filter(segment=self)) return segment_rules + def toggle(self, save=True): + self.status = ( + self.STATUS_ENABLED if self.status == self.STATUS_DISABLED + else self.STATUS_DISABLED) + if save: + self.save() + class PersonalisablePageMixin(models.Model): """The personalisable page model. Allows creation of variants with linked diff --git a/src/wagtail_personalisation/views.py b/src/wagtail_personalisation/views.py index 2aebb0c..5cc7961 100644 --- a/src/wagtail_personalisation/views.py +++ b/src/wagtail_personalisation/views.py @@ -98,12 +98,7 @@ def toggle(request, segment_id): if request.user.has_perm('wagtailadmin.access_admin'): segment = get_object_or_404(Segment, pk=segment_id) - if segment.status == Segment.STATUS_ENABLED: - segment.status = Segment.STATUS_DISABLED - elif segment.status == Segment.STATUS_DISABLED: - segment.status = Segment.STATUS_ENABLED - - segment.save() + segment.toggle() return HttpResponseRedirect(request.META.get('HTTP_REFERER', '/')) diff --git a/src/wagtail_personalisation/wagtail_hooks.py b/src/wagtail_personalisation/wagtail_hooks.py index ac7fc56..09fde05 100644 --- a/src/wagtail_personalisation/wagtail_hooks.py +++ b/src/wagtail_personalisation/wagtail_hooks.py @@ -75,7 +75,7 @@ def serve_variation(page, request, serve_args, serve_kwargs): user_segments = adapter.get_segments() if user_segments: - variations = _check_for_variations(user_segments, page) + variations = page.variants_for_segments(user_segments) if variations: variation = variations[0] @@ -85,22 +85,6 @@ def serve_variation(page, request, serve_args, serve_kwargs): return variation.serve(request, *serve_args, **serve_kwargs) -def _check_for_variations(segments, page): - """Check whether there are variations available for the provided segments - on the page being served. - - :param segments: The segments applicable to the request. - :type segments: list of wagtail_personalisation.models.Segment - :param page: The page being served - :type page: wagtail_personalisation.models.PersonalisablePage or - wagtail.wagtailcore.models.Page - :returns: A variant of the requested page matching the segments or None - :rtype: wagtail_personalisation.models.PersonalisablePage or None - - """ - return page.variants_for_segments(segments) - - @hooks.register('register_page_listing_buttons') def page_listing_variant_buttons(page, page_perms, is_parent=False): """Adds page listing buttons to personalisable pages. Shows variants for diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 12e10f0..000a3ab 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -1,13 +1,12 @@ import pytest -from wagtail_personalisation import wagtail_hooks as hooks from wagtail_personalisation.models import Segment @pytest.mark.django_db -def test_check_for_variations(segmented_page): +def test_variants(segmented_page): segments = Segment.objects.all() page = segmented_page.canonical_page - variations = hooks._check_for_variations(segments, page) + variations = page.variants_for_segments(segments) assert variations