diff --git a/social_core/backends/apple.py b/social_core/backends/apple.py index e74e19e5a..99eb54f6e 100644 --- a/social_core/backends/apple.py +++ b/social_core/backends/apple.py @@ -125,7 +125,7 @@ def decode_id_token(self, id_token): id_token, key=public_key, audience=self.get_audience(), - algorithm='RS256', + algorithms=['RS256'], ) except PyJWTError: raise AuthFailed(self, 'Token validation failed') diff --git a/social_core/backends/azuread_b2c.py b/social_core/backends/azuread_b2c.py index a540a994d..997ad6b59 100644 --- a/social_core/backends/azuread_b2c.py +++ b/social_core/backends/azuread_b2c.py @@ -28,11 +28,9 @@ """ import json -import six from cryptography.hazmat.primitives import serialization -from jwt import DecodeError, ExpiredSignature, decode as jwt_decode -from jwt.utils import base64url_decode +from jwt import DecodeError, ExpiredSignature, decode as jwt_decode, get_unverified_header try: @@ -173,22 +171,16 @@ def user_data(self, access_token, *args, **kwargs): response = kwargs.get('response') id_token = response.get('id_token') - if six.PY2: - # str() to fix a bug in Python's base64 - # https://stackoverflow.com/a/2230623/161278 - id_token = str(id_token) - - jwt_header_json = base64url_decode(id_token.split('.')[0]) - jwt_header = json.loads(jwt_header_json.decode('ascii')) # `kid` is short for key id - key = self.get_public_key(jwt_header['kid']) + kid = get_unverified_header(id_token)['kid'] + key = self.get_public_key(kid) try: return jwt_decode( id_token, key=key, - algorithms=jwt_header['alg'], + algorithms=['RS256'], audience=self.setting('KEY'), leeway=self.setting('JWT_LEEWAY', default=0), ) diff --git a/social_core/backends/azuread_tenant.py b/social_core/backends/azuread_tenant.py index 7960ff1b5..0038dbada 100644 --- a/social_core/backends/azuread_tenant.py +++ b/social_core/backends/azuread_tenant.py @@ -1,9 +1,6 @@ -import base64 -import json - from cryptography.x509 import load_pem_x509_certificate from cryptography.hazmat.backends import default_backend -from jwt import DecodeError, ExpiredSignature, decode as jwt_decode +from jwt import DecodeError, ExpiredSignature, decode as jwt_decode, get_unverified_header from ..exceptions import AuthTokenError from .azuread import AzureADOAuth2 @@ -95,14 +92,8 @@ def get_user_id(self, details, response): def user_data(self, access_token, *args, **kwargs): id_token = access_token - # decode the JWT header as JSON dict - jwt_header = json.loads( - base64.b64decode(id_token.split('.', 1)[0]).decode() - ) - # get key id and algorithm - key_id = jwt_header['kid'] - algorithm = jwt_header['alg'] + key_id = get_unverified_header(id_token)['kid'] try: # retrieve certificate for key_id @@ -111,7 +102,7 @@ def user_data(self, access_token, *args, **kwargs): return jwt_decode( id_token, key=certificate.public_key(), - algorithms=algorithm, + algorithms=['RS256'], audience=self.setting('KEY') ) except (DecodeError, ExpiredSignature) as error: diff --git a/social_core/backends/open_id_connect.py b/social_core/backends/open_id_connect.py index 8d13c1909..520471866 100644 --- a/social_core/backends/open_id_connect.py +++ b/social_core/backends/open_id_connect.py @@ -44,6 +44,7 @@ class OpenIdConnectAuth(BaseOAuth2): REVOKE_TOKEN_URL = '' USERINFO_URL = '' JWKS_URI = '' + JWT_ALGORITHMS = ['RS256'] JWT_DECODE_OPTIONS = dict() def __init__(self, *args, **kwargs): @@ -162,14 +163,13 @@ def validate_and_return_id_token(self, id_token, access_token): if not key: raise AuthTokenError(self, 'Signature verification failed') - alg = key['alg'] rsakey = jwk.construct(key) try: claims = jwt.decode( id_token, rsakey.to_pem().decode('utf-8'), - algorithms=[alg], + algorithms=self.JWT_ALGORITHMS, audience=client_id, issuer=self.id_token_issuer(), access_token=access_token, diff --git a/social_core/tests/backends/test_keycloak.py b/social_core/tests/backends/test_keycloak.py index 84b18209e..a137afb05 100644 --- a/social_core/tests/backends/test_keycloak.py +++ b/social_core/tests/backends/test_keycloak.py @@ -95,7 +95,7 @@ def _encode( def _decode( token, key=_PUBLIC_KEY, - algorithms=_ALGORITHM, + algorithms=[_ALGORITHM], audience=_KEY, ): return jwt.decode(token, key=key, algorithms=algorithms, audience=audience)