Skip to content

Commit

Permalink
Support CAE in azure-identity (#16323)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Feb 3, 2021
1 parent 891d7aa commit ef46a5c
Show file tree
Hide file tree
Showing 14 changed files with 1,835 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _request_token(self, *scopes, **kwargs):
request_state = str(uuid.uuid4())
app = self._get_app()
auth_url = app.get_authorization_request_url(
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account", **kwargs
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account"
)

# open browser to that url
Expand All @@ -113,7 +113,9 @@ def _request_token(self, *scopes, **kwargs):

# redeem the authorization code for a token
code = self._parse_response(request_state, response)
return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs)
return app.acquire_token_by_authorization_code(
code, scopes=scopes, redirect_uri=redirect_uri, claims_challenge=kwargs.get("claims")
)

@staticmethod
def _parse_response(request_state, response):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def _request_token(self, *scopes, **kwargs):
if self._timeout is not None and self._timeout < flow["expires_in"]:
# user specified an effective timeout we will observe
deadline = int(time.time()) + self._timeout
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
result = app.acquire_token_by_device_flow(
flow, exit_condition=lambda flow: time.time() > deadline, claims_challenge=kwargs.get("claims")
)
else:
# MSAL will stop polling when the device code expires
result = app.acquire_token_by_device_flow(flow)
result = app.acquire_token_by_device_flow(flow, claims_challenge=kwargs.get("claims"))

if "access_token" not in result:
if result.get("error") == "authorization_pending":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
Expand All @@ -87,7 +89,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes)
return self._acquire_token_silent(*scopes, **kwargs)

account = self._get_account(self._username, self._tenant_id)

Expand Down Expand Up @@ -121,6 +123,7 @@ def _initialize(self):
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
client_capabilities=["CP1"]
)

self._initialized = True
Expand All @@ -146,7 +149,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = self._app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,8 @@ def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict
app = self._get_app()
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
username=self._username,
password=self._password,
scopes=list(scopes),
claims_challenge=kwargs.get("claims"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def get_token(self, *scopes, **kwargs):
This method is called automatically by Azure SDK clients.
:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
required data, state, or platform support
Expand Down Expand Up @@ -187,7 +189,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand All @@ -200,7 +204,7 @@ def _acquire_token_silent(self, *scopes, **kwargs):
def _get_app(self):
# type: () -> msal.PublicClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.PublicClientApplication)
self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=["CP1"])
return self._msal_app

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ def _get_app(self):
# type: () -> msal.ClientApplication
pass

def _create_app(self, cls):
# type: (Type[msal.ClientApplication]) -> msal.ClientApplication
def _create_app(self, cls, **kwargs):
# type: (Type[msal.ClientApplication], **Any) -> msal.ClientApplication
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
http_client=self._client,
**kwargs
)

return app
77 changes: 77 additions & 0 deletions sdk/identity/azure-identity/tests/recording_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import base64
import binascii
import hashlib
import json
import re
import time

from azure_devtools.scenario_tests import RecordingProcessor
import six


SECRETS = frozenset({
"access_token",
"client_secret",
"code",
"device_code",
"message",
"password",
"refresh_token",
"user_code",
})


class RecordingRedactor(RecordingProcessor):
"""Removes authentication secrets from recordings"""

def process_request(self, request):
# don't record the body because it probably contains secrets and is formed by msal anyway,
# i.e. it isn't this library's responsibility
request.body = None
return request

def process_response(self, response):
try:
body = json.loads(response["body"]["string"])
except (KeyError, ValueError):
return response

for field in body:
if field in SECRETS:
# record a hash of the secret instead of a simple replacement like "redacted"
# because some tests (e.g. for CAE) require unique, consistent values
digest = hashlib.sha256(six.ensure_binary(body[field])).digest()
body[field] = six.ensure_str(binascii.hexlify(digest))

response["body"]["string"] = json.dumps(body)
return response


class IdTokenProcessor(RecordingProcessor):
def process_response(self, response):
"""Changes the "exp" claim of recorded id tokens to be in the future during playback
This is necessary because msal always validates id tokens, raising an exception when they've expired.
"""
try:
# decode the recorded token
body = json.loads(six.ensure_str(response["body"]["string"]))
header, encoded_payload, signed = body["id_token"].split(".")
decoded_payload = base64.b64decode(encoded_payload + "=" * (4 - len(encoded_payload) % 4))

# set the token's expiry time to one hour from now
payload = json.loads(six.ensure_str(decoded_payload))
payload["exp"] = int(time.time()) + 3600

# write the modified token to the response body
new_payload = six.ensure_binary(json.dumps(payload))
body["id_token"] = ".".join((header, base64.b64encode(new_payload).decode("utf-8"), signed))
response["body"]["string"] = six.ensure_binary(json.dumps(body))
except KeyError:
pass

return response
Loading

0 comments on commit ef46a5c

Please sign in to comment.