diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index b82aeee..0225ced 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -166,17 +166,13 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): def update_visit_count(self): """Update the visit count for all segments in the request session.""" segments = self.request.session['segments'] - for seg in segments: - try: - segment = Segment.objects.get(pk=seg['id']) - segment.visit_count = F('visit_count') + 1 - segment.save() + segment_pks = [s['id'] for s in segments] - except Segment.DoesNotExist: - # The segment no longer exists. - # Remove it from the request session. - self.request.session['segments'][:] = [ - item for item in segments if item.get('id') != seg['id']] + # Update counts + (Segment.objects + .enabled() + .filter(pk__in=segment_pks) + .update(visit_count=F('visit_count') + 1)) def refresh(self): """Retrieve the request session segments and verify whether or not they diff --git a/src/wagtail_personalisation/models.py b/src/wagtail_personalisation/models.py index bd31692..b77d643 100644 --- a/src/wagtail_personalisation/models.py +++ b/src/wagtail_personalisation/models.py @@ -10,13 +10,17 @@ from wagtail.utils.decorators import cached_classmethod from wagtail.wagtailadmin.edit_handlers import ( FieldPanel, FieldRowPanel, InlinePanel, MultiFieldPanel, ObjectList, PageChooserPanel, TabbedInterface) -from wagtail.wagtailcore.models import Page from wagtail_personalisation.forms import AdminPersonalisablePageForm from wagtail_personalisation.rules import AbstractBaseRule from wagtail_personalisation.utils import count_active_days +class SegmentQuerySet(models.QuerySet): + def enabled(self): + return self.filter(status=self.model.STATUS_ENABLED) + + @python_2_unicode_compatible class Segment(ClusterableModel): """The segment model.""" @@ -43,6 +47,8 @@ class Segment(ClusterableModel): help_text=_("Should the segment match all the rules or just one of them?") ) + objects = SegmentQuerySet.as_manager() + def __init__(self, *args, **kwargs): Segment.panels = [ MultiFieldPanel([ diff --git a/tests/unit/test_adapter_session.py b/tests/unit/test_adapter_session.py index bf4fcf1..fa9bf96 100644 --- a/tests/unit/test_adapter_session.py +++ b/tests/unit/test_adapter_session.py @@ -1,11 +1,11 @@ import pytest -from wagtail_personalisation import adapters from tests.factories.segment import SegmentFactory +from wagtail_personalisation import adapters @pytest.mark.django_db -def test_get_segments(rf, monkeypatch): +def test_get_segments(rf): request = rf.get('/') adapter = adapters.SessionSegmentsAdapter(request) @@ -21,7 +21,7 @@ def test_get_segments(rf, monkeypatch): @pytest.mark.django_db -def test_get_segment_by_id(rf, monkeypatch): +def test_get_segment_by_id(rf): request = rf.get('/') adapter = adapters.SessionSegmentsAdapter(request) @@ -36,7 +36,7 @@ def test_get_segment_by_id(rf, monkeypatch): @pytest.mark.django_db -def test_refresh_removes_disabled(rf, monkeypatch): +def test_refresh_removes_disabled(rf): request = rf.get('/') adapter = adapters.SessionSegmentsAdapter(request) @@ -52,3 +52,51 @@ def test_refresh_removes_disabled(rf, monkeypatch): adapter.refresh() assert adapter.get_segments() == [segment_2] + + +@pytest.mark.django_db +def test_add_page_visit(rf, site): + request = rf.get('/') + + adapter = adapters.SessionSegmentsAdapter(request) + adapter.add_page_visit(site.root_page) + + assert request.session['visit_count'][0]['count'] == 1 + + adapter.add_page_visit(site.root_page) + assert request.session['visit_count'][0]['count'] == 2 + + assert adapter.get_visit_count() == 2 + + +@pytest.mark.django_db +def test_update_visit_count(rf, site): + request = rf.get('/') + + adapter = adapters.SessionSegmentsAdapter(request) + + segment_1 = SegmentFactory(name='segment-1', persistent=True, visit_count=0) + segment_2 = SegmentFactory(name='segment-2', persistent=True, visit_count=0) + + adapter.set_segments([segment_1, segment_2]) + adapter.update_visit_count() + + segment_1.refresh_from_db() + segment_2.refresh_from_db() + + assert segment_1.visit_count == 1 + assert segment_2.visit_count == 1 + + +@pytest.mark.django_db +def test_update_visit_count_deleted_segment(rf, site): + request = rf.get('/') + + adapter = adapters.SessionSegmentsAdapter(request) + + segment_1 = SegmentFactory(name='segment-1', persistent=True, visit_count=0) + segment_2 = SegmentFactory(name='segment-2', persistent=True, visit_count=0) + + adapter.set_segments([segment_1, segment_2]) + segment_2.delete() + adapter.update_visit_count()