Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SharedTokenCacheCredential lazily loads the cache #12172

Merged
merged 5 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()
chlowell marked this conversation as resolved.
Show resolved Hide resolved

if not self._client:
raise CredentialUnavailableError(message="Shared token cache unavailable")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import platform
import time

from msal import TokenCache
Expand Down Expand Up @@ -107,20 +108,26 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
self._tenant_id = kwargs.pop("tenant_id", None)

self._cache = kwargs.pop("_cache", None)
if not self._cache:
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
self._client = None # type: Optional[AadClientBase]
self._client_kwargs = kwargs
self._client_kwargs["tenant_id"] = authenticating_tenant
self._initialized = False

def _initialize(self):
if self._initialized:
return

if not self._cache and self.supported():
allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False)
try:
self._cache = load_user_cache(allow_unencrypted)
except Exception: # pylint:disable=broad-except
pass

if self._cache:
self._client = self._get_auth_client(
authority=self._authority, cache=self._cache, tenant_id=authenticating_tenant, **kwargs
) # type: Optional[AadClientBase]
else:
# couldn't load the cache -> credential will be unavailable
self._client = None
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

@abc.abstractmethod
def _get_auth_client(self, **kwargs):
Expand Down Expand Up @@ -236,12 +243,4 @@ def supported():

:rtype: bool
"""
try:
load_user_cache(allow_unencrypted=False)
except NotImplementedError:
return False
except ValueError:
# cache is supported but can't be encrypted
pass

return True
return platform.system() in {"Darwin", "Linux", "Windows"}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()

if not self._client:
raise CredentialUnavailableError(message="Shared token cache unavailable")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport


def test_supported():
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
assert SharedTokenCacheCredential.supported()


def test_no_scopes():
"""The credential should raise when get_token is called with no scopes"""

Expand Down Expand Up @@ -717,14 +722,34 @@ def test_access_token_caching():
)


def test_initialization():
"""the credential should attempt to load the cache only once, when it's first needed"""

with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
mock_cache_loader.side_effect = Exception("it didn't work")

credential = SharedTokenCacheCredential()
assert mock_cache_loader.call_count == 0

for _ in range(2):
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
assert mock_cache_loader.call_count == 1


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
credential.get_token("scope")

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from test_shared_cache_credential import get_account_event, populated_cache


def test_supported():
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
assert SharedTokenCacheCredential.supported()


@pytest.mark.asyncio
async def test_no_scopes():
"""The credential should raise when get_token is called with no scopes"""
Expand All @@ -37,39 +42,53 @@ async def test_no_scopes():

@pytest.mark.asyncio
async def test_close():
transport = AsyncMockTransport()
async def send(*_, **__):
return mock_response(json_payload=build_aad_response(access_token="**"))

transport = AsyncMockTransport(send=send)
credential = SharedTokenCacheCredential(
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
)

# the credential doesn't open a transport session before one is needed, so we send a request
await credential.get_token("scope")

await credential.close()

assert transport.__aexit__.call_count == 1


@pytest.mark.asyncio
async def test_context_manager():
transport = AsyncMockTransport()
async def send(*_, **__):
return mock_response(json_payload=build_aad_response(access_token="**"))

transport = AsyncMockTransport(send=send)
credential = SharedTokenCacheCredential(
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
)

# async with before initialization: credential should call aexit but not aenter
async with credential:
assert transport.__aenter__.call_count == 1
await credential.get_token("scope")

assert transport.__aenter__.call_count == 1
assert transport.__aenter__.call_count == 0
assert transport.__aexit__.call_count == 1

# async with after initialization: credential should call aenter and aexit
async with credential:
await credential.get_token("scope")
assert transport.__aenter__.call_count == 1
assert transport.__aexit__.call_count == 2


@pytest.mark.asyncio
async def test_context_manager_no_cache():
"""the credential shouldn't open/close sessions when instantiated in an environment with no cache"""

transport = AsyncMockTransport()

with patch(
"azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)
):
with patch("azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)):
credential = SharedTokenCacheCredential(transport=transport)

async with credential:
Expand Down Expand Up @@ -666,14 +685,20 @@ async def test_auth_record_multiple_accounts_for_username():
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
@pytest.mark.asyncio
async def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
await credential.get_token("scope")

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
Expand Down Expand Up @@ -713,3 +738,20 @@ async def test_allow_unencrypted_cache():

msal_extensions_patch.stop()
platform_patch.stop()


@pytest.mark.asyncio
async def test_initialization():
"""the credential should attempt to load the cache only once, when it's first needed"""

with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
mock_cache_loader.side_effect = Exception("it didn't work")

credential = SharedTokenCacheCredential()
assert mock_cache_loader.call_count == 0

for _ in range(2):
with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")
assert mock_cache_loader.call_count == 1