diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 622b1df..15f356a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -50,7 +50,7 @@ jobs: python -m tests.test_server & python -m pytest \ - --ignore=tests/functional/test_function.py \ + --ignore=tests/functional/test_functional.py \ --ignore=tests/test_server.py \ --cov requests_ntlm \ --cov-report term-missing \ diff --git a/README.rst b/README.rst index 09f4d4d..4ba3fca 100644 --- a/README.rst +++ b/README.rst @@ -41,10 +41,10 @@ Requirements ------------ - requests_ -- ntlm-auth_ +- pyspnego_ .. _requests: https://github.com/kennethreitz/requests/ -.. _ntlm-auth: https://github.com/jborean93/ntlm-auth +.. _ntlm-auth: https://github.com/jborean93/pyspnego/ Authors ------- diff --git a/requests_ntlm/requests_ntlm.py b/requests_ntlm/requests_ntlm.py index 23db0cc..b800a8b 100644 --- a/requests_ntlm/requests_ntlm.py +++ b/requests_ntlm/requests_ntlm.py @@ -1,14 +1,38 @@ -import binascii -import sys import warnings +import base64 +import typing as t from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.exceptions import UnsupportedAlgorithm -from ntlm_auth import ntlm from requests.auth import AuthBase from requests.packages.urllib3.response import HTTPResponse +import spnego + + +class ShimSessionSecurity: + """Shim used for backwards compatibility with ntlm-auth.""" + + def __init__(self, context: spnego.ContextProxy) -> None: + self._context = context + + def wrap(self, message) -> t.Tuple[bytes, bytes]: + wrap_res = self._context.wrap(message, encrypt=True) + signature = wrap_res.data[:16] + data = wrap_res.data[16:] + + return data, signature + + def unwrap(self, message: bytes, signature: bytes) -> bytes: + data = signature + message + return self._context.unwrap(data).data + + def get_signature(self, message: bytes) -> bytes: + return self._context.sign(message) + + def verify_signature(self, message: bytes, signature: bytes) -> None: + self._context.verify(message, signature) class HttpNtlmAuth(AuthBase): @@ -26,18 +50,7 @@ def __init__(self, username, password, session=None, send_cbt=True): :param str session: Unused. Kept for backwards-compatibility. :param bool send_cbt: Will send the channel bindings over a HTTPS channel (Default: True) """ - if ntlm is None: - raise Exception("NTLM libraries unavailable") - - # parse the username - try: - self.domain, self.username = username.split('\\', 1) - except ValueError: - self.username = username - self.domain = '' - - if self.domain: - self.domain = self.domain.upper() + self.username = username self.password = password self.send_cbt = send_cbt @@ -46,18 +59,30 @@ def __init__(self, username, password, session=None, send_cbt=True): # call requests_ntlm to encrypt and decrypt the messages sent after authentication self.session_security = None - def retry_using_http_NTLM_auth(self, auth_header_field, auth_header, - response, auth_type, args): + def retry_using_http_NTLM_auth( + self, + auth_header_field, + auth_header, + response, + auth_type, + args, + ): # Get the certificate of the server if using HTTPS for CBT server_certificate_hash = self._get_server_cert(response) + cbt = None + if server_certificate_hash: + cbt = spnego.channel_bindings.GssChannelBindings( + application_data=b"tls-server-end-point:" + server_certificate_hash + ) """Attempt to authenticate using HTTP NTLM challenge/response.""" if auth_header in response.request.headers: return response content_length = int( - response.request.headers.get('Content-Length', '0'), base=10) - if hasattr(response.request.body, 'seek'): + response.request.headers.get("Content-Length", "0"), base=10 + ) + if hasattr(response.request.body, "seek"): if content_length > 0: response.request.body.seek(-content_length, 1) else: @@ -69,11 +94,16 @@ def retry_using_http_NTLM_auth(self, auth_header_field, auth_header, response.raw.release_conn() request = response.request.copy() - # ntlm returns the headers as a base64 encoded bytestring. Convert to - # a string. - context = ntlm.Ntlm() - negotiate_message = context.create_negotiate_message(self.domain).decode('ascii') - auth = u'%s %s' % (auth_type, negotiate_message) + client = spnego.client( + self.username, + self.password, + protocol="ntlm", + channel_bindings=cbt, + ) + # Perform the first step of the NTLM authentication + negotiate_message = base64.b64encode(client.step()).decode() + auth = "%s %s" % (auth_type, negotiate_message) + request.headers[auth_header] = auth # A streaming response breaks authentication. @@ -96,42 +126,34 @@ def retry_using_http_NTLM_auth(self, auth_header_field, auth_header, # this is important for some web applications that store # authentication-related info in cookies (it took a long time to # figure out) - if response2.headers.get('set-cookie'): - request.headers['Cookie'] = response2.headers.get('set-cookie') + if response2.headers.get("set-cookie"): + request.headers["Cookie"] = response2.headers.get("set-cookie") # get the challenge auth_header_value = response2.headers[auth_header_field] - auth_strip = auth_type + ' ' + auth_strip = auth_type + " " ntlm_header_value = next( - s for s in (val.lstrip() for val in auth_header_value.split(',')) + s + for s in (val.lstrip() for val in auth_header_value.split(",")) if s.startswith(auth_strip) ).strip() - # Parse the challenge in the ntlm context - context.parse_challenge_message(ntlm_header_value[len(auth_strip):]) + # Parse the challenge in the ntlm context and perform + # the second step of authentication + val = base64.b64decode(ntlm_header_value[len(auth_strip) :].encode()) + authenticate_message = base64.b64encode(client.step(val)).decode() - # build response - # Get the response based on the challenge message - authenticate_message = context.create_authenticate_message( - self.username, - self.password, - self.domain, - server_certificate_hash=server_certificate_hash - ) - authenticate_message = authenticate_message.decode('ascii') - auth = u'%s %s' % (auth_type, authenticate_message) + auth = "%s %s" % (auth_type, authenticate_message) request.headers[auth_header] = auth response3 = response2.connection.send(request, **args) - # Update the history. response3.history.append(response) response3.history.append(response2) - # Get the session_security object created by ntlm-auth for signing and sealing of messages - self.session_security = context.session_security + self.session_security = ShimSessionSecurity(client) return response3 @@ -139,30 +161,28 @@ def response_hook(self, r, **kwargs): """The actual hook handler.""" if r.status_code == 401: # Handle server auth. - www_authenticate = r.headers.get('www-authenticate', '').lower() + www_authenticate = r.headers.get("www-authenticate", "").lower() auth_type = _auth_type_from_header(www_authenticate) if auth_type is not None: return self.retry_using_http_NTLM_auth( - 'www-authenticate', - 'Authorization', + "www-authenticate", + "Authorization", r, auth_type, - kwargs + kwargs, ) elif r.status_code == 407: # If we didn't have server auth, do proxy auth. - proxy_authenticate = r.headers.get( - 'proxy-authenticate', '' - ).lower() + proxy_authenticate = r.headers.get("proxy-authenticate", "").lower() auth_type = _auth_type_from_header(proxy_authenticate) if auth_type is not None: return self.retry_using_http_NTLM_auth( - 'proxy-authenticate', - 'Proxy-authorization', + "proxy-authenticate", + "Proxy-authorization", r, auth_type, - kwargs + kwargs, ) return r @@ -179,36 +199,31 @@ def _get_server_cert(self, response): :return: The hash of the DER encoded certificate at the request_url or None if not a HTTPS endpoint """ if self.send_cbt: - certificate_hash = None raw_response = response.raw if isinstance(raw_response, HTTPResponse): - if sys.version_info > (3, 0): - socket = raw_response._fp.fp.raw._sock - else: - socket = raw_response._fp.fp._sock + socket = raw_response._fp.fp.raw._sock try: server_certificate = socket.getpeercert(True) except AttributeError: pass else: - certificate_hash = _get_certificate_hash(server_certificate) + return _get_certificate_hash(server_certificate) else: warnings.warn( "Requests is running with a non urllib3 backend, cannot retrieve server certificate for CBT", - NoCertificateRetrievedWarning) + NoCertificateRetrievedWarning, + ) - return certificate_hash - else: - return None + return None def __call__(self, r): # we must keep the connection because NTLM authenticates the # connection, not single requests r.headers["Connection"] = "Keep-Alive" - r.register_hook('response', self.response_hook) + r.register_hook("response", self.response_hook) return r @@ -218,10 +233,10 @@ def _auth_type_from_header(header): authentication type to use. We prefer NTLM over Negotiate if the server suppports it. """ - if 'ntlm' in header: - return 'NTLM' - elif 'negotiate' in header: - return 'Negotiate' + if "ntlm" in header: + return "NTLM" + elif "negotiate" in header: + return "Negotiate" return None @@ -233,22 +248,24 @@ def _get_certificate_hash(certificate_der): try: hash_algorithm = cert.signature_hash_algorithm except UnsupportedAlgorithm as ex: - warnings.warn("Failed to get signature algorithm from certificate, " - "unable to pass channel bindings: %s" % str(ex), UnknownSignatureAlgorithmOID) + warnings.warn( + "Failed to get signature algorithm from certificate, " + "unable to pass channel bindings: %s" % str(ex), + UnknownSignatureAlgorithmOID, + ) return None # if the cert signature algorithm is either md5 or sha1 then use sha256 # otherwise use the signature algorithm - if hash_algorithm.name in ['md5', 'sha1']: + if hash_algorithm.name in ["md5", "sha1"]: digest = hashes.Hash(hashes.SHA256(), default_backend()) else: digest = hashes.Hash(hash_algorithm, default_backend()) digest.update(certificate_der) certificate_hash_bytes = digest.finalize() - certificate_hash = binascii.hexlify(certificate_hash_bytes).decode().upper() - return certificate_hash + return certificate_hash_bytes class NoCertificateRetrievedWarning(Warning): diff --git a/requirements.txt b/requirements.txt index ede853a..c8e2e77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests>=2.0.0 -ntlm-auth>=1.0.2 +pyspnego cryptography>=1.3 flask pytest diff --git a/setup.py b/setup.py index d54d929..46c0e5d 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ name="requests_ntlm", version="1.2.0", packages=["requests_ntlm"], - install_requires=["requests>=2.0.0", "ntlm-auth>=1.0.2", "cryptography>=1.3"], + install_requires=["requests>=2.0.0", "pyspnego>=0.1.6", "cryptography>=1.3"], python_requires=">=3.7", provides=["requests_ntlm"], author="Ben Toews", diff --git a/tests/test_utils.py b/tests/test_utils.py index 173846f..c157e21 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,3 +2,6 @@ username = 'username' domain = 'domain' password = 'password' + +# Genearated online as hashlib.md4 may not be available anymore +password_md4 = '8a9d093f14f8701df17732b2bb182c74' \ No newline at end of file diff --git a/tests/unit/test_requests_ntlm.py b/tests/unit/test_requests_ntlm.py index bebcaf5..b3acc0f 100644 --- a/tests/unit/test_requests_ntlm.py +++ b/tests/unit/test_requests_ntlm.py @@ -4,7 +4,7 @@ import requests_ntlm import warnings -from tests.test_utils import domain, username, password +from tests.test_utils import domain, username, password, password_md4 class TestRequestsNtlm(unittest.TestCase): @@ -13,7 +13,8 @@ def setUp(self): self.test_server_url = 'http://localhost:5000/' self.test_server_username = '%s\\%s' % (domain, username) self.test_server_password = password - self.auth_types = ['ntlm','negotiate','both'] + self.auth_types = ['ntlm', 'negotiate', 'both'] + self.hash = password_md4 def test_requests_ntlm(self): for auth_type in self.auth_types: @@ -25,6 +26,17 @@ def test_requests_ntlm(self): self.assertEqual(res.status_code,200, msg='auth_type ' + auth_type) + def test_requests_ntlm_hash(self): + # Test authenticating using an NTLM hash + for auth_type in self.auth_types: + res = requests.get(\ + url = self.test_server_url + auth_type,\ + auth = requests_ntlm.HttpNtlmAuth( + self.test_server_username,\ + "0" * 32 + ":" + self.hash)) + + self.assertEqual(res.status_code,200, msg='auth_type ' + auth_type) + def test_history_is_preserved(self): for auth_type in self.auth_types: res = requests.get(url=self.test_server_url + auth_type, @@ -42,46 +54,6 @@ def test_new_requests_are_used(self): self.assertTrue(res.history[0].request is not res.history[1].request) self.assertTrue(res.history[0].request is not res.request) - def test_username_parse_backslash(self): - test_user = 'domain\\user' - expected_domain = 'DOMAIN' - expected_user = 'user' - - context = requests_ntlm.HttpNtlmAuth(test_user, 'pass') - - actual_domain = context.domain - actual_user = context.username - - assert actual_domain == expected_domain - assert actual_user == expected_user - - def test_username_parse_at(self): - test_user = 'user@domain.com' - # UPN format should not be split, since "stuff after @" not always == domain - # (eg, email address with alt UPN suffix) - expected_domain = '' - expected_user = 'user@domain.com' - - context = requests_ntlm.HttpNtlmAuth(test_user, 'pass') - - actual_domain = context.domain - actual_user = context.username - - assert actual_domain == expected_domain - assert actual_user == expected_user - - def test_username_parse_no_domain(self): - test_user = 'user' - expected_domain = '' - expected_user = 'user' - - context = requests_ntlm.HttpNtlmAuth(test_user, 'pass') - - actual_domain = context.domain - actual_user = context.username - - assert actual_domain == expected_domain - assert actual_user == expected_user class TestCertificateHash(unittest.TestCase): @@ -110,7 +82,7 @@ def test_rsa_md5(self): expected_hash = '2334B8476CBF4E6DFC766A5D5A30D6649C01BAE1662A5C3A130' \ '2A968D7C6B0F6' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_rsa_sha1(self): cert_der = b'MIIDGzCCAgOgAwIBAgIQJg/Mf5sR55xApJRK+kabbTANBgkqhkiG9w0' \ @@ -137,7 +109,7 @@ def test_rsa_sha1(self): expected_hash = '14CFE8E4B332B20A343FC840B18F9F6F78926AFE7EC3E7B8E28' \ '969619B1E8F3E' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_rsa_sha256(self): cert_der = b'MIIDGzCCAgOgAwIBAgIQWkeAtqoFg6pNWF7xC4YXhTANBgkqhkiG9w0' \ @@ -164,7 +136,7 @@ def test_rsa_sha256(self): expected_hash = '996F3EEA812C1870E30549FF9B86CD87A890B6D8DFDF4A81BEF' \ '9675970DADB26' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_rsa_sha384(self): cert_der = b'MIIDGzCCAgOgAwIBAgIQEmj1prSSQYRL2zYBEjsm5jANBgkqhkiG9w0' \ @@ -191,7 +163,7 @@ def test_rsa_sha384(self): expected_hash = '34F303C995286F4B214A9BA6435B69B51ECF3758EABC2A14D7A' \ '43FD237DC2B1A1AD9111C5C965E107507CB4198C09FEC' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_rsa_sha512(self): cert_der = b'MIIDGzCCAgOgAwIBAgIQUDHcKGevZohJV+TkIIYC1DANBgkqhkiG9w0' \ @@ -219,7 +191,7 @@ def test_rsa_sha512(self): '00544E1AD2B76FF25CFBE69B1C4E630C3BB0207DF11314C6738' \ 'BCAED7E071D7BFBF2C9DFAB85D' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_ecdsa_sha1(self): cert_der = b'MIIBjjCCATSgAwIBAgIQRCJw7nbtvJ5F8wikRmwgizAJBgcqhkjOPQQ' \ @@ -236,7 +208,7 @@ def test_ecdsa_sha1(self): expected_hash = '1EC9AD46DEE9340E4503CFFDB5CD810CB26B778F46BE95D5EAF' \ '999DCB1C45EDA' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_ecdsa_sha256(self): cert_der = b'MIIBjzCCATWgAwIBAgIQeNQTxkMgq4BF9tKogIGXUTAKBggqhkjOPQQ' \ @@ -253,7 +225,7 @@ def test_ecdsa_sha256(self): expected_hash = 'FECF1B2585449990D9E3B2C92D3F597EC8354E124EDA751D948' \ '37C2C89A2C155' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_ecdsa_sha384(self): cert_der = b'MIIBjzCCATWgAwIBAgIQcO3/jALdQ6BOAoaoseLSCjAKBggqhkjOPQQ' \ @@ -270,7 +242,7 @@ def test_ecdsa_sha384(self): expected_hash = 'D2987AD8F20E8316A831261B74EF7B3E55155D0922E07FFE546' \ '20806982B68A73A5E3C478BAA5E7714135CB26D980749' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_ecdsa_sha512(self): cert_der = b'MIIBjjCCATWgAwIBAgIQHVj2AGEwd6pOOSbcf0skQDAKBggqhkjOPQQ' \ @@ -288,7 +260,7 @@ def test_ecdsa_sha512(self): 'F19A5BD8F0B2FAAC861855FBB63A221CC46FC1E226A072411AF' \ '175DDE479281E006878B348059' actual_hash = requests_ntlm.requests_ntlm._get_certificate_hash(base64.b64decode(cert_der)) - assert actual_hash == expected_hash + assert actual_hash == base64.b16decode(expected_hash) def test_invalid_signature_algorithm(self): # Manually edited from test_ecdsa_sha512 to change the OID to '1.2.840.10045.4.3.5'