8

Update to use the save method on the form to populate the segments

This commit is contained in:
Todd Dembrey
2017-10-23 15:37:08 +01:00
parent 44cc95617e
commit a116b14d57
3 changed files with 186 additions and 176 deletions

View File

@@ -1,7 +1,29 @@
from importlib import import_module
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.models import Session
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.test.client import RequestFactory
from django.utils import timezone
from django.utils.lru_cache import lru_cache
from wagtail.wagtailadmin.forms import WagtailAdminModelForm from wagtail.wagtailadmin.forms import WagtailAdminModelForm
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@lru_cache(maxsize=1000)
def user_from_data(user_id):
User = get_user_model()
try:
return User.objects.get(id=user_id)
except User.DoesNotExist:
return AnonymousUser
class SegmentAdminForm(WagtailAdminModelForm): class SegmentAdminForm(WagtailAdminModelForm):
def clean(self): def clean(self):
cleaned_data = super(SegmentAdminForm, self).clean() cleaned_data = super(SegmentAdminForm, self).clean()
@@ -16,3 +38,22 @@ class SegmentAdminForm(WagtailAdminModelForm):
}) })
return cleaned_data return cleaned_data
def save(self, *args, **kwargs):
instance = super(SegmentAdminForm, self).save(*args, **kwargs)
if instance.can_populate:
request = RequestFactory().get('/')
for session in Session.objects.filter(expire_date__gt=timezone.now()).iterator():
session_data = session.get_decoded()
user = user_from_data(session_data.get('_auth_id'))
request.user = user
request.session = SessionStore(session_key=session.session_key)
all_pass = all(rule.test_user(request) for rule in instance.get_rules() if rule.static)
if all_pass:
instance.sessions.add(session)
instance.frozen = True
instance.save()
return instance

View File

@@ -1,26 +1,20 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
from importlib import import_module
from django import forms from django import forms
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.models import Session from django.contrib.sessions.models import Session
from django.core.exceptions import ValidationError
from django.db import models, transaction from django.db import models, transaction
from django.dispatch import receiver
from django.template.defaultfilters import slugify from django.template.defaultfilters import slugify
from django.test.client import RequestFactory
from django.utils import timezone
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.lru_cache import lru_cache
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from modelcluster.models import ClusterableModel from modelcluster.models import ClusterableModel
from wagtail.wagtailadmin.edit_handlers import ( from wagtail.wagtailadmin.edit_handlers import (
FieldPanel, FieldRowPanel, InlinePanel, MultiFieldPanel) FieldPanel,
FieldRowPanel,
InlinePanel,
MultiFieldPanel,
)
from wagtail.wagtailcore.models import Page from wagtail.wagtailcore.models import Page
from wagtail_personalisation.rules import AbstractBaseRule from wagtail_personalisation.rules import AbstractBaseRule
@@ -29,23 +23,11 @@ from wagtail_personalisation.utils import count_active_days
from .forms import SegmentAdminForm from .forms import SegmentAdminForm
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class SegmentQuerySet(models.QuerySet): class SegmentQuerySet(models.QuerySet):
def enabled(self): def enabled(self):
return self.filter(status=self.model.STATUS_ENABLED) return self.filter(status=self.model.STATUS_ENABLED)
@lru_cache(maxsize=1000)
def user_from_data(user_id):
User = get_user_model()
try:
return User.objects.get(id=user_id)
except User.DoesNotExist:
return AnonymousUser
@python_2_unicode_compatible @python_2_unicode_compatible
class Segment(ClusterableModel): class Segment(ClusterableModel):
"""The segment model.""" """The segment model."""
@@ -192,29 +174,6 @@ class Segment(ClusterableModel):
self.save() self.save()
@receiver(models.signals.post_init, sender=Segment)
def populate_sessions_first_time(sender, **kwargs):
instance = kwargs.pop('instance', None)
if instance.can_populate:
request = RequestFactory().get('/')
for session in Session.objects.filter(
expire_date__gt=timezone.now(),
).iterator():
session_data = session.get_decoded()
user = user_from_data(session_data.get('_auth_id'))
request.user = user
request.session = SessionStore(session_key=session.session_key)
all_pass = all(rule.test_user(request) for rule in instance.get_rules() if rule.static)
if all_pass:
instance.sessions.add(session.session_key)
models.signals.post_init.disconnect(populate_sessions_first_time, sender=sender)
instance.frozen = True
instance.save()
models.signals.post_init.connect(populate_sessions_first_time, sender=sender)
class PersonalisablePageMetadata(ClusterableModel): class PersonalisablePageMetadata(ClusterableModel):
"""The personalisable page model. Allows creation of variants with linked """The personalisable page model. Allows creation of variants with linked
segments. segments.

View File

@@ -2,132 +2,16 @@ from __future__ import absolute_import, unicode_literals
import datetime import datetime
import pytest
from django.core.exceptions import ValidationError
from django.forms.models import model_to_dict from django.forms.models import model_to_dict
from tests.factories.segment import SegmentFactory from tests.factories.segment import SegmentFactory
import pytest
from wagtail_personalisation.forms import SegmentAdminForm from wagtail_personalisation.forms import SegmentAdminForm
from wagtail_personalisation.models import Segment from wagtail_personalisation.models import Segment
from wagtail_personalisation.rules import TimeRule, VisitCountRule from wagtail_personalisation.rules import TimeRule, VisitCountRule
@pytest.mark.django_db def form_with_data(segment, *rules):
def test_session_added_to_static_segment_at_creation(rf, site, client):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory(type=Segment.TYPE_STATIC)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
segment.save()
# We need to trigger the post init
segment = Segment.objects.get(id=segment.id)
assert session.session_key in segment.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_mixed_static_dynamic_session_doesnt_generate_at_creation(rf, site, client):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory(type=Segment.TYPE_STATIC)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
TimeRule.objects.create(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
segment=segment,
)
segment.save()
# We need to trigger the post init
segment = Segment.objects.get(id=segment.id)
assert not segment.sessions.all()
@pytest.mark.django_db
def test_session_not_added_to_static_segment_after_creation(rf, site, client):
segment = SegmentFactory(type=Segment.TYPE_STATIC)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
segment.save()
session = client.session
session.save()
client.get(site.root_page.url)
assert not segment.sessions.all()
@pytest.mark.django_db
def test_session_added_to_static_segment_after_creation(rf, site, client):
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=1)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
segment.save()
session = client.session
session.save()
client.get(site.root_page.url)
assert session.session_key in segment.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_session_not_added_to_static_segment_after_full(rf, site, client):
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=1)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
segment.save()
session = client.session
session.save()
client.get(site.root_page.url)
second_session = client.session
second_session.create()
client.get(site.root_page.url)
assert session.session_key != second_session.session_key
assert segment.sessions.count() == 1
assert session.session_key in segment.sessions.values_list('session_key', flat=True)
assert second_session.session_key not in segment.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_sessions_not_added_to_static_segment_if_rule_not_static(client, site):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory(type=Segment.TYPE_STATIC)
TimeRule.objects.create(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
segment=segment,
)
segment.save()
assert not segment.sessions.all()
@pytest.mark.django_db
def test_does_not_calculate_the_segment_again(rf, site, client, mocker):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=2)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
segment.save()
mock_test_rule = mocker.patch('wagtail_personalisation.adapters.SessionSegmentsAdapter._test_rules')
client.get(site.root_page.url)
assert mock_test_rule.call_count == 0
def form_with_data(segment, rule):
model_fields = ['type', 'status', 'count', 'name'] model_fields = ['type', 'status', 'count', 'name']
class TestSegmentAdminForm(SegmentAdminForm): class TestSegmentAdminForm(SegmentAdminForm):
@@ -138,20 +22,147 @@ def form_with_data(segment, rule):
data = model_to_dict(segment, model_fields) data = model_to_dict(segment, model_fields)
for formset in TestSegmentAdminForm().formsets.values(): for formset in TestSegmentAdminForm().formsets.values():
rule_data = {} rule_data = {}
if isinstance(rule, formset.model): for rule in rules:
rule_data = model_to_dict(rule) if isinstance(rule, formset.model):
for key, value in rule_data.items(): rule_data = model_to_dict(rule)
data['{}-0-{}'.format(formset.prefix, key)] = value for key, value in rule_data.items():
data['{}-0-{}'.format(formset.prefix, key)] = value
data['{}-INITIAL_FORMS'.format(formset.prefix)] = 0 data['{}-INITIAL_FORMS'.format(formset.prefix)] = 0
data['{}-TOTAL_FORMS'.format(formset.prefix)] = 1 if rule_data else 0 data['{}-TOTAL_FORMS'.format(formset.prefix)] = 1 if rule_data else 0
return TestSegmentAdminForm(data) return TestSegmentAdminForm(data)
@pytest.mark.django_db
def test_session_added_to_static_segment_at_creation(site, client):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory.build(type=Segment.TYPE_STATIC)
rule = VisitCountRule(counted_page=site.root_page)
form = form_with_data(segment, rule)
instance = form.save()
assert instance.frozen
assert session.session_key in instance.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_mixed_static_dynamic_session_doesnt_generate_at_creation(site, client):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1)
static_rule = VisitCountRule(counted_page=site.root_page)
non_static_rule = TimeRule(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
)
form = form_with_data(segment, static_rule, non_static_rule)
instance = form.save()
assert instance.frozen
assert not instance.sessions.all()
@pytest.mark.django_db
def test_session_not_added_to_static_segment_after_creation(site, client):
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=0)
rule = VisitCountRule(counted_page=site.root_page)
form = form_with_data(segment, rule)
instance = form.save()
session = client.session
session.save()
client.get(site.root_page.url)
assert instance.frozen
assert not instance.sessions.all()
@pytest.mark.django_db
def test_session_added_to_static_segment_after_creation(site, client):
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1)
rule = VisitCountRule(counted_page=site.root_page)
form = form_with_data(segment, rule)
instance = form.save()
session = client.session
session.save()
client.get(site.root_page.url)
assert instance.frozen
assert session.session_key in instance.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_session_not_added_to_static_segment_after_full(site, client):
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1)
rule = VisitCountRule(counted_page=site.root_page)
form = form_with_data(segment, rule)
instance = form.save()
assert instance.frozen
assert instance.sessions.count() == 0
session = client.session
session.save()
client.get(site.root_page.url)
assert instance.sessions.count() == 1
second_session = client.session
second_session.create()
client.get(site.root_page.url)
assert session.session_key != second_session.session_key
assert instance.sessions.count() == 1
assert session.session_key in instance.sessions.values_list('session_key', flat=True)
assert second_session.session_key not in instance.sessions.values_list('session_key', flat=True)
@pytest.mark.django_db
def test_sessions_not_added_to_static_segment_if_rule_not_static(client, site):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=1)
rule = TimeRule(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
segment=segment,
)
form = form_with_data(segment, rule)
instance = form.save()
assert instance.frozen
assert not instance.sessions.all()
@pytest.mark.django_db
def test_does_not_calculate_the_segment_again(site, client, mocker):
session = client.session
session.save()
client.get(site.root_page.url)
segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=2)
rule = VisitCountRule(counted_page=site.root_page, segment=segment)
form = form_with_data(segment, rule)
instance = form.save()
assert instance.frozen
assert session.session_key in instance.sessions.values_list('session_key', flat=True)
mock_test_rule = mocker.patch('wagtail_personalisation.adapters.SessionSegmentsAdapter._test_rules')
client.get(site.root_page.url)
assert mock_test_rule.call_count == 0
@pytest.mark.django_db @pytest.mark.django_db
def test_non_static_rules_have_a_count(): def test_non_static_rules_have_a_count():
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=0) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=0)
rule = TimeRule.objects.create( rule = TimeRule(
start_time=datetime.time(0, 0, 0), start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59), end_time=datetime.time(23, 59, 59),
segment=segment, segment=segment,
@@ -162,19 +173,18 @@ def test_non_static_rules_have_a_count():
@pytest.mark.django_db @pytest.mark.django_db
def test_static_segment_with_static_rules_needs_no_count(site): def test_static_segment_with_static_rules_needs_no_count(site):
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=0) segment = SegmentFactory.build(type=Segment.TYPE_STATIC, count=0)
rule = VisitCountRule.objects.create(counted_page=site.root_page, segment=segment) rule = VisitCountRule(counted_page=site.root_page, segment=segment)
form = form_with_data(segment, rule) form = form_with_data(segment, rule)
assert form.is_valid() assert form.is_valid()
@pytest.mark.django_db @pytest.mark.django_db
def test_dynamic_segment_with_non_static_rules_have_a_count(): def test_dynamic_segment_with_non_static_rules_have_a_count():
segment = SegmentFactory(type=Segment.TYPE_DYNAMIC, count=0) segment = SegmentFactory.build(type=Segment.TYPE_DYNAMIC, count=0)
rule = TimeRule.objects.create( rule = TimeRule(
start_time=datetime.time(0, 0, 0), start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59), end_time=datetime.time(23, 59, 59),
segment=segment,
) )
form = form_with_data(segment, rule) form = form_with_data(segment, rule)
assert form.is_valid(), form.errors assert form.is_valid(), form.errors