Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CAE in azure-identity #16323

Merged
merged 4 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _request_token(self, *scopes, **kwargs):
request_state = str(uuid.uuid4())
app = self._get_app()
auth_url = app.get_authorization_request_url(
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account", **kwargs
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account"
)

# open browser to that url
Expand All @@ -113,7 +113,9 @@ def _request_token(self, *scopes, **kwargs):

# redeem the authorization code for a token
code = self._parse_response(request_state, response)
return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs)
return app.acquire_token_by_authorization_code(
code, scopes=scopes, redirect_uri=redirect_uri, claims_challenge=kwargs.get("claims")
)

@staticmethod
def _parse_response(request_state, response):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ def _request_token(self, *scopes, **kwargs):
if self._timeout is not None and self._timeout < flow["expires_in"]:
# user specified an effective timeout we will observe
deadline = int(time.time()) + self._timeout
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
result = app.acquire_token_by_device_flow(
flow, exit_condition=lambda flow: time.time() > deadline, claims_challenge=kwargs.get("claims")
)
else:
# MSAL will stop polling when the device code expires
result = app.acquire_token_by_device_flow(flow)
result = app.acquire_token_by_device_flow(flow, claims_challenge=kwargs.get("claims"))

if "access_token" not in result:
if result.get("error") == "authorization_pending":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
Expand All @@ -87,7 +89,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes)
return self._acquire_token_silent(*scopes, **kwargs)

account = self._get_account(self._username, self._tenant_id)

Expand Down Expand Up @@ -121,6 +123,7 @@ def _initialize(self):
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
client_capabilities=["CP1"]
)

self._initialized = True
Expand All @@ -146,7 +149,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = self._app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,8 @@ def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict
app = self._get_app()
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
username=self._username,
password=self._password,
scopes=list(scopes),
claims_challenge=kwargs.get("claims"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def get_token(self, *scopes, **kwargs):
This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims: additional claims required in the token, such as those returned in a resource provider's
claims challenge following an authorization failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
required data, state, or platform support
Expand Down Expand Up @@ -187,7 +189,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims")
)
Comment on lines +192 to +194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of **kwargs will make acquiring SSH certificate stop working as data is passed in kwargs.

Azure CLI is aiming to support acquiring SSH certificate for Service Principal as well (#16397).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing offline discussion here, there are two orthogonal concerns: acquiring SSH certificates, and allowing applications to pass keyword arguments through get_token to MSAL. Neither is supported today. Acquiring SSH certificates may be supported in a future version. Passing MSAL arguments through get_token is possible in some cases but really isn't supportable because routine maintenance requires internal changes (e.g. #16449) that can break applications relying on this behavior. My goal with this change is to prevent applications from taking a dependency on such implementation details.

if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand All @@ -200,7 +204,7 @@ def _acquire_token_silent(self, *scopes, **kwargs):
def _get_app(self):
# type: () -> msal.PublicClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.PublicClientApplication)
self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=["CP1"])
return self._msal_app

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ def _get_app(self):
# type: () -> msal.ClientApplication
pass

def _create_app(self, cls):
# type: (Type[msal.ClientApplication]) -> msal.ClientApplication
def _create_app(self, cls, **kwargs):
# type: (Type[msal.ClientApplication], **Any) -> msal.ClientApplication
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
http_client=self._client,
**kwargs
)

return app
77 changes: 77 additions & 0 deletions sdk/identity/azure-identity/tests/recording_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import base64
import binascii
import hashlib
import json
import re
import time

from azure_devtools.scenario_tests import RecordingProcessor
import six


SECRETS = frozenset({
"access_token",
"client_secret",
"code",
"device_code",
"message",
"password",
"refresh_token",
"user_code",
})


class RecordingRedactor(RecordingProcessor):
"""Removes authentication secrets from recordings"""

def process_request(self, request):
# don't record the body because it probably contains secrets and is formed by msal anyway,
# i.e. it isn't this library's responsibility
request.body = None
return request

def process_response(self, response):
try:
body = json.loads(response["body"]["string"])
except (KeyError, ValueError):
return response

for field in body:
if field in SECRETS:
# record a hash of the secret instead of a simple replacement like "redacted"
# because some tests (e.g. for CAE) require unique, consistent values
digest = hashlib.sha256(six.ensure_binary(body[field])).digest()
body[field] = six.ensure_str(binascii.hexlify(digest))

response["body"]["string"] = json.dumps(body)
return response


class IdTokenProcessor(RecordingProcessor):
def process_response(self, response):
"""Changes the "exp" claim of recorded id tokens to be in the future during playback

This is necessary because msal always validates id tokens, raising an exception when they've expired.
"""
try:
# decode the recorded token
body = json.loads(six.ensure_str(response["body"]["string"]))
header, encoded_payload, signed = body["id_token"].split(".")
decoded_payload = base64.b64decode(encoded_payload + "=" * (4 - len(encoded_payload) % 4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what are the "=" added onto the payload for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A base64 encoded string should be padded with "=" to make its length divisible by 4. CPython's base64 strictly requires padding. However, because a decoder can infer the padding, encoders commonly omit it.


# set the token's expiry time to one hour from now
payload = json.loads(six.ensure_str(decoded_payload))
payload["exp"] = int(time.time()) + 3600

# write the modified token to the response body
new_payload = six.ensure_binary(json.dumps(payload))
body["id_token"] = ".".join((header, base64.b64encode(new_payload).decode("utf-8"), signed))
response["body"]["string"] = six.ensure_binary(json.dumps(body))
except KeyError:
pass

return response
Loading