diff --git a/src/wagtail_personalisation/rules.py b/src/wagtail_personalisation/rules.py index 5316002..6461ea7 100644 --- a/src/wagtail_personalisation/rules.py +++ b/src/wagtail_personalisation/rules.py @@ -2,17 +2,23 @@ from __future__ import absolute_import, unicode_literals import re from datetime import datetime +from importlib import import_module from django.apps import apps +from django.conf import settings +from django.contrib.sessions.models import Session from django.db import models from django.template.defaultfilters import slugify from django.utils.encoding import force_text, python_2_unicode_compatible from django.utils.translation import ugettext_lazy as _ +from django.test.client import RequestFactory from modelcluster.fields import ParentalKey from user_agents import parse from wagtail.wagtailadmin.edit_handlers import ( FieldPanel, FieldRowPanel, PageChooserPanel) +SessionStore = import_module(settings.SESSION_ENGINE).SessionStore + @python_2_unicode_compatible class AbstractBaseRule(models.Model): @@ -221,17 +227,33 @@ class VisitCountRule(AbstractBaseRule): verbose_name = _('Visit count Rule') def test_user(self, request, user=None): + # Local import for cyclic import + from wagtail_personalisation.adapters import ( + get_segment_adapter, SessionSegmentsAdapter, SEGMENT_ADAPTER_CLASS) + if user: - # This rule currently does not support testing a user directly - # TODO: Make this test a user directly when the rule uses - # historical data + # Create a fake request so we can use the adapter + request = RequestFactory().get('/') + request.session = SessionStore() + + # If we're using the session adapter check for an active session + if SEGMENT_ADAPTER_CLASS == SessionSegmentsAdapter: + sessions = Session.objects.iterator() + for session in sessions: + session_data = session.get_decoded() + if session_data.get('_auth_user_id') == str(user.id): + request.session = SessionStore( + session_key=session.session_key) + break + + request.user = user + elif not request: + # Return false if we don't have a user or a request return False + operator = self.operator segment_count = self.count - # Local import for cyclic import - from wagtail_personalisation.adapters import get_segment_adapter - adapter = get_segment_adapter(request) visit_count = adapter.get_visit_count(self.counted_page) diff --git a/tests/unit/test_rules_visitcount.py b/tests/unit/test_rules_visitcount.py index a4d7d60..f153e3c 100644 --- a/tests/unit/test_rules_visitcount.py +++ b/tests/unit/test_rules_visitcount.py @@ -1,5 +1,8 @@ import pytest +from tests.factories.rule import VisitCountRuleFactory +from tests.factories.segment import SegmentFactory + @pytest.mark.django_db def test_visit_count(site, client): @@ -20,3 +23,29 @@ def test_visit_count(site, client): visit_count = client.session['visit_count'] assert visit_count[0]['count'] == 2 assert visit_count[1]['count'] == 1 + + +@pytest.mark.django_db +def test_visit_count_call_test_user_with_user(site, client, user): + segment = SegmentFactory(name='VisitCount') + rule = VisitCountRuleFactory(counted_page=site.root_page, segment=segment) + + session = client.session + session['visit_count'] = [{'path': '/', 'count': 2}] + session.save() + client.force_login(user) + + assert rule.test_user(None, user) + + +@pytest.mark.django_db +def test_visit_count_call_test_user_with_user_or_request_fails(site, client, user): + segment = SegmentFactory(name='VisitCount') + rule = VisitCountRuleFactory(counted_page=site.root_page, segment=segment) + + session = client.session + session['visit_count'] = [{'path': '/', 'count': 2}] + session.save() + client.force_login(user) + + assert not rule.test_user(None)