Skip to content

Commit

Permalink
Configure ip headers (#58)
Browse files Browse the repository at this point in the history
Adds a way of configuring the SDK to extract IP headers.
  • Loading branch information
joladev authored Feb 13, 2020
1 parent 464243a commit a0455e3
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- [#59](https://github.com/castle/castle-python/pull/59) drop requests min version in ci
- [#56](https://github.com/castle/castle-python/pull/56) drop special ip header behavior
- [#58](https://github.com/castle/castle-python/pull/58) Adds `ip_header` configuration option

### Breaking Changes:

Expand Down
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ import and configure the library with your Castle API secret.
# some headers are always scrubbed, for security reasons.
configuration.blacklisted = ['HTTP-X-header']
# Castle needs the original IP of the client, not the IP of your proxy or load balancer.
# If that IP is sent as a header you can configure the SDK to extract it automatically.
# Note that format, it should be prefixed with `HTTP`, capitalized and separated by underscores.
configuration.ip_headers = ["HTTP_X_FORWARDED_FOR"]
Tracking
--------

Expand Down
12 changes: 12 additions & 0 deletions castle/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self):
self.blacklisted = []
self.request_timeout = REQUEST_TIMEOUT
self.failover_strategy = 'allow'
self.ip_headers = []

@property
def api_secret(self):
Expand Down Expand Up @@ -111,6 +112,17 @@ def failover_strategy(self, value):
else:
raise ConfigurationError

@property
def ip_headers(self):
return self.__ip_headers

@ip_headers.setter
def ip_headers(self, value):
if isinstance(value, list):
self.__ip_headers = value
else:
raise ConfigurationError


# pylint: disable=invalid-name
configuration = Configuration()
14 changes: 14 additions & 0 deletions castle/extractors/ip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from castle.configuration import configuration


class ExtractorsIp(object):
def __init__(self, request):
self.request = request

def call(self):
ip_address = self.get_ip_from_headers()
if ip_address:
return ip_address

if hasattr(self.request, 'ip'):
return self.request.ip

return self.request.environ.get('REMOTE_ADDR')

def get_ip_from_headers(self):
for header in configuration.ip_headers:
value = self.request.environ.get(header)
if value:
return value
return None
12 changes: 12 additions & 0 deletions castle/test/configuration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_default_values(self):
self.assertEqual(config.blacklisted, [])
self.assertEqual(config.request_timeout, 500)
self.assertEqual(config.failover_strategy, 'allow')
self.assertEqual(config.ip_headers, [])

def test_api_secret_setter(self):
config = Configuration()
Expand Down Expand Up @@ -80,3 +81,14 @@ def test_failover_strategy_setter_invalid(self):
config = Configuration()
with self.assertRaises(ConfigurationError):
config.failover_strategy = 'invalid'

def test_ip_headers_setter_valid(self):
config = Configuration()
ip_headers = ['HTTP_X_FORWARDED_FOR']
config.ip_headers = ip_headers
self.assertEqual(config.ip_headers, ip_headers)

def test_ip_headers_setter_invalid(self):
config = Configuration()
with self.assertRaises(ConfigurationError):
config.ip_headers = 'invalid'
33 changes: 33 additions & 0 deletions castle/test/extractors/ip_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from castle.test import unittest, mock
from castle.extractors.ip import ExtractorsIp
from castle.configuration import configuration


def request_ip():
Expand All @@ -22,7 +23,23 @@ def request_with_ip_remote_addr():
return req


def request_with_ip_x_forwarded_for():
req = mock.Mock(spec=['environ'])
req.environ = {'HTTP_X_FORWARDED_FOR': request_ip()}
return req


def request_with_ip_cf_connecting_ip():
req = mock.Mock(spec=['environ'])
req.environ = {'HTTP_CF_CONNECTING_IP': request_ip_next()}
return req


class ExtractorsIpTestCase(unittest.TestCase):
@classmethod
def tearDownClass(cls):
configuration.ip_headers = []

def test_extract_ip(self):
self.assertEqual(ExtractorsIp(request()).call(), request_ip())

Expand All @@ -31,3 +48,19 @@ def test_extract_ip_from_wsgi_request_remote_addr(self):
ExtractorsIp(request_with_ip_remote_addr()).call(),
request_ip()
)

def test_extract_ip_from_wsgi_request_configured_ip_header_first(self):
configuration.ip_headers = ["HTTP_CF_CONNECTING_IP"]
self.assertEqual(
ExtractorsIp(request_with_ip_cf_connecting_ip()).call(),
request_ip_next()
)
configuration.ip_headers = []

def test_extract_ip_from_wsgi_request_configured_ip_header_second(self):
configuration.ip_headers = ["HTTP_CF_CONNECTING_IP", "HTTP_X_FORWARDED_FOR"]
self.assertEqual(
ExtractorsIp(request_with_ip_x_forwarded_for()).call(),
request_ip()
)
configuration.ip_headers = []

0 comments on commit a0455e3

Please sign in to comment.