from importlib.util import find_spec from unittest.mock import MagicMock, call, patch import pytest from tests.factories.rule import OriginCountryRuleFactory from tests.factories.segment import SegmentFactory from wagtail_personalisation.rules import get_geoip_module skip_if_geoip2_installed = pytest.mark.skipif( find_spec('geoip2'), reason='requires GeoIP2 to be not installed' ) skip_if_geoip2_not_installed = pytest.mark.skipif( not find_spec('geoip2'), reason='requires GeoIP2 to be installed.' ) @pytest.mark.django_db def test_get_cloudflare_country_with_header(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', HTTP_CF_IPCOUNTRY='PL') assert rule.get_cloudflare_country(request) == 'pl' @pytest.mark.django_db def test_get_cloudflare_country_with_no_header(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/') assert 'HTTP_CF_IPCOUNTRY' not in request.META assert rule.get_cloudflare_country(request) is None @pytest.mark.django_db def test_get_cloudfront_country_with_header(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='BY') assert rule.get_cloudfront_country(request) == 'by' @pytest.mark.django_db def test_get_cloudfront_country_with_no_header(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/') assert 'HTTP_CLOUDFRONT_VIEWER_COUNTRY' not in request.META assert rule.get_cloudfront_country(request) is None @pytest.mark.django_db def test_get_geoip_country_with_remote_addr(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', REMOTE_ADDR='173.254.89.34') geoip_mock = MagicMock() with patch('wagtail_personalisation.rules.get_geoip_module', return_value=geoip_mock) as geoip_import_mock: rule.get_geoip_country(request) geoip_import_mock.assert_called_once() geoip_mock.assert_called_once() assert geoip_mock.mock_calls[1] == call().country_code('173.254.89.34') @pytest.mark.django_db def test_get_country_calls_all_methods(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/') @patch.object(rule, 'get_geoip_country', return_value='') @patch.object(rule, 'get_cloudflare_country', return_value='') @patch.object(rule, 'get_cloudfront_country', return_value='') def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock): country = rule.get_country(request) cloudflare_mock.assert_called_once_with(request) cloudfront_mock.assert_called_once_with(request) geoip_mock.assert_called_once_with(request) assert country is None test_mock() @pytest.mark.django_db def test_get_country_does_not_use_all_detection_methods_unnecessarily(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/') @patch.object(rule, 'get_geoip_country', return_value='') @patch.object(rule, 'get_cloudflare_country', return_value='t1') @patch.object(rule, 'get_cloudfront_country', return_value='') def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock): country = rule.get_country(request) cloudflare_mock.assert_called_once_with(request) cloudfront_mock.assert_not_called() geoip_mock.assert_not_called() assert country == 't1' test_mock() @pytest.mark.django_db def test_test_user_calls_get_country(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/') with patch.object(rule, 'get_country') as get_country_mock: rule.test_user(request) get_country_mock.assert_called_once_with(request) @pytest.mark.django_db def test_test_user_returns_true_if_cloudflare_country_match(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', HTTP_CF_IPCOUNTRY='GB') assert rule.test_user(request) is True @pytest.mark.django_db def test_test_user_returns_false_if_cloudflare_country_doesnt_match(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', HTTP_CF_IPCOUNTRY='NL') assert not rule.test_user(request) @pytest.mark.django_db def test_test_user_returns_false_if_cloudfront_country_doesnt_match(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='GB') request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='NL') assert rule.test_user(request) is False @pytest.mark.django_db def test_test_user_returns_true_if_cloudfront_country_matches(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='se') request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='SE') assert rule.test_user(request) is True @skip_if_geoip2_not_installed @pytest.mark.django_db def test_test_user_geoip_module_matches(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='se') request = rf.get('/', REMOTE_ADDR='123.120.0.2') GeoIP2Mock = MagicMock() GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'}) GeoIP2Mock.reset_mock() with patch('wagtail_personalisation.rules.get_geoip_module', return_value=GeoIP2Mock): assert rule.test_user(request) is True assert GeoIP2Mock.mock_calls == [ call(), call().country_code('123.120.0.2'), ] @skip_if_geoip2_not_installed @pytest.mark.django_db def test_test_user_geoip_module_does_not_match(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='nl') request = rf.get('/', REMOTE_ADDR='123.120.0.2') GeoIP2Mock = MagicMock() GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'}) GeoIP2Mock.reset_mock() with patch('wagtail_personalisation.rules.get_geoip_module', return_value=GeoIP2Mock): assert rule.test_user(request) is False assert GeoIP2Mock.mock_calls == [ call(), call().country_code('123.120.0.2') ] @skip_if_geoip2_installed @pytest.mark.django_db def test_test_user_does_not_use_geoip_module_if_disabled(rf): segment = SegmentFactory(name='Test segment') rule = OriginCountryRuleFactory(segment=segment, country='se') request = rf.get('/', REMOTE_ADDR='123.120.0.2') assert rule.test_user(request) is False @skip_if_geoip2_installed def test_get_geoip_module_disabled(): with pytest.raises(ImportError): from django.contrib.gis.geoip2 import GeoIP2 # noqa assert get_geoip_module() is None @skip_if_geoip2_not_installed def test_get_geoip_module_enabled(): from django.contrib.gis.geoip2 import GeoIP2 assert get_geoip_module() is GeoIP2