diff --git a/jwt/algorithms.py b/jwt/algorithms.py index 9c1a7e803..e3d0a3a83 100644 --- a/jwt/algorithms.py +++ b/jwt/algorithms.py @@ -1,9 +1,14 @@ import hashlib import hmac +import json from .compat import constant_time_compare, string_types, text_type from .exceptions import InvalidKeyError -from .utils import der_to_raw_signature, raw_to_der_signature +from .utils import ( + base64url_encode, base64url_decode, + to_base64url_uint, from_base64url_uint, + der_to_raw_signature, raw_to_der_signature +) try: from cryptography.hazmat.primitives import hashes @@ -11,7 +16,8 @@ load_pem_private_key, load_pem_public_key, load_ssh_public_key ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey + RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers, + rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp ) from cryptography.hazmat.primitives.asymmetric.ec import ( EllipticCurvePrivateKey, EllipticCurvePublicKey @@ -77,6 +83,18 @@ def verify(self, msg, key, sig): """ raise NotImplementedError + def to_jwk(self, key_obj): + """ + Serializes a given RSA key into a JWK + """ + raise NotImplementedError + + def from_jwk(self, jwk): + """ + Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + """ + return NotImplementedError + class NoneAlgorithm(Algorithm): """ @@ -92,6 +110,14 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + return {} + + @staticmethod + def from_jwk(jwk): + return None + def sign(self, msg, key): return b'' @@ -131,6 +157,22 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + return json.dumps({ + 'k': base64url_encode(key_obj), + 'typ': 'oct' + }) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'oct': + raise InvalidKeyError('Not an HMAC key') + + return base64url_decode(obj['k']) + def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() @@ -172,6 +214,101 @@ def prepare_key(self, key): return key + @staticmethod + def to_jwk(key_obj): + obj = None + + if getattr(key_obj, 'private_numbers', None): + # Private key + numbers = key_obj.private_numbers() + + obj = { + 'kty': 'RSA', + 'key_ops': ['sign'], + 'd': to_base64url_uint(numbers.d), + 'p': to_base64url_uint(numbers.p), + 'q': to_base64url_uint(numbers.q), + 'dp': to_base64url_uint(numbers.dmp1), + 'dq': to_base64url_uint(numbers.dmq1), + 'qi': to_base64url_uint(numbers.iqmp) + } + + elif getattr(key_obj, 'verifier', None): + # Public key + numbers = key_obj.public_numbers() + + obj = { + 'kty': 'RSA', + 'use': 'sig', + 'key_ops': ['verify'], + 'n': to_base64url_uint(numbers.n), + 'e': to_base64url_uint(numbers.e) + } + else: + raise InvalidKeyError('Not a public or private key') + + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk): + obj = json.loads(jwk) + + if obj.get('kty') != 'RSA': + raise InvalidKeyError('Not an RSA key') + + if 'd' in obj and 'e' in obj and 'n' in obj: + # Private key + if 'oth' in obj: + raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported') + + other_props = ['p', 'q', 'dp', 'dq', 'qi'] + props_found = [True for prop in other_props if prop in obj] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError('RSA key must include all parameters if any are present besides d') + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj['d']), + p=from_base64url_uint(obj['p']), + q=from_base64url_uint(obj['q']), + dmp1=from_base64url_uint(obj['dp']), + dmq1=from_base64url_uint(obj['dq']), + iqmp=from_base64url_uint(obj['qi']), + public_numbers=public_numbers + ) + else: + p, q = rsa_recover_prime_factors( + public_numbers.n, public_numbers.d, public_numbers.e + ) + d = from_base64url_uint(obj['d']) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers + ) + + return numbers.private_key(default_backend()) + elif 'n' in obj and 'e' in obj: + # Public key + numbers = RSAPublicNumbers( + from_base64url_uint(obj['e']), from_base64url_uint(obj['n']) + ) + + return numbers.public_key(default_backend()) + else: + raise InvalidKeyError('Not a public or private key') + def sign(self, msg, key): signer = key.signer( padding.PKCS1v15(), diff --git a/jwt/utils.py b/jwt/utils.py index 637b89299..d1f2c8a9d 100644 --- a/jwt/utils.py +++ b/jwt/utils.py @@ -1,5 +1,8 @@ import base64 import binascii +import struct + +from .compat import text_type try: from cryptography.hazmat.primitives.asymmetric.utils import ( @@ -10,6 +13,9 @@ def base64url_decode(input): + if isinstance(input, text_type): + input = input.encode('ascii') + rem = len(input) % 4 if rem > 0: @@ -22,6 +28,35 @@ def base64url_encode(input): return base64.urlsafe_b64encode(input).replace(b'=', b'') +def to_base64url_uint(val): + if val < 0: + raise ValueError('Must be a positive integer') + + buf = [] + while val: + val, remainder = divmod(val, 256) + buf.append(remainder) + + buf.reverse() + + data = struct.pack('%sB' % len(buf), *buf) + + if len(data) == 0: + data = '\x00' + + return base64url_encode(data) + + +def from_base64url_uint(val): + if isinstance(val, text_type): + val = val.encode('ascii') + + data = base64url_decode(val) + + buf = struct.unpack('%sB' % len(data), data) + return int(''.join(["%02x" % byte for byte in buf]), 16) + + def merge_dict(original, updates): if not updates: return original diff --git a/tests/keys/__init__.py b/tests/keys/__init__.py index fad09f57e..47bce70b3 100644 --- a/tests/keys/__init__.py +++ b/tests/keys/__init__.py @@ -20,10 +20,9 @@ def load_hmac_key(): return base64url_decode(ensure_bytes(keyobj['k'])) try: - from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.backends import default_backend - + from jwt.algorithms import RSAAlgorithm has_crypto = True except ImportError: has_crypto = False @@ -31,26 +30,11 @@ def load_hmac_key(): if has_crypto: def load_rsa_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPrivateNumbers( - p=decode_value(keyobj['p']), - q=decode_value(keyobj['q']), - d=decode_value(keyobj['d']), - dmp1=decode_value(keyobj['dp']), - dmq1=decode_value(keyobj['dq']), - iqmp=decode_value(keyobj['qi']), - public_numbers=load_rsa_pub_key().public_numbers() - ).private_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_rsa_pub_key(): with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile: - keyobj = json.load(infile) - - return rsa.RSAPublicNumbers( - n=decode_value(keyobj['n']), - e=decode_value(keyobj['e']) - ).public_key(default_backend()) + return RSAAlgorithm.from_jwk(infile.read()) def load_ec_key(): with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile: diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index e07185414..c16965988 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,4 +1,5 @@ import base64 +import json from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm from jwt.exceptions import InvalidKeyError @@ -88,6 +89,15 @@ def test_hmac_should_throw_exception_if_key_is_x509_cert(self): with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile: algo.prepare_key(keyfile.read()) + def test_hmac_jwk_public_and_private_keys_should_parse_and_verify(self): + algo = HMACAlgorithm(HMACAlgorithm.SHA256) + + with open(key_path('jwk_hmac.json'), 'r') as keyfile: + key = algo.from_jwk(keyfile.read()) + + signature = algo.sign('Hello World!', key) + assert algo.verify('Hello World!', key, signature) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_rsa_should_parse_pem_public_key(self): algo = RSAAlgorithm(RSAAlgorithm.SHA256) @@ -131,6 +141,19 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self): result = algo.verify(message, pub_key, sig) assert not result + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') + def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self): + algo = RSAAlgorithm(RSAAlgorithm.SHA256) + + with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile: + pub_key = algo.from_jwk(keyfile.read()) + + with open(key_path('jwk_rsa_key.json'), 'r') as keyfile: + priv_key = algo.from_jwk(keyfile.read()) + + signature = algo.sign('Hello World!', priv_key) + assert algo.verify('Hello World!', pub_key, signature) + @pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library') def test_ec_should_reject_non_string_key(self): algo = ECAlgorithm(ECAlgorithm.SHA256)