From 1f4a4536ab55bab13e58b1062c79fd7b4bd3fb40 Mon Sep 17 00:00:00 2001 From: Todd Dembrey Date: Wed, 1 Nov 2017 16:43:22 +0000 Subject: [PATCH] Make the static elements tracked users only We cannot track anonymous users as the session expires after 10 minutes of inactivity. This also avoids an issue where there is an error when the user's session has expired and they navigate a page --- src/wagtail_personalisation/adapters.py | 9 +-- src/wagtail_personalisation/forms.py | 23 ++++---- .../migrations/0015_static_users.py | 26 +++++++++ src/wagtail_personalisation/models.py | 7 ++- tests/fixtures.py | 5 ++ tests/unit/test_static_dynamic_segments.py | 56 ++++++++++++------- 6 files changed, 86 insertions(+), 40 deletions(-) create mode 100644 src/wagtail_personalisation/migrations/0015_static_users.py diff --git a/src/wagtail_personalisation/adapters.py b/src/wagtail_personalisation/adapters.py index e46ed19..e145cd9 100644 --- a/src/wagtail_personalisation/adapters.py +++ b/src/wagtail_personalisation/adapters.py @@ -175,7 +175,7 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): # Run tests on all remaining enabled segments to verify applicability. additional_segments = [] for segment in enabled_segments: - if segment.is_static and self.request.session.session_key in segment.sessions.values_list('session_key', flat=True): + if segment.is_static and segment.static_users.filter(id=self.request.user.id).exists(): additional_segments.append(segment) elif not segment.is_static or not segment.is_full: segment_rules = [] @@ -186,11 +186,8 @@ class SessionSegmentsAdapter(BaseSegmentsAdapter): match_any=segment.match_any) if result and segment.is_static and not segment.is_full: - session = self.request.session.model.objects.get( - session_key=self.request.session.session_key, - expire_date__gt=timezone.now(), - ) - segment.sessions.add(session) + if self.request.user.is_authenticated(): + segment.static_users.add(self.request.user) if result: additional_segments.append(segment) diff --git a/src/wagtail_personalisation/forms.py b/src/wagtail_personalisation/forms.py index e8e5deb..a8d659a 100644 --- a/src/wagtail_personalisation/forms.py +++ b/src/wagtail_personalisation/forms.py @@ -26,7 +26,7 @@ def user_from_data(user_id): try: return User.objects.get(id=user_id) except User.DoesNotExist: - return AnonymousUser + return AnonymousUser() @@ -78,22 +78,23 @@ class SegmentAdminForm(WagtailAdminModelForm): request.session = SessionStore() adapter = get_segment_adapter(request) - sessions_to_add = [] - sessions = Session.objects.filter(expire_date__gt=timezone.now()).iterator() + users_to_add = [] + sessions = Session.objects.iterator() take_session = takewhile( - lambda x: instance.count == 0 or len(sessions_to_add) <= instance.count, + lambda x: instance.count == 0 or len(users_to_add) <= instance.count, sessions ) for session in take_session: session_data = session.get_decoded() - user = user_from_data(session_data.get('_auth_id')) - request.user = user - request.session = SessionStore(session_key=session.session_key) - passes = adapter._test_rules(instance.get_rules(), request, instance.match_any) - if passes: - sessions_to_add.append(session) + user = user_from_data(session_data.get('_auth_user_id')) + if user.is_authenticated: + request.user = user + request.session = SessionStore(session_key=session.session_key) + passes = adapter._test_rules(instance.get_rules(), request, instance.match_any) + if passes: + users_to_add.append(user) - instance.sessions.add(*sessions_to_add) + instance.static_users.add(*users_to_add) return instance diff --git a/src/wagtail_personalisation/migrations/0015_static_users.py b/src/wagtail_personalisation/migrations/0015_static_users.py new file mode 100644 index 0000000..ea76aa8 --- /dev/null +++ b/src/wagtail_personalisation/migrations/0015_static_users.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.6 on 2017-11-01 15:58 +from __future__ import unicode_literals + +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('wagtail_personalisation', '0013_add_dynamic_static_to_segment'), + ] + + operations = [ + migrations.RemoveField( + model_name='segment', + name='sessions', + ), + migrations.AddField( + model_name='segment', + name='static_users', + field=models.ManyToManyField(to=settings.AUTH_USER_MODEL), + ), + ] diff --git a/src/wagtail_personalisation/models.py b/src/wagtail_personalisation/models.py index e7d5e5f..cc1a6b2 100644 --- a/src/wagtail_personalisation/models.py +++ b/src/wagtail_personalisation/models.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, unicode_literals from django import forms +from django.conf import settings from django.contrib.sessions.models import Session from django.db import models, transaction from django.template.defaultfilters import slugify @@ -82,7 +83,9 @@ class Segment(ClusterableModel): "set until the number is reached. After this no more users will be added." ) ) - sessions = models.ManyToManyField(Session) + static_users = models.ManyToManyField( + settings.AUTH_USER_MODEL, + ) objects = SegmentQuerySet.as_manager() @@ -131,7 +134,7 @@ class Segment(ClusterableModel): @property def is_full(self): - return self.sessions.count() >= self.count + return self.static_users.count() >= self.count def encoded_name(self): """Return a string with a slug for the segment.""" diff --git a/tests/fixtures.py b/tests/fixtures.py index 03efbd9..9f5adfc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -44,3 +44,8 @@ class RequestFactory(BaseRequestFactory): request.session = SessionStore() request._messages = FallbackStorage(request) return request + + +@pytest.fixture +def user(django_user_model): + return django_user_model.objects.create(username='user') diff --git a/tests/unit/test_static_dynamic_segments.py b/tests/unit/test_static_dynamic_segments.py index fe4b05c..778c0b5 100644 --- a/tests/unit/test_static_dynamic_segments.py +++ b/tests/unit/test_static_dynamic_segments.py @@ -36,9 +36,10 @@ def form_with_data(segment, *rules): @pytest.mark.django_db -def test_session_added_to_static_segment_at_creation(site, client): +def test_session_added_to_static_segment_at_creation(site, client, user): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) segment = SegmentFactory.build(type=Segment.TYPE_STATIC) @@ -46,17 +47,21 @@ def test_session_added_to_static_segment_at_creation(site, client): form = form_with_data(segment, rule) instance = form.save() - assert session.session_key in instance.sessions.values_list('session_key', flat=True) + assert user in instance.static_users.all() @pytest.mark.django_db -def test_match_any_correct_populates(site, client): +def test_match_any_correct_populates(site, client, django_user_model): + user = django_user_model.objects.create(username='first') session = client.session + client.force_login(user) client.get(site.root_page.url) + other_user = django_user_model.objects.create(username='second') client.cookies.clear() second_session = client.session other_page = site.root_page.get_last_child() + client.force_login(other_user) client.get(other_page.url) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, match_any=True) @@ -66,14 +71,15 @@ def test_match_any_correct_populates(site, client): instance = form.save() assert session.session_key != second_session.session_key - assert session.session_key in instance.sessions.values_list('session_key', flat=True) - assert second_session.session_key in instance.sessions.values_list('session_key', flat=True) + assert user in instance.static_users.all() + assert other_user in instance.static_users.all() @pytest.mark.django_db -def test_mixed_static_dynamic_session_doesnt_generate_at_creation(site, client): +def test_mixed_static_dynamic_session_doesnt_generate_at_creation(site, client, user): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1) @@ -85,11 +91,11 @@ def test_mixed_static_dynamic_session_doesnt_generate_at_creation(site, client): form = form_with_data(segment, static_rule, non_static_rule) instance = form.save() - assert not instance.sessions.all() + assert not instance.static_users.all() @pytest.mark.django_db -def test_session_not_added_to_static_segment_after_creation(site, client): +def test_session_not_added_to_static_segment_after_creation(site, client, user): segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=0) rule = VisitCountRule(counted_page=site.root_page) form = form_with_data(segment, rule) @@ -97,13 +103,14 @@ def test_session_not_added_to_static_segment_after_creation(site, client): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) - assert not instance.sessions.all() + assert not instance.static_users.all() @pytest.mark.django_db -def test_session_added_to_static_segment_after_creation(site, client): +def test_session_added_to_static_segment_after_creation(site, client, user): segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1) rule = VisitCountRule(counted_page=site.root_page) form = form_with_data(segment, rule) @@ -111,39 +118,45 @@ def test_session_added_to_static_segment_after_creation(site, client): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) - assert session.session_key in instance.sessions.values_list('session_key', flat=True) + assert user in instance.static_users.all() @pytest.mark.django_db -def test_session_not_added_to_static_segment_after_full(site, client): +def test_session_not_added_to_static_segment_after_full(site, client, django_user_model): + user = django_user_model.objects.create(username='first') + other_user = django_user_model.objects.create(username='second') segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1) rule = VisitCountRule(counted_page=site.root_page) form = form_with_data(segment, rule) instance = form.save() - assert instance.sessions.count() == 0 + assert not instance.static_users.all() session = client.session + client.force_login(user) client.get(site.root_page.url) - assert instance.sessions.count() == 1 + assert instance.static_users.count() == 1 client.cookies.clear() second_session = client.session + client.force_login(other_user) client.get(site.root_page.url) assert session.session_key != second_session.session_key - assert instance.sessions.count() == 1 - assert session.session_key in instance.sessions.values_list('session_key', flat=True) - assert second_session.session_key not in instance.sessions.values_list('session_key', flat=True) + assert instance.static_users.count() == 1 + assert user in instance.static_users.all() + assert other_user not in instance.static_users.all() @pytest.mark.django_db -def test_sessions_not_added_to_static_segment_if_rule_not_static(client, site): +def test_sessions_not_added_to_static_segment_if_rule_not_static(client, site, user): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1) @@ -155,13 +168,14 @@ def test_sessions_not_added_to_static_segment_if_rule_not_static(client, site): form = form_with_data(segment, rule) instance = form.save() - assert not instance.sessions.all() + assert not instance.static_users.all() @pytest.mark.django_db -def test_does_not_calculate_the_segment_again(site, client, mocker): +def test_does_not_calculate_the_segment_again(site, client, mocker, user): session = client.session session.save() + client.force_login(user) client.get(site.root_page.url) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=2) @@ -169,7 +183,7 @@ def test_does_not_calculate_the_segment_again(site, client, mocker): form = form_with_data(segment, rule) instance = form.save() - assert session.session_key in instance.sessions.values_list('session_key', flat=True) + assert user in instance.static_users.all() mock_test_rule = mocker.patch('wagtail_personalisation.adapters.SessionSegmentsAdapter._test_rules') client.get(site.root_page.url)