Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Token exchange support for ManagedIdentityCredential #19902

Merged
merged 10 commits into from
Aug 6, 2021
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ class EnvironmentVariables:
AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST"
AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION = "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION"
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"

TOKEN_FILE_PATH = "TOKEN_FILE_PATH"
TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, TOKEN_FILE_PATH)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .._internal import AadClient
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Callable, Optional
from azure.core.credentials import AccessToken


class ClientAssertionCredential(GetTokenMixin):
def __init__(self, tenant_id, client_id, get_assertion, **kwargs):
# type: (str, str, Callable[[], str], **Any) -> None
"""Authenticates a service principal with a JWT assertion.
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
:param str client_id: the principal's client ID
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
acquires a new token.
:paramtype get_assertion: Callable[[], str]
:keyword str authority: authority of an Azure Active Directory endpoint, for example
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super(ClientAssertionCredential, self).__init__(**kwargs)

def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
return self._client.get_cached_access_token(scopes, **kwargs)

def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
assertion = self._get_assertion()
token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ def __init__(self, **kwargs):
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any


class TokenFileMixin(object):
def __init__(self, token_file_path, **_):
# type: (str, **Any) -> None
super(TokenFileMixin, self).__init__()
self._jwt = ""
self._last_read_time = 0
self._token_file_path = token_file_path

def get_service_account_token(self):
# type: () -> str
now = int(time.time())
if now - self._last_read_time > 300:
with open(self._token_file_path) as f:
self._jwt = f.read()
self._last_read_time = now
return self._jwt


class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, tenant_id, client_id, token_file_path, **kwargs):
# type: (str, str, str, **Any) -> None
super(TokenExchangeCredential, self).__init__(
tenant_id=tenant_id,
client_id=client_id,
get_assertion=self.get_service_account_token,
token_file_path=token_file_path,
**kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_jwt_assertion_request(scopes, assertion)
now = int(time.time())
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> AccessToken
request = self._get_refresh_token_request(scopes, refresh_token, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def get_cached_refresh_tokens(self, scopes):
def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs):
pass

@abc.abstractmethod
def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
pass
Expand Down Expand Up @@ -165,10 +169,8 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None,
request = self._post(data, **kwargs)
return request

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
audience = self._get_token_url(**kwargs)
assertion = self._get_jwt_assertion(certificate, audience)
def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
Expand All @@ -180,19 +182,8 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_client_secret_request(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = self._post(data, **kwargs)
return request

def _get_jwt_assertion(self, certificate, audience):
# type: (AadClientCertificate, str) -> str
def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
Expand All @@ -201,7 +192,7 @@ def _get_jwt_assertion(self, certificate, audience):
json.dumps(
{
"jti": str(uuid4()),
"aud": audience,
"aud": self._get_token_url(**kwargs),
"iss": self._client_id,
"sub": self._client_id,
"nbf": now,
Expand All @@ -213,8 +204,20 @@ def _get_jwt_assertion(self, certificate, audience):
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
signature = certificate.sign(jws)
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
assertion = jwt_bytes.decode("utf-8")

return jwt_bytes.decode("utf-8")
return self._get_jwt_assertion_request(scopes, assertion, **kwargs)

def _get_client_secret_request(self, scopes, secret, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_id": self._client_id,
"client_secret": secret,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
}
request = self._post(data, **kwargs)
return request

def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .._internal import AadClient, AsyncContextManager
from .._internal.get_token_mixin import GetTokenMixin

if TYPE_CHECKING:
from typing import Any, Callable, Optional
from azure.core.credentials import AccessToken


class ClientAssertionCredential(AsyncContextManager, GetTokenMixin):
def __init__(self, tenant_id: str, client_id: str, get_assertion: "Callable[[], str]", **kwargs: "Any") -> None:
"""Authenticates a service principal with a JWT assertion.
This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more
convenient API for the most common assertion scenario, authenticating a service principal with a certificate.
:param str tenant_id: ID of the principal's tenant. Also called its "directory" ID.
:param str client_id: the principal's client ID
:param get_assertion: a callable that returns a string assertion. The credential will call this every time it
acquires a new token.
:paramtype get_assertion: Callable[[], str]
:keyword str authority: authority of an Azure Active Directory endpoint, for example
"login.microsoftonline.com", the authority for Azure Public Cloud (which is the default).
:class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds.
"""
self._get_assertion = get_assertion
self._client = AadClient(tenant_id, client_id, **kwargs)
super().__init__(**kwargs)

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self) -> None:
"""Close the credential's transport session."""
await self._client.close()
Comment on lines +36 to +42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you keeping the async context management API since it's more relevant for async credentials?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That and self._client here already implements the API. I removed it from the sync credential because I don't want to block this PR on #19746 or paste part of that PR into this one.


async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
return self._client.get_cached_access_token(scopes, **kwargs)

async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
assertion = self._get_assertion()
token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs)
return token
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def __init__(self, **kwargs: "Any") -> None:
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH],
**kwargs
)
else:
from .imds import ImdsCredential

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import TYPE_CHECKING

from .client_assertion import ClientAssertionCredential
from ..._credentials.token_exchange import TokenFileMixin

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any


class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin):
def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs: "Any") -> None:
super().__init__(
tenant_id=tenant_id,
client_id=client_id,
get_assertion=self.get_service_account_token,
token_file_path=token_file_path,
**kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ async def obtain_token_by_authorization_code(
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken
async def obtain_token_by_client_certificate(
self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any"
) -> "AccessToken":
request = self._get_client_certificate_request(scopes, certificate, **kwargs)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
Expand All @@ -63,6 +64,14 @@ async def obtain_token_by_client_secret(
response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_jwt_assertion(
self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any"
) -> "AccessToken":
request = self._get_jwt_assertion_request(scopes, assertion)
now = int(time.time())
response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

async def obtain_token_by_refresh_token(
self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any"
) -> "AccessToken":
Expand Down
10 changes: 6 additions & 4 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
# Licensed under the MIT License.
# ------------------------------------
import functools
import time

from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError
from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal import AadClient, AadClientCertificate
from azure.core.credentials import AccessToken

import pytest
from msal import TokenCache
Expand Down Expand Up @@ -234,10 +232,14 @@ def test_retries_token_requests():
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
client.obtain_token_by_refresh_token("", "")
client.obtain_token_by_jwt_assertion("", "")
assert transport.send.call_count > 1
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
client.obtain_token_by_refresh_token("", "")
assert transport.send.call_count > 1


def test_shared_cache():
"""The client should return only tokens associated with its own client_id"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,14 @@ async def test_retries_token_requests():
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
await client.obtain_token_by_refresh_token("", "")
await client.obtain_token_by_jwt_assertion("", "")
assert transport.send.call_count > 1
transport.send.reset_mock()

with pytest.raises(ServiceRequestError, match=message):
await client.obtain_token_by_refresh_token("", "")
assert transport.send.call_count > 1


async def test_shared_cache():
"""The client should return only tokens associated with its own client_id"""
Expand Down
Loading