diff --git a/CHANGES.txt b/CHANGES.txt index 4a9e019f338..a59d789ef40 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -4,6 +4,9 @@ CHANGES 0.16.0 (XX-XX-XXXX) ------------------- +- Support new `verify_fingerprint` param of TCPConnector to enable verifying + ssl certificates via md5, sha1, or sha256 fingerprint + - Setup uploaded filename if field value is binary and transfer encoding is not specified #349 diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 392309b2166..42f72678b93 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -8,7 +8,9 @@ import traceback import warnings +from binascii import hexlify, unhexlify from collections import defaultdict +from hashlib import md5, sha1, sha256 from itertools import chain from math import ceil @@ -17,6 +19,7 @@ from .errors import ServerDisconnectedError from .errors import HttpProxyError, ProxyConnectionError from .errors import ClientOSError, ClientTimeoutError +from .errors import FingerprintMismatch from .helpers import BasicAuth @@ -25,6 +28,12 @@ PY_34 = sys.version_info >= (3, 4) PY_343 = sys.version_info >= (3, 4, 3) +HASHFUNC_BY_DIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, +} + class Connection(object): @@ -347,13 +356,16 @@ class TCPConnector(BaseConnector): """TCP connector. :param bool verify_ssl: Set to True to check ssl certifications. + :param str verify_fingerprint: Set to the md5, sha1, or sha256 fingerprint + (as a hexadecimal string) of the expected certificate (DER-encoded) + to verify the cert matches. May be interspersed with colons. :param bool resolve: Set to True to do DNS lookup for host name. :param family: socket address family :param args: see :class:`BaseConnector` :param kwargs: see :class:`BaseConnector` """ - def __init__(self, *, verify_ssl=True, + def __init__(self, *, verify_ssl=True, verify_fingerprint=None, resolve=False, family=socket.AF_INET, ssl_context=None, **kwargs): super().__init__(**kwargs) @@ -364,6 +376,16 @@ def __init__(self, *, verify_ssl=True, "verify_ssl=False or specify ssl_context, not both.") self._verify_ssl = verify_ssl + + if verify_fingerprint: + verify_fingerprint = verify_fingerprint.replace(':', '').lower() + digestlen, odd = divmod(len(verify_fingerprint), 2) + if odd or digestlen not in HASHFUNC_BY_DIGESTLEN: + raise ValueError('Fingerprint is of invalid length.') + self._hashfunc = HASHFUNC_BY_DIGESTLEN[digestlen] + self._fingerprint_bytes = unhexlify(verify_fingerprint) + + self._verify_fingerprint = verify_fingerprint self._ssl_context = ssl_context self._family = family self._resolve = resolve @@ -374,6 +396,11 @@ def verify_ssl(self): """Do check for ssl certifications?""" return self._verify_ssl + @property + def verify_fingerprint(self): + """Verify ssl cert fingerprint matches?""" + return self._verify_fingerprint + @property def ssl_context(self): """SSLContext instance for https requests. @@ -464,11 +491,25 @@ def _create_connection(self, req): for hinfo in hosts: try: - return (yield from self._loop.create_connection( - self._factory, hinfo['host'], hinfo['port'], + host = hinfo['host'] + port = hinfo['port'] + conn = yield from self._loop.create_connection( + self._factory, host, port, ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None)) + server_hostname=hinfo['hostname'] if sslcontext else None) + if req.ssl and self._verify_fingerprint: + transport = conn[0] + sock = transport.get_extra_info('socket') + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + got = cert and self._hashfunc(cert).digest() + expected = self._fingerprint_bytes + if expected != got: + got = got and hexlify(got).decode('ascii') + expected = hexlify(expected).decode('ascii') + raise FingerprintMismatch(expected, got, host, port) + return conn except OSError as e: exc = e else: diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 5c148638c1f..2509645d51d 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -170,3 +170,13 @@ class LineLimitExceededParserError(ParserError): def __init__(self, msg, limit): super().__init__(msg) self.limit = limit + + +class FingerprintMismatch(Exception): + """SSL certificate does not match expected fingerprint.""" + + def __init__(self, expected, got, host, port): + self.expected = expected + self.got = got + self.host = host + self.port = port diff --git a/docs/client.rst b/docs/client.rst index 8304cf0de01..4c658698dac 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -404,13 +404,21 @@ If you need to setup custom ssl parameters (use own certification files for example) you can create a :class:`ssl.SSLContext` instance and pass it into the connector:: - >>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - >>> sslcontext.verify_mode = ssl.CERT_REQUIRED - >>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt") + >>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt') >>> conn = aiohttp.TCPConnector(ssl_context=sslcontext) >>> r = yield from aiohttp.request( ... 'get', 'https://example.com', connector=conn) +You may also verify certificates via fingerprint:: + + >>> fp = '...' # hex str of md5, sha1, or sha256 of expected cert (in DER) + >>> conn = aiohttp.TCPConnector(verify_fingerprint=fp) + >>> r = yield from aiohttp.request( + ... 'get', 'https://MITMed.com', connector=conn) + Traceback (most recent call last)\: + ... + FingerprintMismatch(...) + Unix domain sockets ------------------- diff --git a/tests/sample.crt.der b/tests/sample.crt.der new file mode 100644 index 00000000000..ce22b75b9e0 Binary files /dev/null and b/tests/sample.crt.der differ diff --git a/tests/test_connector.py b/tests/test_connector.py index fc81d559045..aa867bc57de 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -12,6 +12,7 @@ import aiohttp from aiohttp import client from aiohttp import test_utils +from aiohttp.errors import FingerprintMismatch from aiohttp.client import ClientResponse, ClientRequest from aiohttp.connector import Connection @@ -452,10 +453,30 @@ def test_cleanup3(self): def test_tcp_connector_ctor(self): conn = aiohttp.TCPConnector(loop=self.loop) self.assertTrue(conn.verify_ssl) + self.assertIs(conn.verify_fingerprint, None) self.assertFalse(conn.resolve) self.assertEqual(conn.family, socket.AF_INET) self.assertEqual(conn.resolved_hosts, {}) + def test_tcp_connector_verify_fingerprint(self): + # sha1 fingerprint of ./sample.crt.der + fpgood = '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9' + fpbad = 'badbadbadbadbadbadbadbadbadbadbadbadbad1' + for fp in (fpgood, fpbad): + conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False, + verify_fingerprint=fp) + with test_utils.run_server(self.loop, use_ssl=True) as httpd: + coro = client.request('get', httpd.url('method', 'get'), + connector=conn, loop=self.loop) + if fp == fpgood: + # should not raise + self.loop.run_until_complete(coro) + else: + with self.assertRaises(FingerprintMismatch) as cm: + self.loop.run_until_complete(coro) + self.assertEqual(cm.exception.expected, fpbad) + self.assertEqual(cm.exception.got, fpgood) + def test_tcp_connector_clear_resolved_hosts(self): conn = aiohttp.TCPConnector(loop=self.loop) info = object()