7

Add clean method to ensure mixed static segments are valid

This commit is contained in:
Todd Dembrey
2017-10-20 10:57:19 +01:00
parent f339879907
commit cf41be4b76
2 changed files with 52 additions and 5 deletions

View File

@ -7,6 +7,7 @@ from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.contrib.sessions.models import Session from django.contrib.sessions.models import Session
from django.core.exceptions import ValidationError
from django.db import models, transaction from django.db import models, transaction
from django.template.defaultfilters import slugify from django.template.defaultfilters import slugify
from django.test.client import RequestFactory from django.test.client import RequestFactory
@ -120,10 +121,22 @@ class Segment(ClusterableModel):
def __str__(self): def __str__(self):
return self.name return self.name
def clean(self):
if self.is_static and not self.is_consistent and not self.count:
raise ValidationError({
'count': _('Static segments with non-static rules must include a count.'),
})
@property @property
def is_static(self): def is_static(self):
return self.type == self.TYPE_STATIC return self.type == self.TYPE_STATIC
@property
def is_consistent(self):
rules = self.get_rules()
all_static = all(rule.static for rule in rules)
return rules and all_static
@property @property
def is_full(self): def is_full(self):
return self.sessions.count() >= self.count return self.sessions.count() >= self.count
@ -170,9 +183,6 @@ class Segment(ClusterableModel):
if self.is_static: if self.is_static:
request = RequestFactory().get('/') request = RequestFactory().get('/')
rules = self.get_rules()
all_static = all(rule.static for rule in rules)
for session in Session.objects.filter( for session in Session.objects.filter(
expire_date__gt=timezone.now(), expire_date__gt=timezone.now(),
).iterator(): ).iterator():
@ -180,8 +190,8 @@ class Segment(ClusterableModel):
user = user_from_data(session_data.get('_auth_id')) user = user_from_data(session_data.get('_auth_id'))
request.user = user request.user = user
request.session = SessionStore(session_key=session.session_key) request.session = SessionStore(session_key=session.session_key)
all_pass = all(rule.test_user(request) for rule in rules if rule.static) all_pass = all(rule.test_user(request) for rule in self.get_rules() if rule.static)
if rules and all_static and all_pass: if not self.is_consistent and all_pass:
self.sessions.add(session) self.sessions.add(session)

View File

@ -4,6 +4,7 @@ import datetime
import pytest import pytest
from django.core.exceptions import ValidationError
from tests.factories.segment import SegmentFactory from tests.factories.segment import SegmentFactory
from wagtail_personalisation.models import Segment from wagtail_personalisation.models import Segment
from wagtail_personalisation.rules import TimeRule, VisitCountRule from wagtail_personalisation.rules import TimeRule, VisitCountRule
@ -116,3 +117,39 @@ def test_does_not_calculate_the_segment_again(rf, site, client, mocker):
mock_test_rule = mocker.patch('wagtail_personalisation.adapters.SessionSegmentsAdapter._test_rules') mock_test_rule = mocker.patch('wagtail_personalisation.adapters.SessionSegmentsAdapter._test_rules')
client.get(site.root_page.url) client.get(site.root_page.url)
assert mock_test_rule.call_count == 0 assert mock_test_rule.call_count == 0
@pytest.mark.django_db
def test_non_static_rules_have_a_count():
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=0)
TimeRule.objects.create(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
segment=segment,
)
with pytest.raises(ValidationError):
segment.clean()
@pytest.mark.django_db
def test_static_segment_with_static_rules_needs_no_count(site):
segment = SegmentFactory(type=Segment.TYPE_STATIC, count=0)
VisitCountRule.objects.create(counted_page=site.root_page, segment=segment)
try:
segment.clean()
except ValidationError:
pytest.fail('Should not raise ValidationError.')
@pytest.mark.django_db
def test_dynamic_segment_with_non_static_rules_have_a_count():
segment = SegmentFactory(type=Segment.TYPE_DYNAMIC, count=0)
TimeRule.objects.create(
start_time=datetime.time(0, 0, 0),
end_time=datetime.time(23, 59, 59),
segment=segment,
)
try:
segment.clean()
except ValidationError:
pytest.fail('Should not raise ValidationError.')