Skip to content

Commit

Permalink
migrate from python-jose to pyjwt
Browse files Browse the repository at this point in the history
  • Loading branch information
dlilue committed May 23, 2024
1 parent 562882b commit 6e67c7e
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 131 deletions.
4 changes: 2 additions & 2 deletions demo_project/api/api_v1/endpoints/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any

import httpx
import jwt
from demo_project.api.dependencies import azure_scheme
from demo_project.core.config import settings
from fastapi import APIRouter, Depends, Request
from httpx import AsyncClient
from jose import jwt

router = APIRouter()

Expand Down Expand Up @@ -47,7 +47,7 @@ async def graph_world(request: Request) -> Any: # noqa: ANN401

# Return all the information to the end user
return (
{'claims': jwt.get_unverified_claims(token=request.state.user.access_token)}
{'claims': jwt.decode(request.state.user.access_token, options={'verify_signature': False})}
| {'obo_response': obo_response.json()}
| {'graph_response': graph}
)
75 changes: 49 additions & 26 deletions fastapi_azure_auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,30 @@
from typing import Any, Awaitable, Callable, Dict, Literal, Optional
from warnings import warn

import jwt
from fastapi.exceptions import HTTPException
from fastapi.security import OAuth2AuthorizationCodeBearer, SecurityScopes
from fastapi.security.base import SecurityBase
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
from jwt.exceptions import (
ExpiredSignatureError,
InvalidTokenError,
InvalidAudienceError,
InvalidIssuerError,
InvalidIssuedAtError,
ImmatureSignatureError,
InvalidAlgorithmError,
MissingRequiredClaimError,
)
from starlette.requests import Request

from fastapi_azure_auth.exceptions import InvalidAuth
from fastapi_azure_auth.openid_config import OpenIdConfig
from fastapi_azure_auth.user import User
from fastapi_azure_auth.utils import is_guest
from fastapi_azure_auth.utils import (
is_guest,
get_unverified_header,
get_unverified_claims,
)

log = logging.getLogger('fastapi_azure_auth')

Expand Down Expand Up @@ -148,8 +161,8 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
access_token = await self.oauth(request=request)
try:
# Extract header information of the token.
header: dict[str, str] = jwt.get_unverified_header(token=access_token) or {}
claims: dict[str, Any] = jwt.get_unverified_claims(token=access_token) or {}
header: dict[str, Any] = get_unverified_header(access_token)
claims: dict[str, Any] = get_unverified_claims(access_token)
except Exception as error:
log.warning('Malformed token received. %s. Error: %s', access_token, error, exc_info=True)
raise InvalidAuth(detail='Invalid token format') from error
Expand Down Expand Up @@ -180,48 +193,44 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
try:
if key := self.openid_config.signing_keys.get(header.get('kid', '')):
# We require and validate all fields in an Azure AD token
required_claims = ['exp', 'aud', 'iat', 'nbf', 'sub']
if self.validate_iss:
required_claims.append('iss')

options = {
'verify_signature': True,
'verify_aud': True,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': self.validate_iss,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': True,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': True,
'require_iss': self.validate_iss,
'require_sub': True,
'require_jti': False,
'require_at_hash': False,
'leeway': self.leeway,
}
# Validate token
token = jwt.decode(
access_token,
token = self.validate(
access_token=access_token,
iss=iss,
key=key,
algorithms=['RS256'],
audience=self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}',
issuer=iss,
options=options,
)
options=options)
# Attach the user to the request. Can be accessed through `request.state.user`
user: User = User(
**{**token, 'claims': token, 'access_token': access_token, 'is_guest': user_is_guest}
)
request.state.user = user
return user
except JWTClaimsError as error:
except (
InvalidAudienceError,
InvalidIssuerError,
InvalidIssuedAtError,
ImmatureSignatureError,
InvalidAlgorithmError,
MissingRequiredClaimError
) as error:
log.info('Token contains invalid claims. %s', error)
raise InvalidAuth(detail='Token contains invalid claims') from error
except ExpiredSignatureError as error:
log.info('Token signature has expired. %s', error)
raise InvalidAuth(detail='Token signature has expired') from error
except JWTError as error:
except InvalidTokenError as error:
log.warning('Invalid token. Error: %s', error, exc_info=True)
raise InvalidAuth(detail='Unable to validate token') from error
except Exception as error:
Expand All @@ -235,6 +244,20 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
return None
raise

def validate(self, access_token: str, key: str, iss: str, options: Dict[str, Any]) -> Dict[str, Any]:
alg = 'RS256'
aud = self.app_client_id if self.token_version == 2 else f'api://{self.app_client_id}'
return jwt.decode(
access_token,
key=key,
algorithms=[alg],
audience=aud,
issuer=iss,
leeway=self.leeway,
options=options,
)



class SingleTenantAzureAuthorizationCodeBearer(AzureAuthorizationCodeBearerBase):
def __init__(
Expand Down
6 changes: 3 additions & 3 deletions fastapi_azure_auth/openid_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional

import jwt
from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes as KeyTypes
from fastapi import HTTPException, status
from httpx import AsyncClient
from jose import jwk

log = logging.getLogger('fastapi_azure_auth')

Expand Down Expand Up @@ -98,6 +98,6 @@ def _load_keys(self, keys: List[Dict[str, Any]]) -> None:
for key in keys:
if key.get('use') == 'sig': # Only care about keys that are used for signatures, not encryption
log.debug('Loading public key from certificate: %s', key)
cert_obj = jwk.construct(key, 'RS256')
cert_obj = jwt.PyJWK(key, 'RS256')
if kid := key.get('kid'): # In case a key would not have a thumbprint we can match, we don't want it.
self.signing_keys[kid] = cert_obj
self.signing_keys[kid] = cert_obj.key
20 changes: 20 additions & 0 deletions fastapi_azure_auth/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict

import jwt


def is_guest(claims: Dict[str, Any]) -> bool:
"""
Expand All @@ -12,3 +14,21 @@ def is_guest(claims: Dict[str, Any]) -> bool:
claims_iss: str = claims.get('iss', '')
idp: str = claims.get('idp', claims_iss)
return idp != claims_iss


def get_unverified_header(access_token: str | None) -> Dict[str, Any]:
"""
Get header from the access token without verifying the signature
"""
if access_token is None:
return {}
return jwt.get_unverified_header(access_token)


def get_unverified_claims(access_token: str | None) -> Dict[str, Any]:
"""
Get claims from the access token without verifying the signature
"""
if access_token is None:
return {}
return jwt.decode(access_token, options={'verify_signature': False})
85 changes: 19 additions & 66 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ classifiers = [
python = "^3.8"
fastapi = ">0.68.0"
cryptography = ">=40.0.1"
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
httpx = ">0.18.2"
pyjwt = "^2.8.0"


[tool.poetry.group.dev.dependencies]
Expand Down
3 changes: 2 additions & 1 deletion tests/multi_tenant/test_multi_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from fastapi_azure_auth import MultiTenantAzureAuthorizationCodeBearer
from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
from fastapi_azure_auth.exceptions import InvalidAuth


Expand Down Expand Up @@ -283,7 +284,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):

@pytest.mark.anyio
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
async with AsyncClient(
app=app,
base_url='http://test',
Expand Down
3 changes: 2 additions & 1 deletion tests/multi_tenant_b2c/test_multi_tenant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
build_evil_access_token,
)

from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase
from fastapi_azure_auth.openid_config import OpenIdConfig


Expand Down Expand Up @@ -259,7 +260,7 @@ async def test_only_header(multi_tenant_app, mock_openid_and_keys):

@pytest.mark.anyio
async def test_exception_raised(multi_tenant_app, mock_openid_and_keys, mocker):
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
mocker.patch.object(OpenIdConfig, 'load_config', return_value=True)
async with AsyncClient(
app=app,
Expand Down
4 changes: 3 additions & 1 deletion tests/single_tenant/test_single_tenant_v1_v2_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
build_evil_access_token,
)

from fastapi_azure_auth.auth import AzureAuthorizationCodeBearerBase


def current_version(current_cases) -> int:
return current_cases['single_tenant_app']['token_version'].params['version']
Expand Down Expand Up @@ -349,7 +351,7 @@ async def test_only_header(single_tenant_app, mock_openid_and_keys_v1_v2):
@pytest.mark.anyio
async def test_exception_raised(single_tenant_app, mock_openid_and_keys_v1_v2, mocker, current_cases):
test_version = current_version(current_cases)
mocker.patch('fastapi_azure_auth.auth.jwt.decode', side_effect=ValueError('lol'))
mocker.patch.object(AzureAuthorizationCodeBearerBase, 'validate', side_effect=ValueError('lol'))
async with AsyncClient(
app=app,
base_url='http://test',
Expand Down
Loading

0 comments on commit 6e67c7e

Please sign in to comment.