Co-authored-by: nick.moreton <nick.moreton@torchbox.com> Co-authored-by: Nick Moreton <nick.moreton@torchbox.com> Reviewed-on: #1
204 lines
7.3 KiB
Python
204 lines
7.3 KiB
Python
from importlib.util import find_spec
|
|
from unittest.mock import MagicMock, call, patch
|
|
|
|
import pytest
|
|
|
|
from tests.factories.rule import OriginCountryRuleFactory
|
|
from tests.factories.segment import SegmentFactory
|
|
from wagtail_personalisation.rules import get_geoip_module
|
|
|
|
skip_if_geoip2_installed = pytest.mark.skipif(
|
|
find_spec("geoip2"), reason="requires GeoIP2 to be not installed"
|
|
)
|
|
|
|
skip_if_geoip2_not_installed = pytest.mark.skipif(
|
|
not find_spec("geoip2"), reason="requires GeoIP2 to be installed."
|
|
)
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_cloudflare_country_with_header(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", HTTP_CF_IPCOUNTRY="PL")
|
|
assert rule.get_cloudflare_country(request) == "pl"
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_cloudflare_country_with_no_header(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/")
|
|
assert "HTTP_CF_IPCOUNTRY" not in request.META
|
|
assert rule.get_cloudflare_country(request) is None
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_cloudfront_country_with_header(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="BY")
|
|
assert rule.get_cloudfront_country(request) == "by"
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_cloudfront_country_with_no_header(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/")
|
|
assert "HTTP_CLOUDFRONT_VIEWER_COUNTRY" not in request.META
|
|
assert rule.get_cloudfront_country(request) is None
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_geoip_country_with_remote_addr(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", REMOTE_ADDR="173.254.89.34")
|
|
geoip_mock = MagicMock()
|
|
with patch(
|
|
"wagtail_personalisation.rules.get_geoip_module", return_value=geoip_mock
|
|
) as geoip_import_mock:
|
|
rule.get_geoip_country(request)
|
|
geoip_import_mock.assert_called_once()
|
|
geoip_mock.assert_called_once()
|
|
assert geoip_mock.mock_calls[1] == call().country_code("173.254.89.34")
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_country_calls_all_methods(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/")
|
|
|
|
@patch.object(rule, "get_geoip_country", return_value="")
|
|
@patch.object(rule, "get_cloudflare_country", return_value="")
|
|
@patch.object(rule, "get_cloudfront_country", return_value="")
|
|
def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock):
|
|
country = rule.get_country(request)
|
|
cloudflare_mock.assert_called_once_with(request)
|
|
cloudfront_mock.assert_called_once_with(request)
|
|
geoip_mock.assert_called_once_with(request)
|
|
assert country is None
|
|
|
|
test_mock()
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_get_country_does_not_use_all_detection_methods_unnecessarily(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/")
|
|
|
|
@patch.object(rule, "get_geoip_country", return_value="")
|
|
@patch.object(rule, "get_cloudflare_country", return_value="t1")
|
|
@patch.object(rule, "get_cloudfront_country", return_value="")
|
|
def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock):
|
|
country = rule.get_country(request)
|
|
cloudflare_mock.assert_called_once_with(request)
|
|
cloudfront_mock.assert_not_called()
|
|
geoip_mock.assert_not_called()
|
|
assert country == "t1"
|
|
|
|
test_mock()
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_test_user_calls_get_country(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/")
|
|
with patch.object(rule, "get_country") as get_country_mock:
|
|
rule.test_user(request)
|
|
get_country_mock.assert_called_once_with(request)
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_test_user_returns_true_if_cloudflare_country_match(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", HTTP_CF_IPCOUNTRY="GB")
|
|
assert rule.test_user(request) is True
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_test_user_returns_false_if_cloudflare_country_doesnt_match(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", HTTP_CF_IPCOUNTRY="NL")
|
|
assert not rule.test_user(request)
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_test_user_returns_false_if_cloudfront_country_doesnt_match(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="GB")
|
|
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="NL")
|
|
assert rule.test_user(request) is False
|
|
|
|
|
|
@pytest.mark.django_db
|
|
def test_test_user_returns_true_if_cloudfront_country_matches(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="se")
|
|
request = rf.get("/", HTTP_CLOUDFRONT_VIEWER_COUNTRY="SE")
|
|
assert rule.test_user(request) is True
|
|
|
|
|
|
@skip_if_geoip2_not_installed
|
|
@pytest.mark.django_db
|
|
def test_test_user_geoip_module_matches(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="se")
|
|
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
|
|
GeoIP2Mock = MagicMock()
|
|
GeoIP2Mock().configure_mock(**{"country_code.return_value": "SE"})
|
|
GeoIP2Mock.reset_mock()
|
|
with patch(
|
|
"wagtail_personalisation.rules.get_geoip_module", return_value=GeoIP2Mock
|
|
):
|
|
assert rule.test_user(request) is True
|
|
assert GeoIP2Mock.mock_calls == [
|
|
call(),
|
|
call().country_code("123.120.0.2"),
|
|
]
|
|
|
|
|
|
@skip_if_geoip2_not_installed
|
|
@pytest.mark.django_db
|
|
def test_test_user_geoip_module_does_not_match(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="nl")
|
|
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
|
|
GeoIP2Mock = MagicMock()
|
|
GeoIP2Mock().configure_mock(**{"country_code.return_value": "SE"})
|
|
GeoIP2Mock.reset_mock()
|
|
with patch(
|
|
"wagtail_personalisation.rules.get_geoip_module", return_value=GeoIP2Mock
|
|
):
|
|
assert rule.test_user(request) is False
|
|
assert GeoIP2Mock.mock_calls == [call(), call().country_code("123.120.0.2")]
|
|
|
|
|
|
@skip_if_geoip2_installed
|
|
@pytest.mark.django_db
|
|
def test_test_user_does_not_use_geoip_module_if_disabled(rf):
|
|
segment = SegmentFactory(name="Test segment")
|
|
rule = OriginCountryRuleFactory(segment=segment, country="se")
|
|
request = rf.get("/", REMOTE_ADDR="123.120.0.2")
|
|
assert rule.test_user(request) is False
|
|
|
|
|
|
@skip_if_geoip2_installed
|
|
def test_get_geoip_module_disabled():
|
|
with pytest.raises(ImportError):
|
|
from django.contrib.gis.geoip2 import GeoIP2 # noqa
|
|
assert get_geoip_module() is None
|
|
|
|
|
|
@skip_if_geoip2_not_installed
|
|
def test_get_geoip_module_enabled():
|
|
from django.contrib.gis.geoip2 import GeoIP2
|
|
|
|
assert get_geoip_module() is GeoIP2
|