Skip to content

Commit

Permalink
Token exchange support for ManagedIdentityCredential (#19902)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Aug 6, 2021
1 parent 63088bb commit fff4d58
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 62 deletions.
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ class EnvironmentVariables:
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION = "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION"
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"

TOKEN_FILE_PATH = "TOKEN_FILE_PATH"
TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, TOKEN_FILE_PATH)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .._internal import AadClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Callable, Optional
from azure.core.credentials import AccessToken


class ClientAssertionCredential(GetTokenMixin):
def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
# type: (str, str, Callable[[], str], **Any) -> None
"""Authenticates a service principal with a JWT assertion.
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
:param str client_id: the principal's client ID
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
acquires a new token.
:paramtype get_assertion: Callable[[], str]
:keyword str authority: authority of an Azure Active Directory endpoint, for example
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super(ClientAssertionCredential, self).__init__(**kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_access_token(scopes, **kwargs)

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
assertion = self._get_assertion()
token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def __init__(self, **kwargs):
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any


class TokenFileMixin(object):
def __init__(self, token_file_path, **_):
# type: (str, **Any) -> None
super(TokenFileMixin, self).__init__()
self._jwt = ""
self._last_read_time = 0
self._token_file_path = token_file_path

def get_service_account_token(self):
# type: () -> str
now = int(time.time())
if now - self._last_read_time > 300:
with open(self._token_file_path) as f:
self._jwt = f.read()
self._last_read_time = now
return self._jwt


class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, tenant_id, client_id, token_file_path, **kwargs):
# type: (str, str, str, **Any) -> None
super(TokenExchangeCredential, self).__init__(
tenant_id=tenant_id,
client_id=client_id,
get_assertion=self.get_service_account_token,
token_file_path=token_file_path,
**kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_jwt_assertion_request(scopes, assertion)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def get_cached_refresh_tokens(self, scopes):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
pass
Expand Down Expand Up @@ -165,10 +169,8 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None,
request = self._post(data, **kwargs)
return request

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
audience = self._get_token_url(**kwargs)
assertion = self._get_jwt_assertion(certificate, audience)
def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
Expand All @@ -180,19 +182,8 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_client_secret_request(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = self._post(data, **kwargs)
return request

def _get_jwt_assertion(self, certificate, audience):
# type: (AadClientCertificate, str) -> str
def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
Expand All @@ -201,7 +192,7 @@ def _get_jwt_assertion(self, certificate, audience):
json.dumps(
{
"jti": str(uuid4()),
"aud": audience,
"aud": self._get_token_url(**kwargs),
"iss": self._client_id,
"sub": self._client_id,
"nbf": now,
Expand All @@ -213,8 +204,20 @@ def _get_jwt_assertion(self, certificate, audience):
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
signature = certificate.sign(jws)
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
assertion = jwt_bytes.decode("utf-8")

return jwt_bytes.decode("utf-8")
return self._get_jwt_assertion_request(scopes, assertion, **kwargs)

def _get_client_secret_request(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = self._post(data, **kwargs)
return request

def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .._internal import AadClient, AsyncContextManager
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Callable, Optional
from azure.core.credentials import AccessToken


class ClientAssertionCredential(AsyncContextManager, GetTokenMixin):
def __init__(self, tenant_id: str, client_id: str, get_assertion: "Callable[[], str]", **kwargs: "Any") -> None:
"""Authenticates a service principal with a JWT assertion.
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
:param str client_id: the principal's client ID
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
acquires a new token.
:paramtype get_assertion: Callable[[], str]
:keyword str authority: authority of an Azure Active Directory endpoint, for example
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super().__init__(**kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self) -> None:
"""Close the credential's transport session."""
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
return self._client.get_cached_access_token(scopes, **kwargs)

async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
assertion = self._get_assertion()
token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def __init__(self, **kwargs: "Any") -> None:
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential
from ..._credentials.token_exchange import TokenFileMixin

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any


class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs: "Any") -> None:
super().__init__(
tenant_id=tenant_id,
client_id=client_id,
get_assertion=self.get_service_account_token,
token_file_path=token_file_path,
**kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ async def obtain_token_by_authorization_code(
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
async def obtain_token_by_client_certificate(
self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
Expand All @@ -63,6 +64,14 @@ async def obtain_token_by_client_secret(
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_jwt_assertion(
self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_jwt_assertion_request(scopes, assertion)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_refresh_token(
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
Expand Down
10 changes: 6 additions & 4 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
# Licensed under the MIT License.
# ------------------------------------
import functools
import time

from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal import AadClient, AadClientCertificate
from azure.core.credentials import AccessToken

import pytest
from msal import TokenCache
Expand Down Expand Up @@ -234,10 +232,14 @@ def test_retries_token_requests():
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
client.obtain_token_by_refresh_token("", "")
client.obtain_token_by_jwt_assertion("", "")
assert transport.send.call_count > 1
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
client.obtain_token_by_refresh_token("", "")
assert transport.send.call_count > 1


def test_shared_cache():
"""The client should return only tokens associated with its own client_id"""
Expand Down
6 changes: 5 additions & 1 deletion sdk/identity/azure-identity/tests/test_aad_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,14 @@ async def test_retries_token_requests():
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
await client.obtain_token_by_refresh_token("", "")
await client.obtain_token_by_jwt_assertion("", "")
assert transport.send.call_count > 1
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
await client.obtain_token_by_refresh_token("", "")
assert transport.send.call_count > 1


async def test_shared_cache():
"""The client should return only tokens associated with its own client_id"""
Expand Down
Loading

0 comments on commit fff4d58

Please sign in to comment.