Skip to content

Commit

Permalink
Add JWK support for HMAC and RSA keys
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-adams committed May 6, 2016
1 parent d363ae9 commit de2d6af
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 21 deletions.
141 changes: 139 additions & 2 deletions jwt/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import hashlib
import hmac
import json

from .compat import constant_time_compare, string_types, text_type
from .exceptions import InvalidKeyError
from .utils import der_to_raw_signature, raw_to_der_signature
from .utils import (
base64url_encode, base64url_decode,
to_base64url_uint, from_base64url_uint,
der_to_raw_signature, raw_to_der_signature
)

try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey, RSAPublicKey
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
Expand Down Expand Up @@ -77,6 +83,18 @@ def verify(self, msg, key, sig):
"""
raise NotImplementedError

def to_jwk(self, key_obj):
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError

def from_jwk(self, jwk):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
return NotImplementedError


class NoneAlgorithm(Algorithm):
"""
Expand All @@ -92,6 +110,14 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
return {}

@staticmethod
def from_jwk(jwk):
return None

def sign(self, msg, key):
return b''

Expand Down Expand Up @@ -131,6 +157,22 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
return json.dumps({
'k': base64url_encode(key_obj),
'typ': 'oct'
})

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')

return base64url_decode(obj['k'])

def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()

Expand Down Expand Up @@ -172,6 +214,101 @@ def prepare_key(self, key):

return key

@staticmethod
def to_jwk(key_obj):
obj = None

if getattr(key_obj, 'private_numbers', None):
# Private key
numbers = key_obj.private_numbers()

obj = {
'kty': 'RSA',
'key_ops': ['sign'],
'd': to_base64url_uint(numbers.d),
'p': to_base64url_uint(numbers.p),
'q': to_base64url_uint(numbers.q),
'dp': to_base64url_uint(numbers.dmp1),
'dq': to_base64url_uint(numbers.dmq1),
'qi': to_base64url_uint(numbers.iqmp)
}

elif getattr(key_obj, 'verifier', None):
# Public key
numbers = key_obj.public_numbers()

obj = {
'kty': 'RSA',
'use': 'sig',
'key_ops': ['verify'],
'n': to_base64url_uint(numbers.n),
'e': to_base64url_uint(numbers.e)
}
else:
raise InvalidKeyError('Not a public or private key')

return json.dumps(obj)

@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)

if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')

if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')

other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [True for prop in other_props if prop in obj]
any_props_found = any(props_found)

if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')

public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj['d']),
p=from_base64url_uint(obj['p']),
q=from_base64url_uint(obj['q']),
dmp1=from_base64url_uint(obj['dp']),
dmq1=from_base64url_uint(obj['dq']),
iqmp=from_base64url_uint(obj['qi']),
public_numbers=public_numbers
)
else:
p, q = rsa_recover_prime_factors(
public_numbers.n, public_numbers.d, public_numbers.e
)
d = from_base64url_uint(obj['d'])

numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers
)

return numbers.private_key(default_backend())
elif 'n' in obj and 'e' in obj:
# Public key
numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)

return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')

def sign(self, msg, key):
signer = key.signer(
padding.PKCS1v15(),
Expand Down
35 changes: 35 additions & 0 deletions jwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import base64
import binascii
import struct

from .compat import text_type

try:
from cryptography.hazmat.primitives.asymmetric.utils import (
Expand All @@ -10,6 +13,9 @@


def base64url_decode(input):
if isinstance(input, text_type):
input = input.encode('ascii')

rem = len(input) % 4

if rem > 0:
Expand All @@ -22,6 +28,35 @@ def base64url_encode(input):
return base64.urlsafe_b64encode(input).replace(b'=', b'')


def to_base64url_uint(val):
if val < 0:
raise ValueError('Must be a positive integer')

buf = []
while val:
val, remainder = divmod(val, 256)
buf.append(remainder)

buf.reverse()

data = struct.pack('%sB' % len(buf), *buf)

if len(data) == 0:
data = '\x00'

return base64url_encode(data)


def from_base64url_uint(val):
if isinstance(val, text_type):
val = val.encode('ascii')

data = base64url_decode(val)

buf = struct.unpack('%sB' % len(data), data)
return int(''.join(["%02x" % byte for byte in buf]), 16)


def merge_dict(original, updates):
if not updates:
return original
Expand Down
22 changes: 3 additions & 19 deletions tests/keys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,21 @@ def load_hmac_key():
return base64url_decode(ensure_bytes(keyobj['k']))

try:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.backends import default_backend

from jwt.algorithms import RSAAlgorithm
has_crypto = True
except ImportError:
has_crypto = False

if has_crypto:
def load_rsa_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_key.json'), 'r') as infile:
keyobj = json.load(infile)

return rsa.RSAPrivateNumbers(
p=decode_value(keyobj['p']),
q=decode_value(keyobj['q']),
d=decode_value(keyobj['d']),
dmp1=decode_value(keyobj['dp']),
dmq1=decode_value(keyobj['dq']),
iqmp=decode_value(keyobj['qi']),
public_numbers=load_rsa_pub_key().public_numbers()
).private_key(default_backend())
return RSAAlgorithm.from_jwk(infile.read())

def load_rsa_pub_key():
with open(os.path.join(BASE_PATH, 'jwk_rsa_pub.json'), 'r') as infile:
keyobj = json.load(infile)

return rsa.RSAPublicNumbers(
n=decode_value(keyobj['n']),
e=decode_value(keyobj['e'])
).public_key(default_backend())
return RSAAlgorithm.from_jwk(infile.read())

def load_ec_key():
with open(os.path.join(BASE_PATH, 'jwk_ec_key.json'), 'r') as infile:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import json

from jwt.algorithms import Algorithm, HMACAlgorithm, NoneAlgorithm
from jwt.exceptions import InvalidKeyError
Expand Down Expand Up @@ -88,6 +89,15 @@ def test_hmac_should_throw_exception_if_key_is_x509_cert(self):
with open(key_path('testkey2_rsa.pub.pem'), 'r') as keyfile:
algo.prepare_key(keyfile.read())

def test_hmac_jwk_public_and_private_keys_should_parse_and_verify(self):
algo = HMACAlgorithm(HMACAlgorithm.SHA256)

with open(key_path('jwk_hmac.json'), 'r') as keyfile:
key = algo.from_jwk(keyfile.read())

signature = algo.sign('Hello World!', key)
assert algo.verify('Hello World!', key, signature)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_should_parse_pem_public_key(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)
Expand Down Expand Up @@ -131,6 +141,19 @@ def test_rsa_verify_should_return_false_if_signature_invalid(self):
result = algo.verify(message, pub_key, sig)
assert not result

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_rsa_jwk_public_and_private_keys_should_parse_and_verify(self):
algo = RSAAlgorithm(RSAAlgorithm.SHA256)

with open(key_path('jwk_rsa_pub.json'), 'r') as keyfile:
pub_key = algo.from_jwk(keyfile.read())

with open(key_path('jwk_rsa_key.json'), 'r') as keyfile:
priv_key = algo.from_jwk(keyfile.read())

signature = algo.sign('Hello World!', priv_key)
assert algo.verify('Hello World!', pub_key, signature)

@pytest.mark.skipif(not has_crypto, reason='Not supported without cryptography library')
def test_ec_should_reject_non_string_key(self):
algo = ECAlgorithm(ECAlgorithm.SHA256)
Expand Down

0 comments on commit de2d6af

Please sign in to comment.