diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index 7766f6b82834..ea1526cac540 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -46,3 +46,6 @@ class EnvironmentVariables: AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION = "AZURE_IDENTITY_ENABLE_LEGACY_TENANT_SELECTION" AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME" + + TOKEN_FILE_PATH = "TOKEN_FILE_PATH" + TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, TOKEN_FILE_PATH) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py new file mode 100644 index 000000000000..013307c3e39b --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -0,0 +1,45 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +from .._internal import AadClient +from .._internal.get_token_mixin import GetTokenMixin + +if TYPE_CHECKING: + from typing import Any, Callable, Optional + from azure.core.credentials import AccessToken + + +class ClientAssertionCredential(GetTokenMixin): + def __init__(self, tenant_id, client_id, get_assertion, **kwargs): + # type: (str, str, Callable[[], str], **Any) -> None + """Authenticates a service principal with a JWT assertion. + + This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more + convenient API for the most common assertion scenario, authenticating a service principal with a certificate. + + :param str tenant_id: ID of the principal's tenant. Also called its "directory" ID. + :param str client_id: the principal's client ID + :param get_assertion: a callable that returns a string assertion. The credential will call this every time it + acquires a new token. + :paramtype get_assertion: Callable[[], str] + + :keyword str authority: authority of an Azure Active Directory endpoint, for example + "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). + :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + """ + self._get_assertion = get_assertion + self._client = AadClient(tenant_id, client_id, **kwargs) + super(ClientAssertionCredential, self).__init__(**kwargs) + + def _acquire_token_silently(self, *scopes, **kwargs): + # type: (*str, **Any) -> Optional[AccessToken] + return self._client.get_cached_access_token(scopes, **kwargs) + + def _request_token(self, *scopes, **kwargs): + # type: (*str, **Any) -> AccessToken + assertion = self._get_assertion() + token = self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) + return token diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index 2fc67178fee0..d0a0acef7931 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -65,6 +65,16 @@ def __init__(self, **kwargs): from .azure_arc import AzureArcCredential self._credential = AzureArcCredential(**kwargs) + elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS): + _LOGGER.info("%s will use token exchange", self.__class__.__name__) + from .token_exchange import TokenExchangeCredential + + self._credential = TokenExchangeCredential( + tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], + client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID], + token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH], + **kwargs + ) else: from .imds import ImdsCredential diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py b/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py new file mode 100644 index 000000000000..bb5bcee00058 --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/_credentials/token_exchange.py @@ -0,0 +1,42 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +import time +from typing import TYPE_CHECKING + +from .client_assertion import ClientAssertionCredential + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any + + +class TokenFileMixin(object): + def __init__(self, token_file_path, **_): + # type: (str, **Any) -> None + super(TokenFileMixin, self).__init__() + self._jwt = "" + self._last_read_time = 0 + self._token_file_path = token_file_path + + def get_service_account_token(self): + # type: () -> str + now = int(time.time()) + if now - self._last_read_time > 300: + with open(self._token_file_path) as f: + self._jwt = f.read() + self._last_read_time = now + return self._jwt + + +class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin): + def __init__(self, tenant_id, client_id, token_file_path, **kwargs): + # type: (str, str, str, **Any) -> None + super(TokenExchangeCredential, self).__init__( + tenant_id=tenant_id, + client_id=client_id, + get_assertion=self.get_service_account_token, + token_file_path=token_file_path, + **kwargs + ) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py index a08a5fa8d6e4..79f986fca8a5 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client.py @@ -40,6 +40,13 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs): response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) + def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs): + # type: (Iterable[str], str, **Any) -> AccessToken + request = self._get_jwt_assertion_request(scopes, assertion) + now = int(time.time()) + response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) + return self._process_response(response, now) + def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs): # type: (Iterable[str], str, **Any) -> AccessToken request = self._get_refresh_token_request(scopes, refresh_token, **kwargs) diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 0aeb6e0ee83e..1bf0b2d02aff 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -80,6 +80,10 @@ def get_cached_refresh_tokens(self, scopes): def obtain_token_by_authorization_code(self, scopes, code, redirect_uri, client_secret=None, **kwargs): pass + @abc.abstractmethod + def obtain_token_by_jwt_assertion(self, scopes, assertion, **kwargs): + pass + @abc.abstractmethod def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs): pass @@ -165,10 +169,8 @@ def _get_auth_code_request(self, scopes, code, redirect_uri, client_secret=None, request = self._post(data, **kwargs) return request - def _get_client_certificate_request(self, scopes, certificate, **kwargs): - # type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest - audience = self._get_token_url(**kwargs) - assertion = self._get_jwt_assertion(certificate, audience) + def _get_jwt_assertion_request(self, scopes, assertion, **kwargs): + # type: (Iterable[str], str, **Any) -> HttpRequest data = { "client_assertion": assertion, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", @@ -180,19 +182,8 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs): request = self._post(data, **kwargs) return request - def _get_client_secret_request(self, scopes, secret, **kwargs): - # type: (Iterable[str], str, **Any) -> HttpRequest - data = { - "client_id": self._client_id, - "client_secret": secret, - "grant_type": "client_credentials", - "scope": " ".join(scopes), - } - request = self._post(data, **kwargs) - return request - - def _get_jwt_assertion(self, certificate, audience): - # type: (AadClientCertificate, str) -> str + def _get_client_certificate_request(self, scopes, certificate, **kwargs): + # type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest now = int(time.time()) header = six.ensure_binary( json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8" @@ -201,7 +192,7 @@ def _get_jwt_assertion(self, certificate, audience): json.dumps( { "jti": str(uuid4()), - "aud": audience, + "aud": self._get_token_url(**kwargs), "iss": self._client_id, "sub": self._client_id, "nbf": now, @@ -213,8 +204,20 @@ def _get_jwt_assertion(self, certificate, audience): jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload) signature = certificate.sign(jws) jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature) + assertion = jwt_bytes.decode("utf-8") - return jwt_bytes.decode("utf-8") + return self._get_jwt_assertion_request(scopes, assertion, **kwargs) + + def _get_client_secret_request(self, scopes, secret, **kwargs): + # type: (Iterable[str], str, **Any) -> HttpRequest + data = { + "client_id": self._client_id, + "client_secret": secret, + "grant_type": "client_credentials", + "scope": " ".join(scopes), + } + request = self._post(data, **kwargs) + return request def _get_refresh_token_request(self, scopes, refresh_token, **kwargs): # type: (Iterable[str], str, **Any) -> HttpRequest diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py new file mode 100644 index 000000000000..8b09b43fca3c --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -0,0 +1,50 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +from .._internal import AadClient, AsyncContextManager +from .._internal.get_token_mixin import GetTokenMixin + +if TYPE_CHECKING: + from typing import Any, Callable, Optional + from azure.core.credentials import AccessToken + + +class ClientAssertionCredential(AsyncContextManager, GetTokenMixin): + def __init__(self, tenant_id: str, client_id: str, get_assertion: "Callable[[], str]", **kwargs: "Any") -> None: + """Authenticates a service principal with a JWT assertion. + + This credential is for advanced scenarios. :class:`~azure.identity.ClientCertificateCredential` has a more + convenient API for the most common assertion scenario, authenticating a service principal with a certificate. + + :param str tenant_id: ID of the principal's tenant. Also called its "directory" ID. + :param str client_id: the principal's client ID + :param get_assertion: a callable that returns a string assertion. The credential will call this every time it + acquires a new token. + :paramtype get_assertion: Callable[[], str] + + :keyword str authority: authority of an Azure Active Directory endpoint, for example + "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). + :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + """ + self._get_assertion = get_assertion + self._client = AadClient(tenant_id, client_id, **kwargs) + super().__init__(**kwargs) + + async def __aenter__(self): + await self._client.__aenter__() + return self + + async def close(self) -> None: + """Close the credential's transport session.""" + await self._client.close() + + async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]": + return self._client.get_cached_access_token(scopes, **kwargs) + + async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": + assertion = self._get_assertion() + token = await self._client.obtain_token_by_jwt_assertion(scopes, assertion, **kwargs) + return token diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index 11c685572794..075dffe1d463 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -62,6 +62,16 @@ def __init__(self, **kwargs: "Any") -> None: from .azure_arc import AzureArcCredential self._credential = AzureArcCredential(**kwargs) + elif all(os.environ.get(var) for var in EnvironmentVariables.TOKEN_EXCHANGE_VARS): + _LOGGER.info("%s will use token exchange", self.__class__.__name__) + from .token_exchange import TokenExchangeCredential + + self._credential = TokenExchangeCredential( + tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], + client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID], + token_file_path=os.environ[EnvironmentVariables.TOKEN_FILE_PATH], + **kwargs + ) else: from .imds import ImdsCredential diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py new file mode 100644 index 000000000000..1ac071ed310e --- /dev/null +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/token_exchange.py @@ -0,0 +1,23 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from typing import TYPE_CHECKING + +from .client_assertion import ClientAssertionCredential +from ..._credentials.token_exchange import TokenFileMixin + +if TYPE_CHECKING: + # pylint:disable=unused-import,ungrouped-imports + from typing import Any + + +class TokenExchangeCredential(ClientAssertionCredential, TokenFileMixin): + def __init__(self, tenant_id: str, client_id: str, token_file_path: str, **kwargs: "Any") -> None: + super().__init__( + tenant_id=tenant_id, + client_id=client_id, + get_assertion=self.get_service_account_token, + token_file_path=token_file_path, + **kwargs + ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py index 249a63d69e0e..44123a2c9f69 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_internal/aad_client.py @@ -48,8 +48,9 @@ async def obtain_token_by_authorization_code( response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) - async def obtain_token_by_client_certificate(self, scopes, certificate, **kwargs): - # type: (Iterable[str], AadClientCertificate, **Any) -> AccessToken + async def obtain_token_by_client_certificate( + self, scopes: "Iterable[str]", certificate: "AadClientCertificate", **kwargs: "Any" + ) -> "AccessToken": request = self._get_client_certificate_request(scopes, certificate, **kwargs) now = int(time.time()) response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) @@ -63,6 +64,14 @@ async def obtain_token_by_client_secret( response = await self._pipeline.run(request, retry_on_methods=self._POST, **kwargs) return self._process_response(response, now) + async def obtain_token_by_jwt_assertion( + self, scopes: "Iterable[str]", assertion: str, **kwargs: "Any" + ) -> "AccessToken": + request = self._get_jwt_assertion_request(scopes, assertion) + now = int(time.time()) + response = await self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs) + return self._process_response(response, now) + async def obtain_token_by_refresh_token( self, scopes: "Iterable[str]", refresh_token: str, **kwargs: "Any" ) -> "AccessToken": diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 9f798d50c4fe..2737f955b51d 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -3,12 +3,10 @@ # Licensed under the MIT License. # ------------------------------------ import functools -import time from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError -from azure.identity._constants import EnvironmentVariables, DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY +from azure.identity._constants import EnvironmentVariables from azure.identity._internal import AadClient, AadClientCertificate -from azure.core.credentials import AccessToken import pytest from msal import TokenCache @@ -234,10 +232,14 @@ def test_retries_token_requests(): transport.send.reset_mock() with pytest.raises(ServiceRequestError, match=message): - client.obtain_token_by_refresh_token("", "") + client.obtain_token_by_jwt_assertion("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + with pytest.raises(ServiceRequestError, match=message): + client.obtain_token_by_refresh_token("", "") + assert transport.send.call_count > 1 + def test_shared_cache(): """The client should return only tokens associated with its own client_id""" diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index 64212d573e97..5f3fd757d399 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -236,10 +236,14 @@ async def test_retries_token_requests(): transport.send.reset_mock() with pytest.raises(ServiceRequestError, match=message): - await client.obtain_token_by_refresh_token("", "") + await client.obtain_token_by_jwt_assertion("", "") assert transport.send.call_count > 1 transport.send.reset_mock() + with pytest.raises(ServiceRequestError, match=message): + await client.obtain_token_by_refresh_token("", "") + assert transport.send.call_count > 1 + async def test_shared_cache(): """The client should return only tokens associated with its own client_id""" diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index a492c8315555..eb199f08b368 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -11,12 +11,10 @@ import mock # type: ignore from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError -from azure.core.pipeline.transport import HttpRequest +from azure.core.exceptions import ClientAuthenticationError from azure.identity import ManagedIdentityCredential from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH -from azure.identity._internal.managed_identity_client import ManagedIdentityClient from azure.identity._internal.user_agent import USER_AGENT import pytest @@ -34,6 +32,11 @@ EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", }, {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc + { # token exchange + EnvironmentVariables.AZURE_CLIENT_ID: "...", + EnvironmentVariables.AZURE_TENANT_ID: "...", + EnvironmentVariables.TOKEN_FILE_PATH: __file__, + }, {}, # IMDS ) @@ -547,9 +550,7 @@ def send(request, **_): # Cloud Shell with mock.patch.dict( - MANAGED_IDENTITY_ENVIRON, - {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, - clear=True, + MANAGED_IDENTITY_ENVIRON, {EnvironmentVariables.MSI_ENDPOINT: "https://localhost"}, clear=True ): credential = ManagedIdentityCredential(client_id=None, transport=mock.Mock(send=send)) token = credential.get_token(scope) @@ -733,18 +734,56 @@ def test_azure_arc_client_id(): credential.get_token("scope") -def test_managed_identity_client_retry(): - """ManagedIdentityClient should retry token requests""" +def test_token_exchange(tmpdir): + exchange_token = "exchange-token" + token_file = tmpdir.join("token") + token_file.write(exchange_token) + access_token = "***" + authority = "https://localhost" + client_id = "client_id" + tenant = "tenant_id" + scope = "scope" - message = "can't connect" - transport = mock.Mock(send=mock.Mock(side_effect=ServiceRequestError(message))) - request_factory = mock.Mock() + transport = validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": client_id, + "grant_type": "client_credentials", + "scope": scope, + }, + ) + ], + responses=[ + mock_response( + json_payload={ + "access_token": access_token, + "expires_in": 3600, + "ext_expires_in": 3600, + "expires_on": int(time.time()) + 3600, + "not_before": int(time.time()), + "resource": scope, + "token_type": "Bearer", + } + ) + ], + ) - client = ManagedIdentityClient(request_factory, transport=transport) + with mock.patch.dict( + "os.environ", + { + EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, + EnvironmentVariables.AZURE_CLIENT_ID: client_id, + EnvironmentVariables.AZURE_TENANT_ID: tenant, + EnvironmentVariables.TOKEN_FILE_PATH: token_file.strpath, + }, + clear=True, + ): + credential = ManagedIdentityCredential(transport=transport) + token = credential.get_token(scope) - for method in ("GET", "POST"): - request_factory.return_value = HttpRequest(method, "https://localhost") - with pytest.raises(ServiceRequestError, match=message): - client.request_token("scope") - assert transport.send.call_count > 1 - transport.send.reset_mock() + assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index dde080f03961..14f616c404cc 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -7,10 +7,8 @@ from unittest import mock from azure.core.credentials import AccessToken -from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError -from azure.core.pipeline.transport import HttpRequest +from azure.core.exceptions import ClientAuthenticationError from azure.identity.aio import ManagedIdentityCredential -from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH from azure.identity._constants import EnvironmentVariables from azure.identity._internal.user_agent import USER_AGENT @@ -716,18 +714,56 @@ async def test_azure_arc_client_id(): @pytest.mark.asyncio -async def test_managed_identity_client_retry(): - """AsyncManagedIdentityClient should retry token requests""" +async def test_token_exchange(tmpdir): + exchange_token = "exchange-token" + token_file = tmpdir.join("token") + token_file.write(exchange_token) + access_token = "***" + authority = "https://localhost" + client_id = "client_id" + tenant = "tenant_id" + scope = "scope" - message = "can't connect" - transport = mock.Mock(send=mock.Mock(side_effect=ServiceRequestError(message)), sleep=get_completed_future) - request_factory = mock.Mock() + transport = async_validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": client_id, + "grant_type": "client_credentials", + "scope": scope, + }, + ) + ], + responses=[ + mock_response( + json_payload={ + "access_token": access_token, + "expires_in": 3600, + "ext_expires_in": 3600, + "expires_on": int(time.time()) + 3600, + "not_before": int(time.time()), + "resource": scope, + "token_type": "Bearer", + } + ) + ], + ) - client = AsyncManagedIdentityClient(request_factory, transport=transport) + with mock.patch.dict( + "os.environ", + { + EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, + EnvironmentVariables.AZURE_CLIENT_ID: client_id, + EnvironmentVariables.AZURE_TENANT_ID: tenant, + EnvironmentVariables.TOKEN_FILE_PATH: token_file.strpath, + }, + clear=True, + ): + credential = ManagedIdentityCredential(transport=transport) + token = await credential.get_token(scope) - for method in ("GET", "POST"): - request_factory.return_value = HttpRequest(method, "https://localhost") - with pytest.raises(ServiceRequestError, match=message): - await client.request_token("scope") - assert transport.send.call_count > 1 - transport.send.reset_mock() + assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_client.py b/sdk/identity/azure-identity/tests/test_managed_identity_client.py index 909a31e90cea..e797cd875543 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_client.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_client.py @@ -5,7 +5,7 @@ import json import time -from azure.core.exceptions import ClientAuthenticationError +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.core.pipeline.transport import HttpRequest from azure.identity._internal.managed_identity_client import ManagedIdentityClient import pytest @@ -83,6 +83,23 @@ def send(request, **_): assert token.token == expected_token +def test_retry(): + """ManagedIdentityClient should retry token requests""" + + message = "can't connect" + transport = mock.Mock(send=mock.Mock(side_effect=ServiceRequestError(message))) + request_factory = mock.Mock() + + client = ManagedIdentityClient(request_factory, transport=transport) + + for method in ("GET", "POST"): + request_factory.return_value = HttpRequest(method, "https://localhost") + with pytest.raises(ServiceRequestError, match=message): + client.request_token("scope") + assert transport.send.call_count > 1 + transport.send.reset_mock() + + @pytest.mark.parametrize("content_type", ("text/html","application/json")) def test_unexpected_content(content_type): content = "not JSON" diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_client_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_client_async.py index bba2a5c77b9d..03a9d870aa4a 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_client_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_client_async.py @@ -6,13 +6,13 @@ import time from unittest.mock import Mock, patch -from azure.core.exceptions import ClientAuthenticationError +from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError from azure.core.pipeline.transport import HttpRequest from azure.identity.aio._internal.managed_identity_client import AsyncManagedIdentityClient import pytest from helpers import mock_response, Request -from helpers_async import async_validating_transport, AsyncMockTransport +from helpers_async import async_validating_transport, AsyncMockTransport, get_completed_future pytestmark = pytest.mark.asyncio @@ -108,6 +108,24 @@ async def send(request, **_): assert token.token == expected_token +@pytest.mark.asyncio +async def test_managed_identity_client_retry(): + """AsyncManagedIdentityClient should retry token requests""" + + message = "can't connect" + transport = Mock(send=Mock(side_effect=ServiceRequestError(message)), sleep=get_completed_future) + request_factory = Mock() + + client = AsyncManagedIdentityClient(request_factory, transport=transport) + + for method in ("GET", "POST"): + request_factory.return_value = HttpRequest(method, "https://localhost") + with pytest.raises(ServiceRequestError, match=message): + await client.request_token("scope") + assert transport.send.call_count > 1 + transport.send.reset_mock() + + @pytest.mark.parametrize("content_type", ("text/html", "application/json")) async def test_unexpected_content(content_type): content = "not JSON"