diff --git a/src/wagtail_personalisation/migrations/0014_add_frozen_to_segment.py b/src/wagtail_personalisation/migrations/0014_add_frozen_to_segment.py new file mode 100644 index 0000000..d594e03 --- /dev/null +++ b/src/wagtail_personalisation/migrations/0014_add_frozen_to_segment.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.6 on 2017-10-20 16:26 +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('wagtail_personalisation', '0013_add_dynamic_static_to_segment'), + ] + + operations = [ + migrations.AddField( + model_name='segment', + name='frozen', + field=models.BooleanField(default=False), + ), + ] diff --git a/src/wagtail_personalisation/models.py b/src/wagtail_personalisation/models.py index 913c363..6ebdfcc 100644 --- a/src/wagtail_personalisation/models.py +++ b/src/wagtail_personalisation/models.py @@ -9,6 +9,7 @@ from django.contrib.auth.models import AnonymousUser from django.contrib.sessions.models import Session from django.core.exceptions import ValidationError from django.db import models, transaction +from django.dispatch import receiver from django.template.defaultfilters import slugify from django.test.client import RequestFactory from django.utils import timezone @@ -98,6 +99,7 @@ class Segment(ClusterableModel): ) ) sessions = models.ManyToManyField(Session) + frozen = models.BooleanField(default=False) objects = SegmentQuerySet.as_manager() @@ -149,6 +151,12 @@ class Segment(ClusterableModel): def is_full(self): return self.sessions.count() >= self.count + @property + def can_populate(self): + return ( + self.id and self.is_static and not self.frozen and self.is_consistent + ) + def encoded_name(self): """Return a string with a slug for the segment.""" return slugify(self.name.lower()) @@ -185,22 +193,28 @@ class Segment(ClusterableModel): if save: self.save() - def save(self, *args, **kwargs): - super(Segment, self).save(*args, **kwargs) - if self.is_static: - request = RequestFactory().get('/') +@receiver(models.signals.post_init, sender=Segment) +def populate_sessions_first_time(sender, **kwargs): + instance = kwargs.pop('instance', None) + if instance.can_populate: + request = RequestFactory().get('/') - for session in Session.objects.filter( - expire_date__gt=timezone.now(), - ).iterator(): - session_data = session.get_decoded() - user = user_from_data(session_data.get('_auth_id')) - request.user = user - request.session = SessionStore(session_key=session.session_key) - all_pass = all(rule.test_user(request) for rule in self.get_rules() if rule.static) - if not self.is_consistent and all_pass: - self.sessions.add(session) + for session in Session.objects.filter( + expire_date__gt=timezone.now(), + ).iterator(): + session_data = session.get_decoded() + user = user_from_data(session_data.get('_auth_id')) + request.user = user + request.session = SessionStore(session_key=session.session_key) + all_pass = all(rule.test_user(request) for rule in instance.get_rules() if rule.static) + if all_pass: + instance.sessions.add(session.session_key) + + models.signals.post_init.disconnect(populate_sessions_first_time, sender=sender) + instance.frozen = True + instance.save() + models.signals.post_init.connect(populate_sessions_first_time, sender=sender) class PersonalisablePageMetadata(ClusterableModel): diff --git a/tests/unit/test_static_dynamic_segments.py b/tests/unit/test_static_dynamic_segments.py index 0364814..a97cc0f 100644 --- a/tests/unit/test_static_dynamic_segments.py +++ b/tests/unit/test_static_dynamic_segments.py @@ -20,6 +20,9 @@ def test_session_added_to_static_segment_at_creation(rf, site, client): VisitCountRule.objects.create(counted_page=site.root_page, segment=segment) segment.save() + # We need to trigger the post init + segment = Segment.objects.get(id=segment.id) + assert session.session_key in segment.sessions.values_list('session_key', flat=True) @@ -38,6 +41,9 @@ def test_mixed_static_dynamic_session_doesnt_generate_at_creation(rf, site, clie ) segment.save() + # We need to trigger the post init + segment = Segment.objects.get(id=segment.id) + assert not segment.sessions.all()