7

Support Wagtail 4.2 (#1)

Co-authored-by: nick.moreton <nick.moreton@torchbox.com>
Co-authored-by: Nick Moreton <nick.moreton@torchbox.com>
Reviewed-on: #1
This commit is contained in:
2023-05-07 03:25:48 +00:00
parent dd4530203f
commit b8d7dd53ae
90 changed files with 2666 additions and 1584 deletions

View File

@ -8,71 +8,72 @@ from tests.factories.segment import SegmentFactory
from wagtail_personalisation.rules import get_geoip_module
skip_if_geoip2_installed = pytest.mark.skipif(
find_spec('geoip2'), reason='requires GeoIP2 to be not installed'
find_spec("geoip2"), reason="requires GeoIP2 to be not installed"
)
skip_if_geoip2_not_installed = pytest.mark.skipif(
not find_spec('geoip2'), reason='requires GeoIP2 to be installed.'
not find_spec("geoip2"), reason="requires GeoIP2 to be installed."
)
@pytest.mark.django_db
def test_get_cloudflare_country_with_header(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', HTTP_CF_IPCOUNTRY='PL')
assert rule.get_cloudflare_country(request) == 'pl'
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", HTTP_CF_IPCOUNTRY="PL")
assert rule.get_cloudflare_country(request) == "pl"
@pytest.mark.django_db
def test_get_cloudflare_country_with_no_header(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/')
assert 'HTTP_CF_IPCOUNTRY' not in request.META
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/")
assert "HTTP_CF_IPCOUNTRY" not in request.META
assert rule.get_cloudflare_country(request) is None
@pytest.mark.django_db
def test_get_cloudfront_country_with_header(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='BY')
assert rule.get_cloudfront_country(request) == 'by'
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="BY")
assert rule.get_cloudfront_country(request) == "by"
@pytest.mark.django_db
def test_get_cloudfront_country_with_no_header(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/')
assert 'HTTP_CLOUDFRONT_VIEWER_COUNTRY' not in request.META
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/")
assert "HTTP_CLOUDFRONT_VIEWER_COUNTRY" not in request.META
assert rule.get_cloudfront_country(request) is None
@pytest.mark.django_db
def test_get_geoip_country_with_remote_addr(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', REMOTE_ADDR='173.254.89.34')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", REMOTE_ADDR="173.254.89.34")
geoip_mock = MagicMock()
with patch('wagtail_personalisation.rules.get_geoip_module',
return_value=geoip_mock) as geoip_import_mock:
with patch(
"wagtail_personalisation.rules.get_geoip_module", return_value=geoip_mock
) as geoip_import_mock:
rule.get_geoip_country(request)
geoip_import_mock.assert_called_once()
geoip_mock.assert_called_once()
assert geoip_mock.mock_calls[1] == call().country_code('173.254.89.34')
assert geoip_mock.mock_calls[1] == call().country_code("173.254.89.34")
@pytest.mark.django_db
def test_get_country_calls_all_methods(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/")
@patch.object(rule, 'get_geoip_country', return_value='')
@patch.object(rule, 'get_cloudflare_country', return_value='')
@patch.object(rule, 'get_cloudfront_country', return_value='')
@patch.object(rule, "get_geoip_country", return_value="")
@patch.object(rule, "get_cloudflare_country", return_value="")
@patch.object(rule, "get_cloudfront_country", return_value="")
def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock):
country = rule.get_country(request)
cloudflare_mock.assert_called_once_with(request)
@ -85,107 +86,106 @@ def test_get_country_calls_all_methods(rf):
@pytest.mark.django_db
def test_get_country_does_not_use_all_detection_methods_unnecessarily(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/")
@patch.object(rule, 'get_geoip_country', return_value='')
@patch.object(rule, 'get_cloudflare_country', return_value='t1')
@patch.object(rule, 'get_cloudfront_country', return_value='')
@patch.object(rule, "get_geoip_country", return_value="")
@patch.object(rule, "get_cloudflare_country", return_value="t1")
@patch.object(rule, "get_cloudfront_country", return_value="")
def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock):
country = rule.get_country(request)
cloudflare_mock.assert_called_once_with(request)
cloudfront_mock.assert_not_called()
geoip_mock.assert_not_called()
assert country == 't1'
assert country == "t1"
test_mock()
@pytest.mark.django_db
def test_test_user_calls_get_country(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/')
with patch.object(rule, 'get_country') as get_country_mock:
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/")
with patch.object(rule, "get_country") as get_country_mock:
rule.test_user(request)
get_country_mock.assert_called_once_with(request)
@pytest.mark.django_db
def test_test_user_returns_true_if_cloudflare_country_match(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', HTTP_CF_IPCOUNTRY='GB')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", HTTP_CF_IPCOUNTRY="GB")
assert rule.test_user(request) is True
@pytest.mark.django_db
def test_test_user_returns_false_if_cloudflare_country_doesnt_match(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', HTTP_CF_IPCOUNTRY='NL')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", HTTP_CF_IPCOUNTRY="NL")
assert not rule.test_user(request)
@pytest.mark.django_db
def test_test_user_returns_false_if_cloudfront_country_doesnt_match(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='GB')
request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='NL')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="GB")
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="NL")
assert rule.test_user(request) is False
@pytest.mark.django_db
def test_test_user_returns_true_if_cloudfront_country_matches(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='se')
request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='SE')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="se")
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="SE")
assert rule.test_user(request) is True
@skip_if_geoip2_not_installed
@pytest.mark.django_db
def test_test_user_geoip_module_matches(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='se')
request = rf.get('/', REMOTE_ADDR='123.120.0.2')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="se")
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
GeoIP2Mock = MagicMock()
GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'})
GeoIP2Mock().configure_mock(**{"country_code.return_value": "SE"})
GeoIP2Mock.reset_mock()
with patch('wagtail_personalisation.rules.get_geoip_module',
return_value=GeoIP2Mock):
with patch(
"wagtail_personalisation.rules.get_geoip_module", return_value=GeoIP2Mock
):
assert rule.test_user(request) is True
assert GeoIP2Mock.mock_calls == [
call(),
call().country_code('123.120.0.2'),
call().country_code("123.120.0.2"),
]
@skip_if_geoip2_not_installed
@pytest.mark.django_db
def test_test_user_geoip_module_does_not_match(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='nl')
request = rf.get('/', REMOTE_ADDR='123.120.0.2')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="nl")
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
GeoIP2Mock = MagicMock()
GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'})
GeoIP2Mock().configure_mock(**{"country_code.return_value": "SE"})
GeoIP2Mock.reset_mock()
with patch('wagtail_personalisation.rules.get_geoip_module',
return_value=GeoIP2Mock):
with patch(
"wagtail_personalisation.rules.get_geoip_module", return_value=GeoIP2Mock
):
assert rule.test_user(request) is False
assert GeoIP2Mock.mock_calls == [
call(),
call().country_code('123.120.0.2')
]
assert GeoIP2Mock.mock_calls == [call(), call().country_code("123.120.0.2")]
@skip_if_geoip2_installed
@pytest.mark.django_db
def test_test_user_does_not_use_geoip_module_if_disabled(rf):
segment = SegmentFactory(name='Test segment')
rule = OriginCountryRuleFactory(segment=segment, country='se')
request = rf.get('/', REMOTE_ADDR='123.120.0.2')
segment = SegmentFactory(name="Test segment")
rule = OriginCountryRuleFactory(segment=segment, country="se")
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
assert rule.test_user(request) is False
@ -199,4 +199,5 @@ def test_get_geoip_module_disabled():
@skip_if_geoip2_not_installed
def test_get_geoip_module_enabled():
from django.contrib.gis.geoip2 import GeoIP2
assert get_geoip_module() is GeoIP2