Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SharedTokenCacheCredential uses MSAL when given an AuthenticationRecord #13490

Merged
merged 5 commits into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
future version.
([#10816](https://github.com/Azure/azure-sdk-for-python/issues/10816))

### Breaking changes
- Removed `authentication_record` keyword argument from the async
`SharedTokenCacheCredential`, i.e. `azure.identity.aio.SharedTokenCacheCredential`

## 1.4.0 (2020-08-10)
### Added
- `DefaultAzureCredential` uses the value of environment variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import time

from msal.application import PublicClientApplication

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError

from .. import CredentialUnavailableError
from .._constants import AZURE_CLI_CLIENT_ID
from .._internal import AadClient
from .._internal.decorators import log_get_token
from .._internal.decorators import log_get_token, wrap_exceptions
from .._internal.msal_client import MsalClient
from .._internal.shared_token_cache import NO_TOKEN, SharedTokenCacheBase

try:
Expand All @@ -15,7 +23,8 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Optional
from .. import AuthenticationRecord
from .._internal import AadClientBase


Expand All @@ -37,6 +46,20 @@ class SharedTokenCacheCredential(SharedTokenCacheBase):
is unavailable. Defaults to False.
"""

def __init__(self, username=None, **kwargs):
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless "tenant_id" specifies another
self._tenant_id = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._cache = kwargs.pop("_cache", None)
self._app = None
self._client_kwargs = kwargs
self._initialized = False
else:
super(SharedTokenCacheCredential, self).__init__(username=username, **kwargs)

@log_get_token("SharedTokenCacheCredential")
def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type (*str, **Any) -> AccessToken
Expand All @@ -51,18 +74,20 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
:raises ~azure.identity.CredentialUnavailableError: the cache is unavailable or contains insufficient user
information
:raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
attribute gives a reason. Any error response from Azure Active Directory is available as the error's
``response`` attribute.
attribute gives a reason.
"""
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()

if not self._client:
if not self._cache:
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes)

account = self._get_account(self._username, self._tenant_id)

token = self._get_cached_access_token(scopes, account)
Expand All @@ -79,3 +104,54 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
return AadClient(client_id=AZURE_CLI_CLIENT_ID, **kwargs)

def _initialize(self):
if self._initialized:
return

if not self._auth_record:
super(SharedTokenCacheCredential, self)._initialize()
return

self._load_cache()
if self._cache:
self._app = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
)

self._initialized = True

@wrap_exceptions
def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

result = None

accounts_for_user = self._app.get_accounts(username=self._auth_record.username)
if not accounts_for_user:
raise CredentialUnavailableError("The cache contains no account matching the given AuthenticationRecord.")

for account in accounts_for_user:
if account.get("home_account_id") != self._auth_record.home_account_id:
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

# if we get this far, the cache contained a matching account but MSAL failed to authenticate it silently
if result:
# cache contains a matching refresh token but STS returned an error response when MSAL tried to use it
message = "Token acquisition failed"
details = result.get("error_description") or result.get("error")
if details:
message += ": {}".format(details)
raise ClientAuthenticationError(message=message)

# cache doesn't contain a matching refresh (or access) token
raise CredentialUnavailableError(message=NO_TOKEN.format(self._auth_record.username))
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Iterable, List, Mapping, Optional
from .._internal import AadClientBase
from azure.identity import AuthenticationRecord

CacheItem = Mapping[str, str]

Expand Down Expand Up @@ -89,46 +88,36 @@ def _filtered_accounts(accounts, username=None, tenant_id=None):
class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

self._auth_record = kwargs.pop("authentication_record", None) # type: Optional[AuthenticationRecord]
if self._auth_record:
# authenticate in the tenant that produced the record unless 'tenant_id' specifies another
authenticating_tenant = kwargs.pop("tenant_id", None) or self._auth_record.tenant_id
self._tenant_id = self._auth_record.tenant_id
self._authority = self._auth_record.authority
self._username = self._auth_record.username
self._environment_aliases = frozenset((self._authority,))
else:
authenticating_tenant = "organizations"
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)

authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()
environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)
self._cache = kwargs.pop("_cache", None)
self._client = None # type: Optional[AadClientBase]
self._client_kwargs = kwargs
self._client_kwargs["tenant_id"] = authenticating_tenant
self._client_kwargs["tenant_id"] = "organizations"
self._initialized = False

def _initialize(self):
if self._initialized:
return

self._load_cache()
if self._cache:
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

def _load_cache(self):
if not self._cache and self.supported():
allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False)
try:
self._cache = load_user_cache(allow_unencrypted)
except Exception: # pylint:disable=broad-except
pass

if self._cache:
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

@abc.abstractmethod
def _get_auth_client(self, **kwargs):
# type: (**Any) -> AadClientBase
Expand Down Expand Up @@ -176,14 +165,6 @@ def _get_account(self, username=None, tenant_id=None):
# cache is empty or contains no refresh token -> user needs to sign in
raise CredentialUnavailableError(message=NO_ACCOUNTS)

if self._auth_record:
for account in accounts:
if account.get("home_account_id") == self._auth_record.home_account_id:
return account
raise CredentialUnavailableError(
message="The cache contains no account matching the given AuthenticationRecord."
)

filtered_accounts = _filtered_accounts(accounts, username, tenant_id)
if len(filtered_accounts) == 1:
return filtered_accounts[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class SharedTokenCacheCredential(SharedTokenCacheBase, AsyncContextManager):
defines authorities for other clouds.
:keyword str tenant_id: an Azure Active Directory tenant ID. Used to select an account when the cache contains
tokens for multiple identities.
:keyword AuthenticationRecord authentication_record: an authentication record returned by a user credential such as
:class:`DeviceCodeCredential` or :class:`InteractiveBrowserCredential`
:keyword bool allow_unencrypted_cache: if True, the credential will fall back to a plaintext cache when encryption
is unavailable. Defaults to False.
"""
Expand Down
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate_request(request, **_):
try:
expected_request, response = next(sessions)
except StopIteration:
assert False, "unexpected request: {}".format(request)
assert False, "unexpected request: {} {}".format(request.method, request.url)
expected_request.assert_matches(request)
return response

Expand Down
58 changes: 40 additions & 18 deletions sdk/identity/azure-identity/tests/test_shared_cache_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore

from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport
from helpers import (
build_aad_response,
build_id_token,
get_discovery_response,
mock_response,
msal_validating_transport,
Request,
validating_transport,
)


def test_supported():
Expand Down Expand Up @@ -513,8 +521,13 @@ def test_authority_environment_variable():

def test_authentication_record_empty_cache():
record = AuthenticationRecord("tenant_id", "client_id", "authority", "home_account_id", "username")
transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=TokenCache())

def send(request, **_):
# expecting only MSAL discovery requests
assert request.method == 'GET'
return get_discovery_response()

credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=TokenCache())

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
Expand All @@ -529,13 +542,17 @@ def test_authentication_record_no_match():
username = "me"
record = AuthenticationRecord(tenant_id, client_id, authority, home_account_id, username)

transport = Mock(side_effect=Exception("the credential shouldn't send a request"))
def send(request, **_):
# expecting only MSAL discovery requests
assert request.method == 'GET'
return get_discovery_response()

cache = populated_cache(
get_account_event(
"not-" + username, "not-" + object_id, "different-" + tenant_id, client_id="not-" + client_id,
),
)
credential = SharedTokenCacheCredential(authentication_record=record, transport=transport, _cache=cache)
credential = SharedTokenCacheCredential(authentication_record=record, transport=Mock(send=send), _cache=cache)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
Expand All @@ -557,7 +574,8 @@ def test_authentication_record():
)
cache = populated_cache(account)

transport = validating_transport(
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
Expand Down Expand Up @@ -593,7 +611,8 @@ def test_auth_record_multiple_accounts_for_username():
),
)

transport = validating_transport(
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(authority=authority, required_data={"refresh_token": expected_refresh_token})],
responses=[mock_response(json_payload=build_aad_response(access_token=expected_access_token))],
)
Expand Down Expand Up @@ -741,19 +760,22 @@ def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "localhost", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
credential.get_token("scope")
def mock_send(request, **_):
if not request.body:
return get_discovery_response()
assert request.url.startswith("https://localhost/" + expected_tenant_id)
return mock_response(json_payload=build_aad_response(access_token="*"))

transport = Mock(send=Mock(wraps=mock_send))
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id, transport=transport
)
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope") # this raises because the cache is empty

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
assert kwargs["tenant_id"] == expected_tenant_id
assert transport.send.called


def get_account_event(
Expand Down
Loading