diff --git a/.travis.yml b/.travis.yml index 6a10249..7d20c4a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,8 +8,12 @@ matrix: env: TOXENV=lint - python: 3.6 env: TOXENV=py36-django20-wagtail20 + - python: 3.6 + env: TOXENV=py36-django20-wagtail20-geoip2 - python: 3.6 env: TOXENV=py36-django20-wagtail21 + - python: 3.6 + env: TOXENV=py36-django20-wagtail21-geoip2 install: - pip install tox codecov diff --git a/docs/default_rules.rst b/docs/default_rules.rst index 3bca625..a42c3dd 100644 --- a/docs/default_rules.rst +++ b/docs/default_rules.rst @@ -131,3 +131,47 @@ Is logged in Whether the user is logged in or logged out. ================== ========================================================== ``wagtail_personalisation.rules.UserIsLoggedInRule`` + + +Origin country rule +------------------- + +The origin country rule allows you to match visitors based on the origin +country of their request. This rule requires to have set up a way to detect +countries beforehand. + +================== ========================================================== +Option Description +================== ========================================================== +Country What country user's request comes from. +================== ========================================================== + +You must have one of the following configurations set up in order to +make it work. + +- Cloudflare IP Geolocation - ``cf-ipcountry`` HTTP header set with a value of + the alpha-2 country format. +- CloudFront Geo-Targeting - ``cloudfront-viewer-country`` header set with a + value of the alpha-2 country format. +- The last fallback is to use GeoIP2 module that is included with Django. This + requires setting up an IP database beforehand, see the Django's + `GeoIP2 instructions `_ + for more information. It will use IP of the request, using HTTP header + the ``x-forwarded-for`` HTTP header and ``REMOTE_ADDR`` server value as a + fallback. If you want to use a custom logic when obtaining IP address, please + set the ``WAGTAIL_PERSONALISATION_IP_FUNCTION`` setting to the function that takes a + request as an argument, e.g. + + .. code-block:: python + + # settings.py + + WAGTAIL_PERSONALISATION_IP_FUNCTION = 'yourproject.utils.get_client_ip' + + + # yourproject/utils.py + + def get_client_ip(request): + return request['HTTP_CF_CONNECTING_IP'] + +``wagtail_personalisation.rules.OriginCountryRule`` diff --git a/setup.py b/setup.py index 13cd1ff..dd8ab90 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ install_requires = [ 'wagtail>=2.0,<2.2', 'user-agents>=1.1.0', 'wagtailfontawesome>=1.1.3', + 'pycountry', ] tests_require = [ diff --git a/src/wagtail_personalisation/migrations/0024_origincountryrule.py b/src/wagtail_personalisation/migrations/0024_origincountryrule.py new file mode 100644 index 0000000..12d9bfb --- /dev/null +++ b/src/wagtail_personalisation/migrations/0024_origincountryrule.py @@ -0,0 +1,26 @@ +# Generated by Django 2.0.6 on 2018-08-10 14:39 + +from django.db import migrations, models +import django.db.models.deletion +import modelcluster.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ('wagtail_personalisation', '0023_personalisablepagemetadata_variant_cascade'), + ] + + operations = [ + migrations.CreateModel( + name='OriginCountryRule', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('country', models.CharField(choices=[('aw', 'Aruba'), ('af', 'Afghanistan'), ('ao', 'Angola'), ('ai', 'Anguilla'), ('ax', 'Åland Islands'), ('al', 'Albania'), ('ad', 'Andorra'), ('ae', 'United Arab Emirates'), ('ar', 'Argentina'), ('am', 'Armenia'), ('as', 'American Samoa'), ('aq', 'Antarctica'), ('tf', 'French Southern Territories'), ('ag', 'Antigua and Barbuda'), ('au', 'Australia'), ('at', 'Austria'), ('az', 'Azerbaijan'), ('bi', 'Burundi'), ('be', 'Belgium'), ('bj', 'Benin'), ('bq', 'Bonaire, Sint Eustatius and Saba'), ('bf', 'Burkina Faso'), ('bd', 'Bangladesh'), ('bg', 'Bulgaria'), ('bh', 'Bahrain'), ('bs', 'Bahamas'), ('ba', 'Bosnia and Herzegovina'), ('bl', 'Saint Barthélemy'), ('by', 'Belarus'), ('bz', 'Belize'), ('bm', 'Bermuda'), ('bo', 'Bolivia, Plurinational State of'), ('br', 'Brazil'), ('bb', 'Barbados'), ('bn', 'Brunei Darussalam'), ('bt', 'Bhutan'), ('bv', 'Bouvet Island'), ('bw', 'Botswana'), ('cf', 'Central African Republic'), ('ca', 'Canada'), ('cc', 'Cocos (Keeling) Islands'), ('ch', 'Switzerland'), ('cl', 'Chile'), ('cn', 'China'), ('ci', "Côte d'Ivoire"), ('cm', 'Cameroon'), ('cd', 'Congo, The Democratic Republic of the'), ('cg', 'Congo'), ('ck', 'Cook Islands'), ('co', 'Colombia'), ('km', 'Comoros'), ('cv', 'Cabo Verde'), ('cr', 'Costa Rica'), ('cu', 'Cuba'), ('cw', 'Curaçao'), ('cx', 'Christmas Island'), ('ky', 'Cayman Islands'), ('cy', 'Cyprus'), ('cz', 'Czechia'), ('de', 'Germany'), ('dj', 'Djibouti'), ('dm', 'Dominica'), ('dk', 'Denmark'), ('do', 'Dominican Republic'), ('dz', 'Algeria'), ('ec', 'Ecuador'), ('eg', 'Egypt'), ('er', 'Eritrea'), ('eh', 'Western Sahara'), ('es', 'Spain'), ('ee', 'Estonia'), ('et', 'Ethiopia'), ('fi', 'Finland'), ('fj', 'Fiji'), ('fk', 'Falkland Islands (Malvinas)'), ('fr', 'France'), ('fo', 'Faroe Islands'), ('fm', 'Micronesia, Federated States of'), ('ga', 'Gabon'), ('gb', 'United Kingdom'), ('ge', 'Georgia'), ('gg', 'Guernsey'), ('gh', 'Ghana'), ('gi', 'Gibraltar'), ('gn', 'Guinea'), ('gp', 'Guadeloupe'), ('gm', 'Gambia'), ('gw', 'Guinea-Bissau'), ('gq', 'Equatorial Guinea'), ('gr', 'Greece'), ('gd', 'Grenada'), ('gl', 'Greenland'), ('gt', 'Guatemala'), ('gf', 'French Guiana'), ('gu', 'Guam'), ('gy', 'Guyana'), ('hk', 'Hong Kong'), ('hm', 'Heard Island and McDonald Islands'), ('hn', 'Honduras'), ('hr', 'Croatia'), ('ht', 'Haiti'), ('hu', 'Hungary'), ('id', 'Indonesia'), ('im', 'Isle of Man'), ('in', 'India'), ('io', 'British Indian Ocean Territory'), ('ie', 'Ireland'), ('ir', 'Iran, Islamic Republic of'), ('iq', 'Iraq'), ('is', 'Iceland'), ('il', 'Israel'), ('it', 'Italy'), ('jm', 'Jamaica'), ('je', 'Jersey'), ('jo', 'Jordan'), ('jp', 'Japan'), ('kz', 'Kazakhstan'), ('ke', 'Kenya'), ('kg', 'Kyrgyzstan'), ('kh', 'Cambodia'), ('ki', 'Kiribati'), ('kn', 'Saint Kitts and Nevis'), ('kr', 'Korea, Republic of'), ('kw', 'Kuwait'), ('la', "Lao People's Democratic Republic"), ('lb', 'Lebanon'), ('lr', 'Liberia'), ('ly', 'Libya'), ('lc', 'Saint Lucia'), ('li', 'Liechtenstein'), ('lk', 'Sri Lanka'), ('ls', 'Lesotho'), ('lt', 'Lithuania'), ('lu', 'Luxembourg'), ('lv', 'Latvia'), ('mo', 'Macao'), ('mf', 'Saint Martin (French part)'), ('ma', 'Morocco'), ('mc', 'Monaco'), ('md', 'Moldova, Republic of'), ('mg', 'Madagascar'), ('mv', 'Maldives'), ('mx', 'Mexico'), ('mh', 'Marshall Islands'), ('mk', 'Macedonia, Republic of'), ('ml', 'Mali'), ('mt', 'Malta'), ('mm', 'Myanmar'), ('me', 'Montenegro'), ('mn', 'Mongolia'), ('mp', 'Northern Mariana Islands'), ('mz', 'Mozambique'), ('mr', 'Mauritania'), ('ms', 'Montserrat'), ('mq', 'Martinique'), ('mu', 'Mauritius'), ('mw', 'Malawi'), ('my', 'Malaysia'), ('yt', 'Mayotte'), ('na', 'Namibia'), ('nc', 'New Caledonia'), ('ne', 'Niger'), ('nf', 'Norfolk Island'), ('ng', 'Nigeria'), ('ni', 'Nicaragua'), ('nu', 'Niue'), ('nl', 'Netherlands'), ('no', 'Norway'), ('np', 'Nepal'), ('nr', 'Nauru'), ('nz', 'New Zealand'), ('om', 'Oman'), ('pk', 'Pakistan'), ('pa', 'Panama'), ('pn', 'Pitcairn'), ('pe', 'Peru'), ('ph', 'Philippines'), ('pw', 'Palau'), ('pg', 'Papua New Guinea'), ('pl', 'Poland'), ('pr', 'Puerto Rico'), ('kp', "Korea, Democratic People's Republic of"), ('pt', 'Portugal'), ('py', 'Paraguay'), ('ps', 'Palestine, State of'), ('pf', 'French Polynesia'), ('qa', 'Qatar'), ('re', 'Réunion'), ('ro', 'Romania'), ('ru', 'Russian Federation'), ('rw', 'Rwanda'), ('sa', 'Saudi Arabia'), ('sd', 'Sudan'), ('sn', 'Senegal'), ('sg', 'Singapore'), ('gs', 'South Georgia and the South Sandwich Islands'), ('sh', 'Saint Helena, Ascension and Tristan da Cunha'), ('sj', 'Svalbard and Jan Mayen'), ('sb', 'Solomon Islands'), ('sl', 'Sierra Leone'), ('sv', 'El Salvador'), ('sm', 'San Marino'), ('so', 'Somalia'), ('pm', 'Saint Pierre and Miquelon'), ('rs', 'Serbia'), ('ss', 'South Sudan'), ('st', 'Sao Tome and Principe'), ('sr', 'Suriname'), ('sk', 'Slovakia'), ('si', 'Slovenia'), ('se', 'Sweden'), ('sz', 'Swaziland'), ('sx', 'Sint Maarten (Dutch part)'), ('sc', 'Seychelles'), ('sy', 'Syrian Arab Republic'), ('tc', 'Turks and Caicos Islands'), ('td', 'Chad'), ('tg', 'Togo'), ('th', 'Thailand'), ('tj', 'Tajikistan'), ('tk', 'Tokelau'), ('tm', 'Turkmenistan'), ('tl', 'Timor-Leste'), ('to', 'Tonga'), ('tt', 'Trinidad and Tobago'), ('tn', 'Tunisia'), ('tr', 'Turkey'), ('tv', 'Tuvalu'), ('tw', 'Taiwan, Province of China'), ('tz', 'Tanzania, United Republic of'), ('ug', 'Uganda'), ('ua', 'Ukraine'), ('um', 'United States Minor Outlying Islands'), ('uy', 'Uruguay'), ('us', 'United States'), ('uz', 'Uzbekistan'), ('va', 'Holy See (Vatican City State)'), ('vc', 'Saint Vincent and the Grenadines'), ('ve', 'Venezuela, Bolivarian Republic of'), ('vg', 'Virgin Islands, British'), ('vi', 'Virgin Islands, U.S.'), ('vn', 'Viet Nam'), ('vu', 'Vanuatu'), ('wf', 'Wallis and Futuna'), ('ws', 'Samoa'), ('ye', 'Yemen'), ('za', 'South Africa'), ('zm', 'Zambia'), ('zw', 'Zimbabwe')], help_text='Select origin country of the request that this rule will match against. This rule will only work if you use Cloudflare or CloudFront IP geolocation or if GeoIP2 module is configured.', max_length=2)), + ('segment', modelcluster.fields.ParentalKey(on_delete=django.db.models.deletion.CASCADE, related_name='wagtail_personalisation_origincountryrules', to='wagtail_personalisation.Segment')), + ], + options={ + 'verbose_name': 'origin country rule', + }, + ), + ] diff --git a/src/wagtail_personalisation/rules.py b/src/wagtail_personalisation/rules.py index 0b391dc..d4d0801 100644 --- a/src/wagtail_personalisation/rules.py +++ b/src/wagtail_personalisation/rules.py @@ -1,9 +1,11 @@ from __future__ import absolute_import, unicode_literals +import logging import re from datetime import datetime from importlib import import_module +import pycountry from django.apps import apps from django.conf import settings from django.contrib.sessions.models import Session @@ -18,8 +20,28 @@ from user_agents import parse from wagtail.admin.edit_handlers import ( FieldPanel, FieldRowPanel, PageChooserPanel) +from wagtail_personalisation.utils import get_client_ip + SessionStore = import_module(settings.SESSION_ENGINE).SessionStore +logger = logging.getLogger(__name__) + + +def get_geoip_module(): + try: + from django.contrib.gis.geoip2 import GeoIP2 + return GeoIP2 + except ImportError: + logger.exception( + 'GeoIP module is disabled. To use GeoIP for the origin\n' + 'country personaliastion rule please set it up as per ' + 'documentation:\n' + 'https://docs.djangoproject.com/en/stable/ref/contrib/gis/' + 'geoip2/.\n' + 'Wagtail-personalisation also works with Cloudflare and\n' + 'CloudFront country detection, so you should not see this\n' + 'warning if you use one of those.') + @python_2_unicode_compatible class AbstractBaseRule(models.Model): @@ -408,3 +430,65 @@ class UserIsLoggedInRule(AbstractBaseRule): 'title': _('These visitors are'), 'value': _('Logged in') if self.is_logged_in else _('Not logged in'), } + + +COUNTRY_CHOICES = [(country.alpha_2.lower(), country.name) + for country in pycountry.countries] + + +class OriginCountryRule(AbstractBaseRule): + """ + Test user against the country or origin of their request. + + Using this rule requires setting up GeoIP2 on Django or using + CloudFlare or CloudFront geolocation detection. + """ + country = models.CharField( + max_length=2, choices=COUNTRY_CHOICES, + help_text=_("Select origin country of the request that this rule will " + "match against. This rule will only work if you use " + "Cloudflare or CloudFront IP geolocation or if GeoIP2 " + "module is configured.") + ) + + class Meta: + verbose_name = _("origin country rule") + + def get_cloudflare_country(self, request): + """ + Get country code that has been detected by Cloudflare. + + Guide to the functionality: + https://support.cloudflare.com/hc/en-us/articles/200168236-What-does-Cloudflare-IP-Geolocation-do- + """ + try: + return request.META['HTTP_CF_IPCOUNTRY'].lower() + except KeyError: + pass + + def get_cloudfront_country(self, request): + try: + return request.META['HTTP_CLOUDFRONT_VIEWER_COUNTRY'].lower() + except KeyError: + pass + + def get_geoip_country(self, request): + GeoIP2 = get_geoip_module() + if GeoIP2 is None: + return False + return GeoIP2().country_code(get_client_ip(request)).lower() + + def get_country(self, request): + # Prioritise CloudFlare and CloudFront country detection over GeoIP. + functions = ( + self.get_cloudflare_country, + self.get_cloudfront_country, + self.get_geoip_country, + ) + for function in functions: + result = function(request) + if result: + return result + + def test_user(self, request=None): + return (self.get_country(request) or '') == self.country.lower() diff --git a/src/wagtail_personalisation/utils.py b/src/wagtail_personalisation/utils.py index 7fc6448..d818182 100644 --- a/src/wagtail_personalisation/utils.py +++ b/src/wagtail_personalisation/utils.py @@ -1,8 +1,10 @@ import time +from django.conf import settings from django.db.models import F from django.template.base import FilterExpression, kwarg_re from django.utils import timezone +from django.utils.module_loading import import_string def impersonate_other_page(page, other_page): @@ -116,3 +118,17 @@ def can_delete_pages(pages, user): if not variant.permissions_for_user(user).can_delete(): return False return True + + +def get_client_ip(request): + try: + func = import_string(settings.WAGTAIL_PERSONALISATION_IP_FUNCTION) + except AttributeError: + pass + else: + return func(request) + try: + x_forwarded_for = request.META['HTTP_X_FORWARDED_FOR'] + return x_forwarded_for.split(',')[-1].strip() + except KeyError: + return request.META['REMOTE_ADDR'] diff --git a/tests/factories/rule.py b/tests/factories/rule.py index 5d602d0..ef8665c 100644 --- a/tests/factories/rule.py +++ b/tests/factories/rule.py @@ -46,3 +46,8 @@ class VisitCountRuleFactory(factory.DjangoModelFactory): class Meta: model = rules.VisitCountRule + + +class OriginCountryRuleFactory(factory.DjangoModelFactory): + class Meta: + model = rules.OriginCountryRule diff --git a/tests/unit/test_rules_country_origin.py b/tests/unit/test_rules_country_origin.py new file mode 100644 index 0000000..809d3a7 --- /dev/null +++ b/tests/unit/test_rules_country_origin.py @@ -0,0 +1,203 @@ +from importlib.util import find_spec +from unittest.mock import call, MagicMock, patch + +import pytest + +from tests.factories.rule import OriginCountryRuleFactory +from tests.factories.segment import SegmentFactory +from wagtail_personalisation.rules import get_geoip_module + + +skip_if_geoip2_installed = pytest.mark.skipif( + find_spec('geoip2'), reason='requires GeoIP2 to be not installed' +) + +skip_if_geoip2_not_installed = pytest.mark.skipif( + not find_spec('geoip2'), reason='requires GeoIP2 to be installed.' +) + + +@pytest.mark.django_db +def test_get_cloudflare_country_with_header(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', HTTP_CF_IPCOUNTRY='PL') + assert rule.get_cloudflare_country(request) == 'pl' + + +@pytest.mark.django_db +def test_get_cloudflare_country_with_no_header(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/') + assert 'HTTP_CF_IPCOUNTRY' not in request.META + assert rule.get_cloudflare_country(request) is None + + +@pytest.mark.django_db +def test_get_cloudfront_country_with_header(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='BY') + assert rule.get_cloudfront_country(request) == 'by' + + +@pytest.mark.django_db +def test_get_cloudfront_country_with_no_header(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/') + assert 'HTTP_CLOUDFRONT_VIEWER_COUNTRY' not in request.META + assert rule.get_cloudfront_country(request) is None + + +@pytest.mark.django_db +def test_get_geoip_country_with_remote_addr(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', REMOTE_ADDR='173.254.89.34') + geoip_mock = MagicMock() + with patch('wagtail_personalisation.rules.get_geoip_module', + return_value=geoip_mock) as geoip_import_mock: + rule.get_geoip_country(request) + geoip_import_mock.assert_called_once() + geoip_mock.assert_called_once() + assert geoip_mock.mock_calls[1] == call().country_code('173.254.89.34') + + +@pytest.mark.django_db +def test_get_country_calls_all_methods(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/') + + @patch.object(rule, 'get_geoip_country', return_value='') + @patch.object(rule, 'get_cloudflare_country', return_value='') + @patch.object(rule, 'get_cloudfront_country', return_value='') + def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock): + country = rule.get_country(request) + cloudflare_mock.assert_called_once_with(request) + cloudfront_mock.assert_called_once_with(request) + geoip_mock.assert_called_once_with(request) + assert country is None + + test_mock() + + +@pytest.mark.django_db +def test_get_country_does_not_use_all_detection_methods_unnecessarily(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/') + + @patch.object(rule, 'get_geoip_country', return_value='') + @patch.object(rule, 'get_cloudflare_country', return_value='t1') + @patch.object(rule, 'get_cloudfront_country', return_value='') + def test_mock(cloudfront_mock, cloudflare_mock, geoip_mock): + country = rule.get_country(request) + cloudflare_mock.assert_called_once_with(request) + cloudfront_mock.assert_not_called() + geoip_mock.assert_not_called() + assert country == 't1' + + test_mock() + + +@pytest.mark.django_db +def test_test_user_calls_get_country(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/') + with patch.object(rule, 'get_country') as get_country_mock: + rule.test_user(request) + get_country_mock.assert_called_once_with(request) + + +@pytest.mark.django_db +def test_test_user_returns_true_if_cloudflare_country_match(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', HTTP_CF_IPCOUNTRY='GB') + assert rule.test_user(request) is True + + +@pytest.mark.django_db +def test_test_user_returns_false_if_cloudflare_country_doesnt_match(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', HTTP_CF_IPCOUNTRY='NL') + assert not rule.test_user(request) + + +@pytest.mark.django_db +def test_test_user_returns_false_if_cloudfront_country_doesnt_match(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='GB') + request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='NL') + assert rule.test_user(request) is False + + +@pytest.mark.django_db +def test_test_user_returns_true_if_cloudfront_country_matches(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='se') + request = rf.get('/', HTTP_CLOUDFRONT_VIEWER_COUNTRY='SE') + assert rule.test_user(request) is True + + +@skip_if_geoip2_not_installed +@pytest.mark.django_db +def test_test_user_geoip_module_matches(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='se') + request = rf.get('/', REMOTE_ADDR='123.120.0.2') + GeoIP2Mock = MagicMock() + GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'}) + GeoIP2Mock.reset_mock() + with patch('wagtail_personalisation.rules.get_geoip_module', + return_value=GeoIP2Mock): + assert rule.test_user(request) is True + assert GeoIP2Mock.mock_calls == [ + call(), + call().country_code('123.120.0.2'), + ] + + +@skip_if_geoip2_not_installed +@pytest.mark.django_db +def test_test_user_geoip_module_does_not_match(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='nl') + request = rf.get('/', REMOTE_ADDR='123.120.0.2') + GeoIP2Mock = MagicMock() + GeoIP2Mock().configure_mock(**{'country_code.return_value': 'SE'}) + GeoIP2Mock.reset_mock() + with patch('wagtail_personalisation.rules.get_geoip_module', + return_value=GeoIP2Mock): + assert rule.test_user(request) is False + assert GeoIP2Mock.mock_calls == [ + call(), + call().country_code('123.120.0.2') + ] + + +@skip_if_geoip2_installed +@pytest.mark.django_db +def test_test_user_does_not_use_geoip_module_if_disabled(rf): + segment = SegmentFactory(name='Test segment') + rule = OriginCountryRuleFactory(segment=segment, country='se') + request = rf.get('/', REMOTE_ADDR='123.120.0.2') + assert rule.test_user(request) is False + + +@skip_if_geoip2_installed +def test_get_geoip_module_disabled(): + with pytest.raises(ImportError): + from django.contrib.gis.geoip2 import GeoIP2 # noqa + assert get_geoip_module() is None + + +@skip_if_geoip2_not_installed +def test_get_geoip_module_enabled(): + from django.contrib.gis.geoip2 import GeoIP2 + assert get_geoip_module() is GeoIP2 diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index be2131c..1d0c182 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,8 +1,10 @@ import pytest +from django.test import override_settings + from tests.factories.page import ContentPageFactory from wagtail_personalisation.utils import ( - can_delete_pages, impersonate_other_page) + can_delete_pages, get_client_ip, impersonate_other_page) @pytest.fixture @@ -36,3 +38,29 @@ def test_can_delete_pages_with_superuser(rf, user, segmented_page): @pytest.mark.django_db def test_cannot_delete_pages_with_standard_user(user, segmented_page): assert not can_delete_pages([segmented_page], user) + + +def test_get_client_ip_with_remote_addr(rf): + request = rf.get('/', REMOTE_ADDR='173.231.235.87') + assert get_client_ip(request) == '173.231.235.87' + + +def test_get_client_ip_with_x_forwarded_for(rf): + request = rf.get('/', HTTP_X_FORWARDED_FOR='173.231.235.87', + REMOTE_ADDR='10.0.23.24') + assert get_client_ip(request) == '173.231.235.87' + + +@override_settings( + WAGTAIL_PERSONALISATION_IP_FUNCTION='some.non.existent.path' +) +def test_get_client_ip_custom_get_client_ip_function_does_not_exist(rf): + with pytest.raises(ImportError): + get_client_ip(rf.get('/')) + + +@override_settings( + WAGTAIL_PERSONALISATION_IP_FUNCTION='tests.utils.get_custom_ip' +) +def test_get_client_ip_custom_get_client_ip_used(rf): + assert get_client_ip(rf.get('/')) == '123.123.123.123' diff --git a/tests/utils.py b/tests/utils.py index c7ed3ef..e1a5e49 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,3 +5,7 @@ def render_template(value, **context): template = engines['django'].from_string(value) request = context.pop('request', None) return template.render(context, request) + + +def get_custom_ip(request): + return '123.123.123.123' diff --git a/tox.ini b/tox.ini index cd3fed6..159f784 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,15 @@ [tox] -envlist = py{36}-django{20}-wagtail{20,21},lint +envlist = py{36}-django{20}-wagtail{20,21}{,-geoip2},lint [testenv] basepython = python3.6 -commands = coverage run --parallel -m pytest {posargs} +commands = coverage run --parallel -m pytest -rs {posargs} extras = test deps = django20: django>=2.0,<2.1 wagtail20: wagtail>=2.0,<2.1 wagtail21: wagtail>=2.1,<2.2 + geoip2: geoip2 [testenv:coverage-report] basepython = python3.6