Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OnBehalfOfCredential #20451

Merged
merged 13 commits into from
Sep 3, 2021
4 changes: 3 additions & 1 deletion sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
### Features Added
- `CertificateCredential` accepts certificates in PKCS12 format
([#13540](https://github.com/Azure/azure-sdk-for-python/issues/13540))
- `OnBehalfOfCredential` supports the on-behalf-of authentication flow for
accessing resources on behalf of users
([#19308](https://github.com/Azure/azure-sdk-for-python/issues/19308))

### Breaking Changes

Expand All @@ -17,7 +20,6 @@
([#18798](https://github.com/Azure/azure-sdk-for-python/issues/18798))



## 1.6.1 (2021-08-19)

### Other Changes
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EnvironmentCredential,
InteractiveBrowserCredential,
ManagedIdentityCredential,
OnBehalfOfCredential,
SharedTokenCacheCredential,
UsernamePasswordCredential,
VisualStudioCodeCredential,
Expand All @@ -45,6 +46,7 @@
"EnvironmentCredential",
"InteractiveBrowserCredential",
"KnownAuthorities",
"OnBehalfOfCredential",
"RegionalAuthority",
"ManagedIdentityCredential",
"SharedTokenCacheCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .default import DefaultAzureCredential
from .environment import EnvironmentCredential
from .managed_identity import ManagedIdentityCredential
from .on_behalf_of import OnBehalfOfCredential
from .shared_cache import SharedTokenCacheCredential
from .azure_cli import AzureCliCredential
from .device_code import DeviceCodeCredential
Expand All @@ -32,6 +33,7 @@
"EnvironmentCredential",
"InteractiveBrowserCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"SharedTokenCacheCredential",
"AzureCliCredential",
"UsernamePasswordCredential",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time
from typing import cast, TYPE_CHECKING

import msal
import six

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .certificate import get_client_credential
from .._internal.decorators import wrap_exceptions
from .._internal.get_token_mixin import GetTokenMixin
from .._internal.interactive import _build_auth_record
from .._internal.msal_credentials import MsalCredential

if TYPE_CHECKING:
from typing import Any, Dict, Optional, Union
from .. import AuthenticationRecord


class OnBehalfOfCredential(MsalCredential, GetTokenMixin):
"""Authenticates a service principal via the on-behalf-of flow.

This flow is typically used by middle-tier services that authorize requests to other services with a delegated
user identity. Because this is not an interactive authentication flow, an application using it must have admin
consent for any delegated permissions before requesting tokens for them. See `Azure Active Directory documentation
<https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow>`_ for a more detailed
description of the on-behalf-of flow.

:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: the service principal's client ID
:param client_credential: a credential to authenticate the service principal, either one of its client secrets (a
string) or the bytes of a certificate in PEM or PKCS12 format including the private key
:type client_credential: str or bytes
:param str user_assertion: the access token the credential will use as the user assertion when requesting
on-behalf-of tokens

:keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If allow_multitenant_authentication is true, is tenant_id still required?

Copy link
Member Author

@chlowell chlowell Aug 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if only to identify a tenant the service principal is registered in.

the application is registered in. When False, which is the default, the credential will acquire tokens only
from the tenant specified by **tenant_id**.
:keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com",
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
defines authorities for other clouds.
:keyword password: a certificate password. Used only when **client_credential** is certificate bytes. If this value
is a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass
appropriately encoded bytes instead.
:paramtype password: str or bytes
"""

def __init__(self, tenant_id, client_id, client_credential, user_assertion, **kwargs):
# type: (str, str, Union[bytes, str], str, **Any) -> None
credential = cast("Union[Dict, str]", client_credential)
if isinstance(client_credential, six.binary_type):
try:
credential = get_client_credential(
certificate_path=None, password=kwargs.pop("password", None), certificate_data=client_credential
)
except ValueError:
# client_credential isn't a cert, which is to be expected on 2.7 where str == bytes
pass

super(OnBehalfOfCredential, self).__init__(client_id, credential, tenant_id=tenant_id, **kwargs)
self._assertion = user_assertion
self._auth_record = None # type: Optional[AuthenticationRecord]

@wrap_exceptions
def _acquire_token_silently(self, *scopes, **kwargs):
# type: (*str, **Any) -> Optional[AccessToken]
if self._auth_record:
claims = kwargs.get("claims")
app = self._get_app(**kwargs)
for account in app.get_accounts(username=self._auth_record.username):
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(list(scopes), account=account, claims_challenge=claims)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

return None

@wrap_exceptions
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
app = self._get_app(**kwargs) # type: msal.ConfidentialClientApplication
request_time = int(time.time())
result = app.acquire_token_on_behalf_of(self._assertion, list(scopes), claims_challenge=kwargs.get("claims"))
if "access_token" not in result or "expires_in" not in result:
message = "Authentication failed: {}".format(result.get("error_description") or result.get("error"))
response = self._client.get_error_response(result)
raise ClientAuthenticationError(message=message, response=response)

try:
self._auth_record = _build_auth_record(result)
except ClientAuthenticationError:
pass # non-fatal; we'll use the assertion again next time instead of a refresh token

return AccessToken(result["access_token"], request_time + int(result["expires_in"]))
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, Optional
from typing import Any, Iterable, Optional, Union
from azure.core.credentials import AccessToken
from azure.core.pipeline import Pipeline
from .._internal import AadClientCertificate
Expand Down Expand Up @@ -65,6 +65,11 @@ def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
response = self._pipeline.run(request, stream=False, retry_on_methods=self._POST, **kwargs)
return self._process_response(response, now)

def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, **kwargs):
# type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> AccessToken
# no need for an implementation, non-async OnBehalfOfCredential acquires tokens through MSAL
raise NotImplementedError()

# pylint:disable=no-self-use
def _build_pipeline(self, **kwargs):
# type: (**Any) -> Pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority, normalize_authority
from .._internal import resolve_tenant
from .._internal.aadclient_certificate import AadClientCertificate

try:
from typing import TYPE_CHECKING
Expand All @@ -34,12 +35,13 @@
from azure.core.pipeline import AsyncPipeline, Pipeline, PipelineResponse
from azure.core.pipeline.policies import AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport, HttpTransport
from .._internal import AadClientCertificate

PipelineType = Union[AsyncPipeline, Pipeline]
PolicyType = Union[AsyncHTTPPolicy, HTTPPolicy, SansIOHTTPPolicy]
TransportType = Union[AsyncHttpTransport, HttpTransport]

JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"


class AadClientBase(ABC):
_POST = ["POST"]
Expand Down Expand Up @@ -96,6 +98,10 @@ def obtain_token_by_client_secret(self, scopes, secret, **kwargs):
def obtain_token_by_refresh_token(self, scopes, refresh_token, **kwargs):
pass

@abc.abstractmethod
def obtain_token_on_behalf_of(self, scopes, client_credential, user_assertion, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, **kwargs):
pass
Expand Down Expand Up @@ -173,7 +179,7 @@ def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
"client_assertion": assertion,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion_type": JWT_BEARER_ASSERTION,
"client_id": self._client_id,
"grant_type": "client_credentials",
"scope": " ".join(scopes),
Expand All @@ -182,8 +188,8 @@ def _get_jwt_assertion_request(self, scopes, assertion, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
def _get_client_certificate_assertion(self, certificate, **kwargs):
# type: (AadClientCertificate, **Any) -> str
now = int(time.time())
header = six.ensure_binary(
json.dumps({"typ": "JWT", "alg": "RS256", "x5t": certificate.thumbprint}), encoding="utf-8"
Expand All @@ -204,8 +210,11 @@ def _get_client_certificate_request(self, scopes, certificate, **kwargs):
jws = base64.urlsafe_b64encode(header) + b"." + base64.urlsafe_b64encode(payload)
signature = certificate.sign(jws)
jwt_bytes = jws + b"." + base64.urlsafe_b64encode(signature)
assertion = jwt_bytes.decode("utf-8")
return jwt_bytes.decode("utf-8")

def _get_client_certificate_request(self, scopes, certificate, **kwargs):
# type: (Iterable[str], AadClientCertificate, **Any) -> HttpRequest
assertion = self._get_client_certificate_assertion(certificate, **kwargs)
return self._get_jwt_assertion_request(scopes, assertion, **kwargs)

def _get_client_secret_request(self, scopes, secret, **kwargs):
Expand All @@ -219,6 +228,24 @@ def _get_client_secret_request(self, scopes, secret, **kwargs):
request = self._post(data, **kwargs)
return request

def _get_on_behalf_of_request(self, scopes, client_credential, user_assertion, **kwargs):
# type: (Iterable[str], Union[str, AadClientCertificate], str, **Any) -> HttpRequest
data = {
"assertion": user_assertion,
"client_id": self._client_id,
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"requested_token_use": "on_behalf_of",
"scope": " ".join(scopes),
}
if isinstance(client_credential, AadClientCertificate):
data["client_assertion"] = self._get_client_certificate_assertion(client_credential)
data["client_assertion_type"] = JWT_BEARER_ASSERTION
else:
data["client_secret"] = client_credential

request = self._post(data, **kwargs)
return request

def _get_refresh_token_request(self, scopes, refresh_token, **kwargs):
# type: (Iterable[str], str, **Any) -> HttpRequest
data = {
Expand Down
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/azure/identity/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DefaultAzureCredential,
EnvironmentCredential,
ManagedIdentityCredential,
OnBehalfOfCredential,
SharedTokenCacheCredential,
VisualStudioCodeCredential,
)
Expand All @@ -30,6 +31,7 @@
"DefaultAzureCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"ChainedTokenCredential",
"SharedTokenCacheCredential",
"VisualStudioCodeCredential",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .default import DefaultAzureCredential
from .environment import EnvironmentCredential
from .managed_identity import ManagedIdentityCredential
from .on_behalf_of import OnBehalfOfCredential
from .certificate import CertificateCredential
from .client_secret import ClientSecretCredential
from .shared_cache import SharedTokenCacheCredential
Expand All @@ -27,6 +28,7 @@
"DefaultAzureCredential",
"EnvironmentCredential",
"ManagedIdentityCredential",
"OnBehalfOfCredential",
"SharedTokenCacheCredential",
"VisualStudioCodeCredential",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import logging
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError

from .._internal import AadClient, AsyncContextManager
from .._internal.get_token_mixin import GetTokenMixin
from ..._credentials.certificate import get_client_credential
from ..._internal import AadClientCertificate, validate_tenant_id

if TYPE_CHECKING:
from typing import Any, Optional, Union
from azure.core.credentials import AccessToken

_LOGGER = logging.getLogger(__name__)


class OnBehalfOfCredential(AsyncContextManager, GetTokenMixin):
"""Authenticates a service principal via the on-behalf-of flow.

This flow is typically used by middle-tier services that authorize requests to other services with a delegated
user identity. Because this is not an interactive authentication flow, an application using it must have admin
consent for any delegated permissions before requesting tokens for them. See `Azure Active Directory documentation
<https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow>`_ for a more detailed
description of the on-behalf-of flow.

:param str tenant_id: ID of the service principal's tenant. Also called its "directory" ID.
:param str client_id: the service principal's client ID
:param client_credential: a credential to authenticate the service principal, either one of its client secrets (a
string) or the bytes of a certificate in PEM or PKCS12 format including the private key
:paramtype client_credential: str or bytes
:param str user_assertion: the access token the credential will use as the user assertion when requesting
on-behalf-of tokens

:keyword bool allow_multitenant_authentication: when True, enables the credential to acquire tokens from any tenant
the application is registered in. When False, which is the default, the credential will acquire tokens only
from the tenant specified by **tenant_id**.
:keyword str authority: Authority of an Azure Active Directory endpoint, for example "login.microsoftonline.com",
the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts`
defines authorities for other clouds.
:keyword password: a certificate password. Used only when **client_credential** is certificate bytes. If this value
is a unicode string, it will be encoded as UTF-8. If the certificate requires a different encoding, pass
appropriately encoded bytes instead.
:paramtype password: str or bytes
"""

def __init__(
self,
tenant_id: str,
client_id: str,
client_credential: "Union[bytes, str]",
user_assertion: str,
**kwargs: "Any"
) -> None:
super().__init__()
validate_tenant_id(tenant_id)

if isinstance(client_credential, bytes):
cert = get_client_credential(None, kwargs.pop("password", None), client_credential)
self._credential = AadClientCertificate(
cert["private_key"], password=cert.get("passphrase")
) # type: Union[str, AadClientCertificate]
else:
self._credential = client_credential

# note AadClient handles "allow_multitenant_authentication", "authority", and any pipeline kwargs
self._client = AadClient(tenant_id, client_id, **kwargs)
self._assertion = user_assertion

async def __aenter__(self):
await self._client.__aenter__()
return self

async def close(self):
await self._client.close()

async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
return self._client.get_cached_access_token(scopes, **kwargs)

async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
# Note we assume the cache has tokens for one user only. That's okay because each instance of this class is
# locked to a single user (assertion). This assumption will become unsafe if this class allows applications
# to change an instance's assertion.
refresh_tokens = self._client.get_cached_refresh_tokens(scopes)
if len(refresh_tokens) == 1: # there should be only one
try:
refresh_token = refresh_tokens[0]["secret"]
return await self._client.obtain_token_by_refresh_token(scopes, refresh_token, **kwargs)
except ClientAuthenticationError as ex:
_LOGGER.debug("silent authentication failed: %s", ex, exc_info=True)
except (IndexError, KeyError, TypeError) as ex:
# this is purely defensive, hasn't been observed in practice
_LOGGER.debug("silent authentication failed due to malformed refresh token: %s", ex, exc_info=True)

# we don't have a refresh token, or silent auth failed: acquire a new token from the assertion
return await self._client.obtain_token_on_behalf_of(scopes, self._credential, self._assertion, **kwargs)
Loading