diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 3b44fde1ba2d..90b5ef38ffb3 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -214,6 +214,7 @@ "msrest", "msrestazure", "MSSQL", + "mutex", "myacr", "nazsdk", "noarch", @@ -372,13 +373,14 @@ ] }, { - "filename": "sdk/communication/azure-communication-identity/tests/*.py", + "filename": "sdk/communication/azure-communication-identity/tests/**", "words": [ "XVCJ", "Njgw", "FNNHHJT", "Zwiz", - "nypg" + "nypg", + "PBOF" ] }, { diff --git a/sdk/communication/azure-communication-chat/CHANGELOG.md b/sdk/communication/azure-communication-chat/CHANGELOG.md index 1ac145c46cd7..df8c5ce833be 100644 --- a/sdk/communication/azure-communication-chat/CHANGELOG.md +++ b/sdk/communication/azure-communication-chat/CHANGELOG.md @@ -2,6 +2,10 @@ ## 1.2.0 (Unreleased) +- Added support for proactive refreshing of tokens + - `CommunicationTokenCredential` exposes a new boolean keyword argument `proactive_refresh` that defaults to `False`. If set to `True`, the refreshing of the token will be scheduled in the background ensuring continuous authentication state. + - Added disposal function `close` for `CommunicationTokenCredential`. + ### Features Added ### Breaking Changes @@ -12,16 +16,20 @@ Python 2.7 is no longer supported. Please use Python version 3.6 or later. ## 1.1.0 (2021-09-15) + - Updated `azure-communication-chat` version. ## 1.1.0b1 (2021-08-16) ### Added + - Added support to add `metadata` for `message` - Added support to add `sender_display_name` for `ChatThreadClient.send_typing_notification` ## 1.0.0 (2021-03-29) + ### Breaking Changes + - Renamed `ChatThread` to `ChatThreadProperties`. - Renamed `get_chat_thread` to `get_properties`. - Moved `get_properties` under `ChatThreadClient`. @@ -37,22 +45,29 @@ Python 2.7 is no longer supported. Please use Python version 3.6 or later. - Refactored implementation of `CommunicationUserIdentifier`, `PhoneNumberIdentifier`, `MicrosoftTeamsUserIdentifier`, `UnknownIdentifier` to use a `dict` property bag. ## 1.0.0b5 (2021-03-09) + ### Breaking Changes + - Added support for communication identifiers instead of raw strings. - Changed return type of `create_chat_thread`: `ChatThreadClient -> CreateChatThreadResult` - Changed return types `add_participants`: `None -> list[(ChatThreadParticipant, CommunicationError)]` - Added check for failure in `add_participant` - Dropped support for Python 3.5 + ### Added + - Removed nullable references from method signatures. ## 1.0.0b4 (2021-02-09) + ### Breaking Changes + - Uses `CommunicationUserIdentifier` and `CommunicationIdentifier` in place of `CommunicationUser`, and `CommunicationTokenCredential` instead of `CommunicationUserCredential`. - Removed priority field (ChatMessage.Priority). - Renamed PhoneNumber to PhoneNumberIdentifier. ### Added + - Support for CreateChatThreadResult and AddChatParticipantsResult to handle partial errors in batch calls. - Added idempotency identifier parameter for chat creation calls. - Added support for readreceipts and getparticipants pagination. @@ -61,10 +76,13 @@ Python 2.7 is no longer supported. Please use Python version 3.6 or later. - Added `MicrosoftTeamsUserIdentifier`. ## 1.0.0b3 (2020-11-16) + - Updated `azure-communication-chat` version. ## 1.0.0b2 (2020-10-06) + - Updated `azure-communication-chat` version. ## 1.0.0b1 (2020-09-22) - - Add ChatClient and ChatThreadClient. + +- Add ChatClient and ChatThreadClient. diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py index 9b5f17dcc95d..f4a89336ad58 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential.py @@ -3,56 +3,68 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from threading import Lock, Condition -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) +from threading import Lock, Condition, Timer, TIMEOUT_MAX, Event +from datetime import timedelta +from typing import Any +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The sync token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], AccessToken] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh' is enabled, the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' callable. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - - def __init__(self, - token, # type: str - **kwargs - ): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 + + def __init__(self, token: str, **kwargs: Any): + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None self._lock = Condition(Lock()) self._some_thread_refreshing = False + self._is_closed = Event() def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") - if not self._token_refresher or not self._token_expiring(): + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + self._update_token_and_reschedule() + return self._token + def _update_token_and_reschedule(self): should_this_thread_refresh = False - with self._lock: - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - self._wait_till_inprogress_thread_finish_refreshing() + self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -60,27 +72,74 @@ def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable - + new_token = self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise + if self._proactive_refresh: + self._schedule_refresh() return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + if timespan <= TIMEOUT_MAX: + self._timer = Timer(timespan, self._update_token_and_reschedule) + self._timer.daemon = True + self._timer.start() + + def _wait_till_lock_owner_finishes_refreshing(self): self._lock.release() self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) + + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on + + def __enter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() + return self + + def __exit__(self, *args): + self.close() + + def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py index 52a99e7a4b6a..c41dc363c3e4 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_credential_async.py @@ -3,93 +3,149 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from asyncio import Condition, Lock -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, - Any -) +from asyncio import Condition, Lock, Event +from datetime import timedelta +from typing import Any +import sys +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token +from .utils_async import AsyncTimer class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The async token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], Awaitable[AccessToken]] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh is enabled', the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' function. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 def __init__(self, token: str, **kwargs: Any): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() - self._lock = Condition(Lock()) + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None + self._async_mutex = Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug(https://bugs.python.org/issue45416): + getattr(self._async_mutex, '_get_loop', lambda: None)() + self._lock = Condition(self._async_mutex) self._some_thread_refreshing = False + self._is_closed = Event() async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + await self._update_token_and_reschedule() + return self._token + async def _update_token_and_reschedule(self): should_this_thread_refresh = False - async with self._lock: - - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - await self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True break - if should_this_thread_refresh: try: - newtoken = await self._token_refresher() # pylint:disable=not-callable - + new_token = await self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") async with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise - + if self._proactive_refresh: + self._schedule_refresh() return self._token - async def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + + self._timer = AsyncTimer(timespan, self._update_token_and_reschedule) + self._timer.start() + + async def _wait_till_lock_owner_finishes_refreshing(self): + self._lock.release() await self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) - async def close(self) -> None: - pass + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on async def __aenter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() return self async def __aexit__(self, *args): await self.close() + + async def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_token_refresh_options.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_token_refresh_options.py deleted file mode 100644 index 6bdc0d456026..000000000000 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/user_token_refresh_options.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -import six -from .utils import create_access_token - -class CommunicationTokenRefreshOptions(object): - """Options for refreshing CommunicationTokenCredential. - :param str token: The token used to authenticate to an Azure Communication service - :param token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError - """ - - def __init__(self, - token, # type: str - token_refresher=None - ): - # type: (str) -> None - if not isinstance(token, six.string_types): - raise TypeError("token must be a string.") - self._token = token - self._token_refresher = token_refresher - - def get_token(self): - """Return the the serialized JWT token.""" - return create_access_token(self._token) - - def get_token_refresher(self): - """Return the token refresher to provide capacity to fetch fresh token.""" - return self._token_refresher diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py index c9255a4217d7..0b3556bbaa44 100644 --- a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils.py @@ -6,15 +6,15 @@ import base64 import json -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -from datetime import datetime import calendar +from typing import (cast, + Tuple, + ) +from datetime import datetime from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + def _convert_datetime_to_utc_int(input_datetime): """ Converts DateTime in local time to the Epoch in UTC in second. @@ -26,6 +26,7 @@ def _convert_datetime_to_utc_int(input_datetime): """ return int(calendar.timegm(input_datetime.utctimetuple())) + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] if conn_str is None: @@ -53,9 +54,10 @@ def parse_connection_str(conn_str): return host, str(shared_access_key) + def get_current_utc_time(): # type: () -> str - return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + return str(datetime.now(tz=TZ_UTC).strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" def get_current_utc_as_int(): @@ -63,6 +65,7 @@ def get_current_utc_as_int(): current_utc_datetime = datetime.utcnow() return _convert_datetime_to_utc_int(current_utc_datetime) + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -84,18 +87,20 @@ def create_access_token(token): raise ValueError(token_parse_err_msg) try: - padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii') + padded_base64_payload = base64.b64decode( + parts[1] + '==').decode('ascii') payload = json.loads(padded_base64_payload) return AccessToken(token, _convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC))) - except ValueError: - raise ValueError(token_parse_err_msg) + except ValueError as val_error: + raise ValueError(token_parse_err_msg) from val_error + def get_authentication_policy( - endpoint, # type: str - credential, # type: TokenCredential or str - decode_url=False, # type: bool - is_async=False, # type: bool + endpoint, # type: str + credential, # type: TokenCredential or str + decode_url=False, # type: bool + is_async=False, # type: bool ): # type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy """Returns the correct authentication policy based diff --git a/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils_async.py b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils_async.py new file mode 100644 index 000000000000..86e0e04d273c --- /dev/null +++ b/sdk/communication/azure-communication-chat/azure/communication/chat/_shared/utils_async.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio + + +class AsyncTimer: + """A non-blocking timer, that calls a function after a specified number of seconds: + :param int interval: time interval in seconds + :param callable callback: function to be called after the interval has elapsed + """ + + def __init__(self, interval, callback): + self._interval = interval + self._callback = callback + self._task = None + + def start(self): + self._task = asyncio.ensure_future(self._job()) + + async def _job(self): + await asyncio.sleep(self._interval) + await self._callback() + + def cancel(self): + if self._task is not None: + self._task.cancel() + self._task = None diff --git a/sdk/communication/azure-communication-chat/samples/user_credential_sample.py b/sdk/communication/azure-communication-chat/samples/user_credential_sample.py new file mode 100644 index 000000000000..055fd9e4d96a --- /dev/null +++ b/sdk/communication/azure-communication-chat/samples/user_credential_sample.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: user_credential_sample.py +DESCRIPTION: + These samples demonstrate creating a `CommunicationTokenCredential` object. + The `CommunicationTokenCredential` object is used to authenticate a user with Communication Services, + such as Chat or Calling. It optionally provides an auto-refresh mechanism to ensure a continuously + stable authentication state during communications. + +USAGE: + python user_credential_sample.py + Set the environment variables with your own values before running the sample: + 1) COMMUNICATION_SAMPLES_CONNECTION_STRING - the connection string in your Communication Services resource +""" + + +import os +from azure.communication.chat import CommunicationTokenCredential +from azure.communication.identity import CommunicationIdentityClient + +class CommunicationTokenCredentialSamples(object): + + connection_string = os.environ.get("COMMUNICATION_SAMPLES_CONNECTION_STRING", None) + if not connection_string: + raise ValueError("Set COMMUNICATION_SAMPLES_CONNECTION_STRING env before running this sample.") + + identity_client = CommunicationIdentityClient.from_connection_string(connection_string) + user = identity_client.create_user() + token_response = identity_client.get_token(user, scopes=["chat"]) + token = token_response.token + + def create_credential_with_static_token(self): + # For short-lived clients, refreshing the token upon expiry is not necessary + # and `CommunicationTokenCredential` may be instantiated with a static token. + with CommunicationTokenCredential(self.token) as credential: + token_response = credential.get_token() + print("Token issued with value: " + token_response.token) + + def create_credential_with_refreshing_callback(self): + # Alternatively, for long-lived clients, you can create a `CommunicationTokenCredential` with a callback to renew tokens if expired. + # Here we assume that we have a function `fetch_token_from_server` that makes a network request to retrieve a token string for a user. + # It's necessary that the `fetch_token_from_server` function returns a valid token (with an expiration date set in the future) at all times. + fetch_token_from_server = lambda: None + with CommunicationTokenCredential( + self.token, token_refresher=fetch_token_from_server) as credential: + token_response = credential.get_token() + print("Token issued with value: " + token_response.token) + + def create_credential_with_proactive_refreshing_callback(self): + # Optionally, you can enable proactive token refreshing where a fresh token will be acquired as soon as the + # previous token approaches expiry. Using this method, your requests are less likely to be blocked to acquire a fresh token + fetch_token_from_server = lambda: None + with CommunicationTokenCredential( + self.token, token_refresher=fetch_token_from_server, proactive_refresh=True) as credential: + token_response = credential.get_token() + print("Token issued with value: " + token_response.token) + + def clean_up(self): + print("cleaning up: deleting created user.") + self.identity_client.delete_user(self.user) + +if __name__ == '__main__': + sample = CommunicationTokenCredentialSamples() + sample.create_credential_with_static_token() + sample.create_credential_with_refreshing_callback() + sample.create_credential_with_proactive_refreshing_callback() + sample.clean_up() diff --git a/sdk/communication/azure-communication-chat/samples/user_credential_sample_async.py b/sdk/communication/azure-communication-chat/samples/user_credential_sample_async.py new file mode 100644 index 000000000000..60791cee875c --- /dev/null +++ b/sdk/communication/azure-communication-chat/samples/user_credential_sample_async.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +""" +FILE: user_credential_sample_async.py +DESCRIPTION: + These samples demonstrate creating a `CommunicationTokenCredential` object. + The `CommunicationTokenCredential` object is used to authenticate a user with Communication Services, + such as Chat or Calling. It optionally provides an auto-refresh mechanism to ensure a continuously + stable authentication state during communications. + +USAGE: + python user_credential_sample_async.py + Set the environment variables with your own values before running the sample: + 1) COMMUNICATION_SAMPLES_CONNECTION_STRING - the connection string in your Communication Services resource +""" + + +import os +import asyncio +from azure.communication.chat.aio import CommunicationTokenCredential +from azure.communication.identity import CommunicationIdentityClient + +class CommunicationTokenCredentialSamples(object): + + connection_string = os.environ.get("COMMUNICATION_SAMPLES_CONNECTION_STRING", None) + if not connection_string: + raise ValueError("Set COMMUNICATION_SAMPLES_CONNECTION_STRING env before running this sample.") + + identity_client = CommunicationIdentityClient.from_connection_string(connection_string) + user = identity_client.create_user() + token_response = identity_client.get_token(user, scopes=["chat"]) + token = token_response.token + + async def create_credential_with_static_token(self): + # For short-lived clients, refreshing the token upon expiry is not necessary + # and `CommunicationTokenCredential` may be instantiated with a static token. + async with CommunicationTokenCredential(self.token) as credential: + token_response = await credential.get_token() + print("Token issued with value: " + token_response.token) + + async def create_credential_with_refreshing_callback(self): + # Alternatively, for long-lived clients, you can create a `CommunicationTokenCredential` with a callback to renew tokens if expired. + # Here we assume that we have a function `fetch_token_from_server` that makes a network request to retrieve a token string for a user. + # It's necessary that the `fetch_token_from_server` function returns a valid token (with an expiration date set in the future) at all times. + fetch_token_from_server = lambda: None + async with CommunicationTokenCredential( + self.token, token_refresher=fetch_token_from_server) as credential: + token_response = await credential.get_token() + print("Token issued with value: " + token_response.token) + + async def create_credential_with_proactive_refreshing_callback(self): + # Optionally, you can enable proactive token refreshing where a fresh token will be acquired as soon as the + # previous token approaches expiry. Using this method, your requests are less likely to be blocked to acquire a fresh token + fetch_token_from_server = lambda: None + async with CommunicationTokenCredential( + self.token, token_refresher=fetch_token_from_server, proactive_refresh=True) as credential: + token_response = await credential.get_token() + print("Token issued with value: " + token_response.token) + + def clean_up(self): + print("cleaning up: deleting created user.") + self.identity_client.delete_user(self.user) + +async def main(): + sample = CommunicationTokenCredentialSamples() + await sample.create_credential_with_static_token() + await sample.create_credential_with_refreshing_callback() + await sample.create_credential_with_proactive_refreshing_callback() + sample.clean_up() + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/sdk/communication/azure-communication-chat/tests/_shared/__init__.py b/sdk/communication/azure-communication-chat/tests/_shared/__init__.py index 3b0cfe17e031..841b812e10ba 100644 --- a/sdk/communication/azure-communication-chat/tests/_shared/__init__.py +++ b/sdk/communication/azure-communication-chat/tests/_shared/__init__.py @@ -3,4 +3,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# -------------------------------------------------------------------------- +# -------------------------------------------------------------------------- \ No newline at end of file diff --git a/sdk/communication/azure-communication-chat/tests/_shared/helper.py b/sdk/communication/azure-communication-chat/tests/_shared/helper.py new file mode 100644 index 000000000000..4d3585695f5a --- /dev/null +++ b/sdk/communication/azure-communication-chat/tests/_shared/helper.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import re +import base64 +from azure_devtools.scenario_tests import RecordingProcessor +from datetime import datetime, timedelta +from functools import wraps +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse +import sys + +def generate_token_with_custom_expiry(valid_for_seconds): + return generate_token_with_custom_expiry_epoch((datetime.now() + timedelta(seconds=valid_for_seconds)).timestamp()) + +def generate_token_with_custom_expiry_epoch(expires_on_epoch): + expiry_json = f'{{"exp": {str(expires_on_epoch)} }}' + base64expiry = base64.b64encode( + expiry_json.encode('utf-8')).decode('utf-8').rstrip("=") + token_template = (f'''eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9. + {base64expiry}.adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs''') + return token_template + + +class URIIdentityReplacer(RecordingProcessor): + """Replace the identity in request uri""" + def process_request(self, request): + resource = (urlparse(request.uri).netloc).split('.')[0] + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + return request + + def process_response(self, response): + if 'url' in response: + response['url'] = re.sub('/identities/([^/?]+)', '/identities/sanitized', response['url']) + return response \ No newline at end of file diff --git a/sdk/communication/azure-communication-chat/tests/helper.py b/sdk/communication/azure-communication-chat/tests/helper.py deleted file mode 100644 index 83ea3cc8397a..000000000000 --- a/sdk/communication/azure-communication-chat/tests/helper.py +++ /dev/null @@ -1,19 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from azure_devtools.scenario_tests import RecordingProcessor - -class URIIdentityReplacer(RecordingProcessor): - """Replace the identity in request uri""" - def process_request(self, request): - import re - request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) - return request - - def process_response(self, response): - import re - if 'url' in response: - response['url'] = re.sub('/identities/([^/?]+)', '/identities/sanitized', response['url']) - return response diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e.py b/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e.py index 5f93abbeb4b4..737fcd3ee48b 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e.py @@ -20,7 +20,7 @@ from azure.communication.chat._shared.utils import parse_connection_str from azure_devtools.scenario_tests import RecordingProcessor -from helper import URIIdentityReplacer +from _shared.helper import URIIdentityReplacer from chat_e2e_helper import ChatURIReplacer from _shared.testcase import ( CommunicationTestCase, diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e_async.py b/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e_async.py index df4658693336..19627da57dd1 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e_async.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_client_e2e_async.py @@ -20,7 +20,7 @@ ) from azure.communication.identity._shared.utils import parse_connection_str from azure_devtools.scenario_tests import RecordingProcessor -from helper import URIIdentityReplacer +from _shared.helper import URIIdentityReplacer from chat_e2e_helper import ChatURIReplacer from _shared.asynctestcase import AsyncCommunicationTestCase from _shared.testcase import BodyReplacerProcessor, ResponseReplacerProcessor diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e.py b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e.py index 990a849cb4c2..55d7263cef29 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e.py @@ -20,7 +20,7 @@ from azure.communication.chat._shared.utils import parse_connection_str from azure_devtools.scenario_tests import RecordingProcessor -from helper import URIIdentityReplacer +from _shared.helper import URIIdentityReplacer from chat_e2e_helper import ChatURIReplacer from _shared.testcase import ( CommunicationTestCase, diff --git a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e_async.py b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e_async.py index 1a468780b8ab..26b59f1c1776 100644 --- a/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e_async.py +++ b/sdk/communication/azure-communication-chat/tests/test_chat_thread_client_e2e_async.py @@ -20,7 +20,7 @@ ) from azure.communication.identity._shared.utils import parse_connection_str from azure_devtools.scenario_tests import RecordingProcessor -from helper import URIIdentityReplacer +from _shared.helper import URIIdentityReplacer from chat_e2e_helper import ChatURIReplacer from _shared.asynctestcase import AsyncCommunicationTestCase from _shared.testcase import BodyReplacerProcessor, ResponseReplacerProcessor diff --git a/sdk/communication/azure-communication-identity/CHANGELOG.md b/sdk/communication/azure-communication-identity/CHANGELOG.md index bbc96fbd5e9b..39e9859b6450 100644 --- a/sdk/communication/azure-communication-identity/CHANGELOG.md +++ b/sdk/communication/azure-communication-identity/CHANGELOG.md @@ -12,20 +12,26 @@ - Python 2.7 is no longer supported. Please use Python version 3.6 or later. ## 1.1.0b1 (2021-11-09) + ### Features Added + - Added support for Microsoft 365 Teams identities - `CommunicationIdentityClient` added a new method `get_token_for_teams_user` that provides the ability to exchange an AAD access token of a Teams user for a Communication Identity access token ## 1.0.1 (2021-06-08) + ### Bug Fixes + - Fixed async client to use async bearer token credential policy instead of sync policy. ## 1.0.0 (2021-03-29) + - Stable release of `azure-communication-identity`. ## 1.0.0b5 (2021-03-09) ### Breaking + - CommunicationIdentityClient's (synchronous and asynchronous) `issue_token` function is now renamed to `get_token`. - The CommunicationIdentityClient constructor uses type `TokenCredential` and `AsyncTokenCredential` for the credential parameter. - Dropped support for 3.5 @@ -33,13 +39,16 @@ ## 1.0.0b4 (2021-02-09) ### Added + - Added CommunicationIdentityClient (originally was part of the azure.communication.administration package). - Added ability to create a user and issue token for it at the same time. ### Breaking + - CommunicationIdentityClient.revoke_tokens now revoke all the currently issued tokens instead of revoking tokens issued prior to a given time. - CommunicationIdentityClient.issue_tokens returns an instance of `azure.core.credentials.AccessToken` instead of `CommunicationUserToken`. + [read_me]: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/communication/azure-communication-identity/README.md [documentation]: https://docs.microsoft.com/azure/communication-services/quickstarts/access-tokens?pivots=programming-language-python diff --git a/sdk/communication/azure-communication-identity/README.md b/sdk/communication/azure-communication-identity/README.md index d023b1c3cb65..94b694ef4cfa 100644 --- a/sdk/communication/azure-communication-identity/README.md +++ b/sdk/communication/azure-communication-identity/README.md @@ -103,7 +103,7 @@ identity_client.delete_user(user) Use the `get_token_for_teams_user` method to exchange an AAD access token of a Teams User for a new Communication Identity access token. ```python -identity_client.get_token_for_teams_user(add_token) +identity_client.get_token_for_teams_user(aad_token) ``` # Troubleshooting diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_communication_identity_client.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_communication_identity_client.py index 820ee43e4969..571d507d38c1 100644 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/_communication_identity_client.py +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/_communication_identity_client.py @@ -188,20 +188,20 @@ def revoke_tokens( @distributed_trace def get_token_for_teams_user( self, - add_token, # type: str + aad_token, # type: str **kwargs ): # type: (...) -> AccessToken """Exchanges an AAD access token of a Teams User for a new Communication Identity access token. - :param add_token: an AAD access token of a Teams User - :type add_token: str + :param aad_token: an AAD access token of a Teams User + :type aad_token: str :return: AccessToken :rtype: ~azure.core.credentials.AccessToken """ api_version = kwargs.pop("api_version", self._api_version) return self._identity_service_client.communication_identity.exchange_teams_user_access_token( - token=add_token, + token=aad_token, api_version=api_version, cls=lambda pr, u, e: AccessToken(u.token, u.expires_on), **kwargs) diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential.py index 9b5f17dcc95d..f4a89336ad58 100644 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential.py +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential.py @@ -3,56 +3,68 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from threading import Lock, Condition -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) +from threading import Lock, Condition, Timer, TIMEOUT_MAX, Event +from datetime import timedelta +from typing import Any +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The sync token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], AccessToken] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh' is enabled, the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' callable. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - - def __init__(self, - token, # type: str - **kwargs - ): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 + + def __init__(self, token: str, **kwargs: Any): + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None self._lock = Condition(Lock()) self._some_thread_refreshing = False + self._is_closed = Event() def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") - if not self._token_refresher or not self._token_expiring(): + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + self._update_token_and_reschedule() + return self._token + def _update_token_and_reschedule(self): should_this_thread_refresh = False - with self._lock: - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - self._wait_till_inprogress_thread_finish_refreshing() + self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -60,27 +72,74 @@ def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable - + new_token = self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise + if self._proactive_refresh: + self._schedule_refresh() return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + if timespan <= TIMEOUT_MAX: + self._timer = Timer(timespan, self._update_token_and_reschedule) + self._timer.daemon = True + self._timer.start() + + def _wait_till_lock_owner_finishes_refreshing(self): self._lock.release() self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) + + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on + + def __enter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() + return self + + def __exit__(self, *args): + self.close() + + def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential_async.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential_async.py index 52a99e7a4b6a..c41dc363c3e4 100644 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_credential_async.py @@ -3,93 +3,149 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from asyncio import Condition, Lock -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, - Any -) +from asyncio import Condition, Lock, Event +from datetime import timedelta +from typing import Any +import sys +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token +from .utils_async import AsyncTimer class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The async token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], Awaitable[AccessToken]] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh is enabled', the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' function. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 def __init__(self, token: str, **kwargs: Any): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() - self._lock = Condition(Lock()) + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None + self._async_mutex = Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug(https://bugs.python.org/issue45416): + getattr(self._async_mutex, '_get_loop', lambda: None)() + self._lock = Condition(self._async_mutex) self._some_thread_refreshing = False + self._is_closed = Event() async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + await self._update_token_and_reschedule() + return self._token + async def _update_token_and_reschedule(self): should_this_thread_refresh = False - async with self._lock: - - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - await self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True break - if should_this_thread_refresh: try: - newtoken = await self._token_refresher() # pylint:disable=not-callable - + new_token = await self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") async with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise - + if self._proactive_refresh: + self._schedule_refresh() return self._token - async def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + + self._timer = AsyncTimer(timespan, self._update_token_and_reschedule) + self._timer.start() + + async def _wait_till_lock_owner_finishes_refreshing(self): + self._lock.release() await self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) - async def close(self) -> None: - pass + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on async def __aenter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() return self async def __aexit__(self, *args): await self.close() + + async def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_token_refresh_options.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_token_refresh_options.py deleted file mode 100644 index 6bdc0d456026..000000000000 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/user_token_refresh_options.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -import six -from .utils import create_access_token - -class CommunicationTokenRefreshOptions(object): - """Options for refreshing CommunicationTokenCredential. - :param str token: The token used to authenticate to an Azure Communication service - :param token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError - """ - - def __init__(self, - token, # type: str - token_refresher=None - ): - # type: (str) -> None - if not isinstance(token, six.string_types): - raise TypeError("token must be a string.") - self._token = token - self._token_refresher = token_refresher - - def get_token(self): - """Return the the serialized JWT token.""" - return create_access_token(self._token) - - def get_token_refresher(self): - """Return the token refresher to provide capacity to fetch fresh token.""" - return self._token_refresher diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils.py index c9255a4217d7..0b3556bbaa44 100644 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils.py +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils.py @@ -6,15 +6,15 @@ import base64 import json -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -from datetime import datetime import calendar +from typing import (cast, + Tuple, + ) +from datetime import datetime from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + def _convert_datetime_to_utc_int(input_datetime): """ Converts DateTime in local time to the Epoch in UTC in second. @@ -26,6 +26,7 @@ def _convert_datetime_to_utc_int(input_datetime): """ return int(calendar.timegm(input_datetime.utctimetuple())) + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] if conn_str is None: @@ -53,9 +54,10 @@ def parse_connection_str(conn_str): return host, str(shared_access_key) + def get_current_utc_time(): # type: () -> str - return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + return str(datetime.now(tz=TZ_UTC).strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" def get_current_utc_as_int(): @@ -63,6 +65,7 @@ def get_current_utc_as_int(): current_utc_datetime = datetime.utcnow() return _convert_datetime_to_utc_int(current_utc_datetime) + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -84,18 +87,20 @@ def create_access_token(token): raise ValueError(token_parse_err_msg) try: - padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii') + padded_base64_payload = base64.b64decode( + parts[1] + '==').decode('ascii') payload = json.loads(padded_base64_payload) return AccessToken(token, _convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC))) - except ValueError: - raise ValueError(token_parse_err_msg) + except ValueError as val_error: + raise ValueError(token_parse_err_msg) from val_error + def get_authentication_policy( - endpoint, # type: str - credential, # type: TokenCredential or str - decode_url=False, # type: bool - is_async=False, # type: bool + endpoint, # type: str + credential, # type: TokenCredential or str + decode_url=False, # type: bool + is_async=False, # type: bool ): # type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy """Returns the correct authentication policy based diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils_async.py b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils_async.py new file mode 100644 index 000000000000..86e0e04d273c --- /dev/null +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/_shared/utils_async.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio + + +class AsyncTimer: + """A non-blocking timer, that calls a function after a specified number of seconds: + :param int interval: time interval in seconds + :param callable callback: function to be called after the interval has elapsed + """ + + def __init__(self, interval, callback): + self._interval = interval + self._callback = callback + self._task = None + + def start(self): + self._task = asyncio.ensure_future(self._job()) + + async def _job(self): + await asyncio.sleep(self._interval) + await self._callback() + + def cancel(self): + if self._task is not None: + self._task.cancel() + self._task = None diff --git a/sdk/communication/azure-communication-identity/azure/communication/identity/aio/_communication_identity_client_async.py b/sdk/communication/azure-communication-identity/azure/communication/identity/aio/_communication_identity_client_async.py index 70d5bd7e2d3b..60c37e4d95d7 100644 --- a/sdk/communication/azure-communication-identity/azure/communication/identity/aio/_communication_identity_client_async.py +++ b/sdk/communication/azure-communication-identity/azure/communication/identity/aio/_communication_identity_client_async.py @@ -183,20 +183,20 @@ async def revoke_tokens( @distributed_trace_async async def get_token_for_teams_user( self, - add_token, # type: str + aad_token, # type: str **kwargs ) -> AccessToken: # type: (...) -> AccessToken """Exchanges an AAD access token of a Teams User for a new Communication Identity access token. - :param add_token: an AAD access token of a Teams User - :type add_token: str + :param aad_token: an AAD access token of a Teams User + :type aad_token: str :return: AccessToken :rtype: ~azure.core.credentials.AccessToken """ api_version = kwargs.pop("api_version", self._api_version) return await self._identity_service_client.communication_identity.exchange_teams_user_access_token( - token=add_token, + token=aad_token, api_version=api_version, cls=lambda pr, u, e: AccessToken(u.token, u.expires_on), **kwargs) diff --git a/sdk/communication/azure-communication-identity/tests/_shared/asynctestcase.py b/sdk/communication/azure-communication-identity/tests/_shared/asynctestcase.py index 197c48e0079b..4c331bb79598 100644 --- a/sdk/communication/azure-communication-identity/tests/_shared/asynctestcase.py +++ b/sdk/communication/azure-communication-identity/tests/_shared/asynctestcase.py @@ -1,4 +1,3 @@ - # coding: utf-8 # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. @@ -24,4 +23,4 @@ def run(test_class_instance, *args, **kwargs): loop = asyncio.get_event_loop() return loop.run_until_complete(test_fn(test_class_instance, **kwargs)) - return run + return run \ No newline at end of file diff --git a/sdk/communication/azure-communication-identity/tests/_shared/helper.py b/sdk/communication/azure-communication-identity/tests/_shared/helper.py index 146d94b649b0..4d3585695f5a 100644 --- a/sdk/communication/azure-communication-identity/tests/_shared/helper.py +++ b/sdk/communication/azure-communication-identity/tests/_shared/helper.py @@ -4,8 +4,27 @@ # license information. # -------------------------------------------------------------------------- import re +import base64 from azure_devtools.scenario_tests import RecordingProcessor -from urllib.parse import urlparse +from datetime import datetime, timedelta +from functools import wraps +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse +import sys + +def generate_token_with_custom_expiry(valid_for_seconds): + return generate_token_with_custom_expiry_epoch((datetime.now() + timedelta(seconds=valid_for_seconds)).timestamp()) + +def generate_token_with_custom_expiry_epoch(expires_on_epoch): + expiry_json = f'{{"exp": {str(expires_on_epoch)} }}' + base64expiry = base64.b64encode( + expiry_json.encode('utf-8')).decode('utf-8').rstrip("=") + token_template = (f'''eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9. + {base64expiry}.adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs''') + return token_template + class URIIdentityReplacer(RecordingProcessor): """Replace the identity in request uri""" @@ -13,6 +32,8 @@ def process_request(self, request): resource = (urlparse(request.uri).netloc).split('.')[0] request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) request.uri = re.sub(resource, 'sanitized', request.uri) + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) return request def process_response(self, response): diff --git a/sdk/communication/azure-communication-identity/tests/asynctestcase.py b/sdk/communication/azure-communication-identity/tests/asynctestcase.py index 31d2c2ab6c75..aa08cac55e12 100644 --- a/sdk/communication/azure-communication-identity/tests/asynctestcase.py +++ b/sdk/communication/azure-communication-identity/tests/asynctestcase.py @@ -10,6 +10,7 @@ from azure_devtools.scenario_tests.utilities import trim_kwargs_from_test_function from testcase import CommunicationIdentityTestCase + class AsyncCommunicationIdentityTestCase(CommunicationIdentityTestCase): @staticmethod diff --git a/sdk/communication/azure-communication-identity/tests/test_communication_identity_client.py b/sdk/communication/azure-communication-identity/tests/test_communication_identity_client.py index a88aee6e5b35..fe3251034b01 100644 --- a/sdk/communication/azure-communication-identity/tests/test_communication_identity_client.py +++ b/sdk/communication/azure-communication-identity/tests/test_communication_identity_client.py @@ -233,8 +233,8 @@ def test_get_token_for_teams_user_from_managed_identity(self, communication_live credential, http_logging_policy=get_http_logging_policy() ) - add_token = self.generate_teams_user_aad_token() - token_response = identity_client.get_token_for_teams_user(add_token) + aad_token = self.generate_teams_user_aad_token() + token_response = identity_client.get_token_for_teams_user(aad_token) assert token_response.token is not None @CommunicationPreparer() @@ -245,8 +245,8 @@ def test_get_token_for_teams_user_with_valid_token(self, communication_livetest_ communication_livetest_dynamic_connection_string, http_logging_policy=get_http_logging_policy() ) - add_token = self.generate_teams_user_aad_token() - token_response = identity_client.get_token_for_teams_user(add_token) + aad_token = self.generate_teams_user_aad_token() + token_response = identity_client.get_token_for_teams_user(aad_token) assert token_response.token is not None @CommunicationPreparer() diff --git a/sdk/communication/azure-communication-identity/tests/test_communication_identity_client_async.py b/sdk/communication/azure-communication-identity/tests/test_communication_identity_client_async.py index b4cde9c67227..7dc9a6027acd 100644 --- a/sdk/communication/azure-communication-identity/tests/test_communication_identity_client_async.py +++ b/sdk/communication/azure-communication-identity/tests/test_communication_identity_client_async.py @@ -244,8 +244,8 @@ async def test_get_token_for_teams_user_from_managed_identity(self, communicatio http_logging_policy=get_http_logging_policy() ) async with identity_client: - add_token = self.generate_teams_user_aad_token() - token_response = await identity_client.get_token_for_teams_user(add_token) + aad_token = self.generate_teams_user_aad_token() + token_response = await identity_client.get_token_for_teams_user(aad_token) assert token_response.token is not None @@ -258,8 +258,8 @@ async def test_get_token_for_teams_user_with_valid_token(self, communication_liv http_logging_policy=get_http_logging_policy() ) async with identity_client: - add_token = self.generate_teams_user_aad_token() - token_response = await identity_client.get_token_for_teams_user(add_token) + aad_token = self.generate_teams_user_aad_token() + token_response = await identity_client.get_token_for_teams_user(aad_token) assert token_response.token is not None diff --git a/sdk/communication/azure-communication-identity/tests/test_user_credential.py b/sdk/communication/azure-communication-identity/tests/test_user_credential.py new file mode 100644 index 000000000000..1fd0ed6f4ca2 --- /dev/null +++ b/sdk/communication/azure-communication-identity/tests/test_user_credential.py @@ -0,0 +1,220 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Type +import platform +import pytest +from unittest import TestCase +try: + from unittest.mock import MagicMock, patch +except ImportError: # python < 3.3 + from mock import MagicMock, patch # type: ignore +import azure.communication.identity._shared.user_credential as user_credential +from azure.communication.identity._shared.user_credential import CommunicationTokenCredential +from azure.communication.identity._shared.utils import create_access_token +from azure.communication.identity._shared.utils import get_current_utc_as_int +from datetime import timedelta +from _shared.helper import generate_token_with_custom_expiry_epoch, generate_token_with_custom_expiry + + +class TestCommunicationTokenCredential(TestCase): + + @classmethod + def setUpClass(cls): + cls.sample_token = generate_token_with_custom_expiry_epoch( + 32503680000) # 1/1/2030 + cls.expired_token = generate_token_with_custom_expiry_epoch( + 100) # 1/1/1970 + + def test_communicationtokencredential_decodes_token(self): + credential = CommunicationTokenCredential(self.sample_token) + access_token = credential.get_token() + self.assertEqual(access_token.token, self.sample_token) + + def test_communicationtokencredential_throws_if_invalid_token(self): + self.assertRaises( + ValueError, lambda: CommunicationTokenCredential("foo.bar.tar")) + + def test_communicationtokencredential_throws_if_nonstring_token(self): + self.assertRaises(TypeError, lambda: CommunicationTokenCredential(454)) + + def test_communicationtokencredential_throws_if_proactive_refresh_enabled_without_token_refresher(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential(self.sample_token, proactive_refresh=True) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + with pytest.raises(ValueError) as err: + CommunicationTokenCredential( + self.sample_token, + proactive_refresh=True, + token_refresher=None) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + + def test_communicationtokencredential_static_token_returns_expired_token(self): + credential = CommunicationTokenCredential(self.expired_token) + self.assertEqual(credential.get_token().token, self.expired_token) + + def test_communicationtokencredential_token_expired_refresh_called(self): + refresher = MagicMock( + return_value=create_access_token(self.sample_token)) + credential = CommunicationTokenCredential(self.expired_token, token_refresher=refresher) + access_token = credential.get_token() + refresher.assert_called_once() + self.assertEqual(access_token.token, self.sample_token) + + def test_communicationtokencredential_raises_if_refresher_returns_expired_token(self): + refresher = MagicMock( + return_value=create_access_token(self.expired_token)) + credential = CommunicationTokenCredential(self.expired_token, token_refresher=refresher) + with self.assertRaises(ValueError): + credential.get_token() + self.assertEqual(refresher.call_count, 1) + + def test_uses_initial_token_as_expected(self): + refresher = MagicMock( + return_value=create_access_token(self.expired_token)) + credential = CommunicationTokenCredential( + self.sample_token, token_refresher=refresher, proactive_refresh=True) + access_token = credential.get_token() + + self.assertEqual(refresher.call_count, 0) + self.assertEqual(access_token.token, self.sample_token) + + def test_proactive_refresher_should_not_be_called_before_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + (refresh_minutes - 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = credential.get_token() + + assert refresher.call_count == 0 + assert access_token.token == initial_token + # check that next refresh is always scheduled + assert credential._timer is None + + def test_proactive_refresher_should_be_called_after_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + \ + (token_validity_minutes - refresh_minutes + 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = credential.get_token() + + assert refresher.call_count == 1 + assert access_token.token == refreshed_token + # check that next refresh is always scheduled + assert credential._timer is not None + + def test_proactive_refresher_keeps_scheduling_again(self): + refresh_minutes = 10 + token_validity_minutes = 60 + expired_token = generate_token_with_custom_expiry(-5 * 60) + skip_to_timestamp = get_current_utc_as_int() + (token_validity_minutes - + refresh_minutes) * 60 + 1 + first_refreshed_token = create_access_token( + generate_token_with_custom_expiry(token_validity_minutes * 60)) + last_refreshed_token = create_access_token( + generate_token_with_custom_expiry(2 * token_validity_minutes * 60)) + refresher = MagicMock( + side_effect=[first_refreshed_token, last_refreshed_token]) + + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = credential.get_token() + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + access_token = credential.get_token() + + assert refresher.call_count == 2 + assert access_token.token == last_refreshed_token.token + # check that next refresh is always scheduled + assert credential._timer is not None + + def test_fractional_backoff_applied_when_token_expiring(self): + token_validity_seconds = 5 * 60 + expiring_token = generate_token_with_custom_expiry( + token_validity_seconds) + + refresher = MagicMock( + side_effect=[create_access_token(expiring_token), create_access_token(expiring_token)]) + + credential = CommunicationTokenCredential( + expiring_token, + token_refresher=refresher, + proactive_refresh=True) + + next_milestone = token_validity_seconds / 2 + + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=(get_current_utc_as_int() + next_milestone)): + credential.get_token() + assert refresher.call_count == 1 + next_milestone = next_milestone / 2 + assert credential._timer.interval == next_milestone + + def test_refresher_should_not_be_called_when_token_still_valid(self): + generated_token = generate_token_with_custom_expiry(15 * 60) + new_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock(return_value=create_access_token(new_token)) + + credential = CommunicationTokenCredential( + generated_token, token_refresher=refresher, proactive_refresh=False) + for _ in range(10): + access_token = credential.get_token() + + refresher.assert_not_called() + assert generated_token == access_token.token + + def test_exit_cancels_timer(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + credential = CommunicationTokenCredential( + self.expired_token,token_refresher=refresher, proactive_refresh=True) + credential.get_token() + credential.close() + assert credential._timer is None + + def test_exit_enter_scenario_throws_exception(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + credential = CommunicationTokenCredential( + self.expired_token,token_refresher=refresher, proactive_refresh=True) + credential.get_token() + credential.close() + assert credential._timer is None + + with pytest.raises(RuntimeError) as err: + credential.get_token() + assert str(err.value) == "An instance of CommunicationTokenCredential cannot be reused once it has been closed." + + + \ No newline at end of file diff --git a/sdk/communication/azure-communication-identity/tests/test_user_credential_async.py b/sdk/communication/azure-communication-identity/tests/test_user_credential_async.py new file mode 100644 index 000000000000..b60569f7147a --- /dev/null +++ b/sdk/communication/azure-communication-identity/tests/test_user_credential_async.py @@ -0,0 +1,251 @@ + +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from unittest import TestCase +import pytest +from asyncio import Future +try: + from unittest.mock import MagicMock, patch +except ImportError: # python < 3.3 + from mock import MagicMock, patch +from azure.communication.identity._shared.user_credential_async import CommunicationTokenCredential +import azure.communication.identity._shared.user_credential_async as user_credential_async +from azure.communication.identity._shared.utils import create_access_token +from azure.communication.identity._shared.utils import get_current_utc_as_int +from _shared.helper import generate_token_with_custom_expiry + + +class TestCommunicationTokenCredential(TestCase): + + def get_completed_future(self, result=None): + future = Future() + future.set_result(result) + return future + + @pytest.mark.asyncio + async def test_raises_error_for_init_with_nonstring_token(self): + with pytest.raises(TypeError) as err: + CommunicationTokenCredential(1234) + assert str(err.value) == "Token must be a string." + + @pytest.mark.asyncio + async def test_raises_error_for_init_with_invalid_token(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential("not a token") + assert str(err.value) == "Token is not formatted correctly" + + @pytest.mark.asyncio + async def test_init_with_valid_token(self): + initial_token = generate_token_with_custom_expiry(5 * 60) + credential = CommunicationTokenCredential(initial_token) + access_token = await credential.get_token() + assert initial_token == access_token.token + + @pytest.mark.asyncio + async def test_communicationtokencredential_throws_if_proactive_refresh_enabled_without_token_refresher(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential(self.sample_token, proactive_refresh=True) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + with pytest.raises(ValueError) as err: + CommunicationTokenCredential( + self.sample_token, + proactive_refresh=True, + token_refresher=None) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + + @pytest.mark.asyncio + async def test_refresher_should_be_called_immediately_with_expired_token(self): + refreshed_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock( + return_value=self.get_completed_future(create_access_token(refreshed_token))) + expired_token = generate_token_with_custom_expiry(-(5 * 60)) + + credential = CommunicationTokenCredential( + expired_token, token_refresher=refresher) + access_token = await credential.get_token() + + refresher.assert_called_once() + assert refreshed_token == access_token.token + + @pytest.mark.asyncio + async def test_refresher_should_not_be_called_before_expiring_time(self): + initial_token = generate_token_with_custom_expiry(15 * 60) + refreshed_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + credential = CommunicationTokenCredential( + initial_token, token_refresher=refresher, proactive_refresh=True) + access_token = await credential.get_token() + + refresher.assert_not_called() + assert initial_token == access_token.token + + @pytest.mark.asyncio + async def test_refresher_should_not_be_called_when_token_still_valid(self): + generated_token = generate_token_with_custom_expiry(15 * 60) + new_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock(return_value=create_access_token(new_token)) + + credential = CommunicationTokenCredential( + generated_token, token_refresher=refresher, proactive_refresh=False) + for _ in range(10): + access_token = await credential.get_token() + + refresher.assert_not_called() + assert generated_token == access_token.token + + @pytest.mark.asyncio + async def test_raises_if_refresher_returns_expired_token(self): + expired_token = generate_token_with_custom_expiry(-(10 * 60)) + refresher = MagicMock(return_value=self.get_completed_future( + create_access_token(expired_token))) + + credential = CommunicationTokenCredential( + expired_token, token_refresher=refresher) + with self.assertRaises(ValueError): + await credential.get_token() + + assert refresher.call_count == 1 + + @pytest.mark.asyncio + async def test_proactive_refresher_should_not_be_called_before_specified_time(self): + refresh_minutes = 30 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + (refresh_minutes - 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = await credential.get_token() + + assert refresher.call_count == 0 + assert access_token.token == initial_token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_proactive_refresher_should_be_called_after_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + \ + (token_validity_minutes - refresh_minutes + 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=self.get_completed_future(create_access_token(refreshed_token))) + + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = await credential.get_token() + + assert refresher.call_count == 1 + assert access_token.token == refreshed_token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_proactive_refresher_keeps_scheduling_again(self): + refresh_minutes = 10 + token_validity_minutes = 60 + expired_token = generate_token_with_custom_expiry(-5 * 60) + skip_to_timestamp = get_current_utc_as_int() + (token_validity_minutes - + refresh_minutes) * 60 + 1 + first_refreshed_token = create_access_token( + generate_token_with_custom_expiry(token_validity_minutes * 60)) + last_refreshed_token = create_access_token( + generate_token_with_custom_expiry(2 * token_validity_minutes * 60)) + refresher = MagicMock( + side_effect=[self.get_completed_future(first_refreshed_token), self.get_completed_future(last_refreshed_token)]) + + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + access_token = await credential.get_token() + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + access_token = await credential.get_token() + + assert refresher.call_count == 2 + assert access_token.token == last_refreshed_token.token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_fractional_backoff_applied_when_token_expiring(self): + token_validity_seconds = 5 * 60 + expiring_token = generate_token_with_custom_expiry( + token_validity_seconds) + + refresher = MagicMock( + side_effect=[create_access_token(expiring_token), create_access_token(expiring_token)]) + + credential = CommunicationTokenCredential( + expiring_token, + token_refresher=refresher, + proactive_refresh=True) + + next_milestone = token_validity_seconds / 2 + assert credential._timer.interval == next_milestone + + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=(get_current_utc_as_int() + next_milestone)): + await credential.get_token() + + assert refresher.call_count == 1 + next_milestone = next_milestone / 2 + assert credential._timer.interval == next_milestone + + @pytest.mark.asyncio + async def test_exit_cancels_timer(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + expired_token = generate_token_with_custom_expiry(-10 * 60) + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + credential.get_token() + credential.close() + assert credential._timer is not None + assert refresher.call_count == 0 + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_exit_enter_scenario_throws_exception(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + expired_token = generate_token_with_custom_expiry(-10 * 60) + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + credential.get_token() + credential.close() + assert credential._timer is not None + + with pytest.raises(RuntimeError) as err: + credential.get_token() + assert str(err.value) == "An instance of CommunicationTokenCredential cannot be reused once it has been closed." diff --git a/sdk/communication/azure-communication-identity/tests/test_user_credential_async_with_context_manager.py b/sdk/communication/azure-communication-identity/tests/test_user_credential_async_with_context_manager.py new file mode 100644 index 000000000000..baebb6e6d35b --- /dev/null +++ b/sdk/communication/azure-communication-identity/tests/test_user_credential_async_with_context_manager.py @@ -0,0 +1,259 @@ + +# coding: utf-8 +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from unittest import TestCase +import pytest +from asyncio import Future +try: + from unittest.mock import MagicMock, patch +except ImportError: # python < 3.3 + from mock import MagicMock, patch +from azure.communication.identity._shared.user_credential_async import CommunicationTokenCredential +import azure.communication.identity._shared.user_credential_async as user_credential_async +from azure.communication.identity._shared.utils import create_access_token +from azure.communication.identity._shared.utils import get_current_utc_as_int +from _shared.helper import generate_token_with_custom_expiry + + +class TestCommunicationTokenCredential(TestCase): + + def get_completed_future(self, result=None): + future = Future() + future.set_result(result) + return future + + @pytest.mark.asyncio + async def test_raises_error_for_init_with_nonstring_token(self): + with pytest.raises(TypeError) as err: + CommunicationTokenCredential(1234) + assert str(err.value) == "Token must be a string." + + @pytest.mark.asyncio + async def test_raises_error_for_init_with_invalid_token(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential("not a token") + assert str(err.value) == "Token is not formatted correctly" + + @pytest.mark.asyncio + async def test_init_with_valid_token(self): + initial_token = generate_token_with_custom_expiry(5 * 60) + credential = CommunicationTokenCredential(initial_token) + access_token = await credential.get_token() + assert initial_token == access_token.token + + @pytest.mark.asyncio + async def test_communicationtokencredential_throws_if_proactive_refresh_enabled_without_token_refresher(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential(self.sample_token, proactive_refresh=True) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + with pytest.raises(ValueError) as err: + CommunicationTokenCredential( + self.sample_token, + proactive_refresh=True, + token_refresher=None) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + + @pytest.mark.asyncio + async def test_refresher_should_be_called_immediately_with_expired_token(self): + refreshed_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock( + return_value=self.get_completed_future(create_access_token(refreshed_token))) + expired_token = generate_token_with_custom_expiry(-(5 * 60)) + + credential = CommunicationTokenCredential( + expired_token, token_refresher=refresher) + async with credential: + access_token = await credential.get_token() + + refresher.assert_called_once() + assert refreshed_token == access_token.token + + @pytest.mark.asyncio + async def test_refresher_should_not_be_called_before_expiring_time(self): + initial_token = generate_token_with_custom_expiry(15 * 60) + refreshed_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + credential = CommunicationTokenCredential( + initial_token, token_refresher=refresher, proactive_refresh=True) + async with credential: + access_token = await credential.get_token() + + refresher.assert_not_called() + assert initial_token == access_token.token + + @pytest.mark.asyncio + async def test_refresher_should_not_be_called_when_token_still_valid(self): + generated_token = generate_token_with_custom_expiry(15 * 60) + new_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock(return_value=create_access_token(new_token)) + + credential = CommunicationTokenCredential( + generated_token, token_refresher=refresher, proactive_refresh=False) + async with credential: + for _ in range(10): + access_token = await credential.get_token() + + refresher.assert_not_called() + assert generated_token == access_token.token + + @pytest.mark.asyncio + async def test_raises_if_refresher_returns_expired_token(self): + expired_token = generate_token_with_custom_expiry(-(10 * 60)) + refresher = MagicMock(return_value=self.get_completed_future( + create_access_token(expired_token))) + + credential = CommunicationTokenCredential( + expired_token, token_refresher=refresher) + async with credential: + with self.assertRaises(ValueError): + await credential.get_token() + + assert refresher.call_count == 1 + + @pytest.mark.asyncio + async def test_proactive_refresher_should_not_be_called_before_specified_time(self): + refresh_minutes = 30 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + (refresh_minutes - 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + async with credential: + access_token = await credential.get_token() + + assert refresher.call_count == 0 + assert access_token.token == initial_token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_proactive_refresher_should_be_called_after_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + \ + (token_validity_minutes - refresh_minutes + 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=self.get_completed_future(create_access_token(refreshed_token))) + + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + async with credential: + access_token = await credential.get_token() + + assert refresher.call_count == 1 + assert access_token.token == refreshed_token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_proactive_refresher_keeps_scheduling_again(self): + refresh_minutes = 10 + token_validity_minutes = 60 + expired_token = generate_token_with_custom_expiry(-5 * 60) + skip_to_timestamp = get_current_utc_as_int() + (token_validity_minutes - + refresh_minutes) * 60 + 1 + first_refreshed_token = create_access_token( + generate_token_with_custom_expiry(token_validity_minutes * 60)) + last_refreshed_token = create_access_token( + generate_token_with_custom_expiry(2 * token_validity_minutes * 60)) + refresher = MagicMock( + side_effect=[self.get_completed_future(first_refreshed_token), self.get_completed_future(last_refreshed_token)]) + + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + async with credential: + access_token = await credential.get_token() + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + access_token = await credential.get_token() + + assert refresher.call_count == 2 + assert access_token.token == last_refreshed_token.token + # check that next refresh is always scheduled + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_fractional_backoff_applied_when_token_expiring(self): + token_validity_seconds = 5 * 60 + expiring_token = generate_token_with_custom_expiry( + token_validity_seconds) + + refresher = MagicMock( + side_effect=[create_access_token(expiring_token), create_access_token(expiring_token)]) + + credential = CommunicationTokenCredential( + expiring_token, + token_refresher=refresher, + proactive_refresh=True) + + next_milestone = token_validity_seconds / 2 + assert credential._timer.interval == next_milestone + + async with credential: + with patch(user_credential_async.__name__+'.'+get_current_utc_as_int.__name__, return_value=(get_current_utc_as_int() + next_milestone)): + await credential.get_token() + + assert refresher.call_count == 1 + next_milestone = next_milestone / 2 + assert credential._timer.interval == next_milestone + + @pytest.mark.asyncio + async def test_exit_cancels_timer(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + expired_token = generate_token_with_custom_expiry(-10 * 60) + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + async with credential: + assert credential._timer is not None + assert refresher.call_count == 0 + assert credential._timer is not None + + @pytest.mark.asyncio + async def test_exit_enter_scenario_throws_exception(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + expired_token = generate_token_with_custom_expiry(-10 * 60) + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + async with credential: + assert credential._timer is not None + assert credential._timer is not None + + with pytest.raises(RuntimeError) as err: + with credential: + assert credential._timer is not None + assert str(err.value) == "An instance of CommunicationTokenCredential cannot be reused once it has been closed." diff --git a/sdk/communication/azure-communication-identity/tests/test_user_credential_with_context_manager.py b/sdk/communication/azure-communication-identity/tests/test_user_credential_with_context_manager.py new file mode 100644 index 000000000000..73a6eee38caa --- /dev/null +++ b/sdk/communication/azure-communication-identity/tests/test_user_credential_with_context_manager.py @@ -0,0 +1,228 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from typing import Type +import platform +import pytest +from unittest import TestCase +try: + from unittest.mock import MagicMock, patch +except ImportError: # python < 3.3 + from mock import MagicMock, patch # type: ignore +import azure.communication.identity._shared.user_credential as user_credential +from azure.communication.identity._shared.user_credential import CommunicationTokenCredential +from azure.communication.identity._shared.utils import create_access_token +from azure.communication.identity._shared.utils import get_current_utc_as_int +from datetime import timedelta +from _shared.helper import generate_token_with_custom_expiry_epoch, generate_token_with_custom_expiry + + +class TestCommunicationTokenCredential(TestCase): + + @classmethod + def setUpClass(cls): + cls.sample_token = generate_token_with_custom_expiry_epoch( + 32503680000) # 1/1/2030 + cls.expired_token = generate_token_with_custom_expiry_epoch( + 100) # 1/1/1970 + + def test_communicationtokencredential_decodes_token(self): + with CommunicationTokenCredential(self.sample_token) as credential: + access_token = credential.get_token() + self.assertEqual(access_token.token, self.sample_token) + + def test_communicationtokencredential_throws_if_invalid_token(self): + self.assertRaises( + ValueError, lambda: CommunicationTokenCredential("foo.bar.tar")) + + def test_communicationtokencredential_throws_if_nonstring_token(self): + self.assertRaises(TypeError, lambda: CommunicationTokenCredential(454)) + + def test_communicationtokencredential_throws_if_proactive_refresh_enabled_without_token_refresher(self): + with pytest.raises(ValueError) as err: + CommunicationTokenCredential(self.sample_token, proactive_refresh=True) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + with pytest.raises(ValueError) as err: + CommunicationTokenCredential( + self.sample_token, + proactive_refresh=True, + token_refresher=None) + assert str(err.value) == "When 'proactive_refresh' is True, 'token_refresher' must not be None." + + def test_communicationtokencredential_static_token_returns_expired_token(self): + with CommunicationTokenCredential(self.expired_token) as credential: + self.assertEqual(credential.get_token().token, self.expired_token) + + def test_communicationtokencredential_token_expired_refresh_called(self): + refresher = MagicMock( + return_value=create_access_token(self.sample_token)) + with CommunicationTokenCredential(self.expired_token, token_refresher=refresher) as credential: + access_token = credential.get_token() + refresher.assert_called_once() + self.assertEqual(access_token.token, self.sample_token) + + def test_communicationtokencredential_raises_if_refresher_returns_expired_token(self): + refresher = MagicMock( + return_value=create_access_token(self.expired_token)) + with CommunicationTokenCredential(self.expired_token, token_refresher=refresher) as credential: + with self.assertRaises(ValueError): + credential.get_token() + self.assertEqual(refresher.call_count, 1) + + def test_uses_initial_token_as_expected(self): + refresher = MagicMock( + return_value=create_access_token(self.expired_token)) + credential = CommunicationTokenCredential( + self.sample_token, token_refresher=refresher, proactive_refresh=True) + with credential: + access_token = credential.get_token() + + self.assertEqual(refresher.call_count, 0) + self.assertEqual(access_token.token, self.sample_token) + + def test_proactive_refresher_should_not_be_called_before_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + (refresh_minutes - 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + with credential: + access_token = credential.get_token() + + assert refresher.call_count == 0 + assert access_token.token == initial_token + # check that next refresh is always scheduled + assert credential._timer is not None + + def test_proactive_refresher_should_be_called_after_specified_time(self): + refresh_minutes = 10 + token_validity_minutes = 60 + start_timestamp = get_current_utc_as_int() + skip_to_timestamp = start_timestamp + \ + (token_validity_minutes - refresh_minutes + 5) * 60 + + initial_token = generate_token_with_custom_expiry( + token_validity_minutes * 60) + refreshed_token = generate_token_with_custom_expiry( + 2 * token_validity_minutes * 60) + refresher = MagicMock( + return_value=create_access_token(refreshed_token)) + + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + credential = CommunicationTokenCredential( + initial_token, + token_refresher=refresher, + proactive_refresh=True) + with credential: + access_token = credential.get_token() + + assert refresher.call_count == 1 + assert access_token.token == refreshed_token + # check that next refresh is always scheduled + assert credential._timer is not None + + def test_proactive_refresher_keeps_scheduling_again(self): + refresh_minutes = 10 + token_validity_minutes = 60 + expired_token = generate_token_with_custom_expiry(-5 * 60) + skip_to_timestamp = get_current_utc_as_int() + (token_validity_minutes - + refresh_minutes) * 60 + 1 + first_refreshed_token = create_access_token( + generate_token_with_custom_expiry(token_validity_minutes * 60)) + last_refreshed_token = create_access_token( + generate_token_with_custom_expiry(2 * token_validity_minutes * 60)) + refresher = MagicMock( + side_effect=[first_refreshed_token, last_refreshed_token]) + + credential = CommunicationTokenCredential( + expired_token, + token_refresher=refresher, + proactive_refresh=True) + with credential: + access_token = credential.get_token() + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=skip_to_timestamp): + access_token = credential.get_token() + + assert refresher.call_count == 2 + assert access_token.token == last_refreshed_token.token + # check that next refresh is always scheduled + assert credential._timer is not None + + def test_fractional_backoff_applied_when_token_expiring(self): + token_validity_seconds = 5 * 60 + expiring_token = generate_token_with_custom_expiry( + token_validity_seconds) + + refresher = MagicMock( + side_effect=[create_access_token(expiring_token), create_access_token(expiring_token)]) + + credential = CommunicationTokenCredential( + expiring_token, + token_refresher=refresher, + proactive_refresh=True) + + next_milestone = token_validity_seconds / 2 + + with credential: + assert credential._timer.interval == next_milestone + with patch(user_credential.__name__+'.'+get_current_utc_as_int.__name__, return_value=(get_current_utc_as_int() + next_milestone)): + credential.get_token() + assert refresher.call_count == 1 + next_milestone = next_milestone / 2 + assert credential._timer.interval == next_milestone + + def test_refresher_should_not_be_called_when_token_still_valid(self): + generated_token = generate_token_with_custom_expiry(15 * 60) + new_token = generate_token_with_custom_expiry(10 * 60) + refresher = MagicMock(return_value=create_access_token(new_token)) + + credential = CommunicationTokenCredential( + generated_token, token_refresher=refresher, proactive_refresh=False) + with credential: + for _ in range(10): + access_token = credential.get_token() + + refresher.assert_not_called() + assert generated_token == access_token.token + + def test_exit_cancels_timer(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + credential = CommunicationTokenCredential( + self.expired_token,token_refresher=refresher, proactive_refresh=True) + with credential: + assert credential._timer is not None + assert credential._timer is None + + def test_exit_enter_scenario_throws_exception(self): + refreshed_token = create_access_token( + generate_token_with_custom_expiry(30 * 60)) + refresher = MagicMock(return_value=refreshed_token) + credential = CommunicationTokenCredential( + self.expired_token,token_refresher=refresher, proactive_refresh=True) + with credential: + assert credential._timer is not None + assert credential._timer is None + + with pytest.raises(RuntimeError) as err: + with credential: + assert credential._timer is not None + assert str(err.value) == "An instance of CommunicationTokenCredential cannot be reused once it has been closed." + + + \ No newline at end of file diff --git a/sdk/communication/azure-communication-identity/tests/user_credential_tests.py b/sdk/communication/azure-communication-identity/tests/user_credential_tests.py deleted file mode 100644 index ec461402d09a..000000000000 --- a/sdk/communication/azure-communication-identity/tests/user_credential_tests.py +++ /dev/null @@ -1,61 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from unittest import TestCase -from unittest.mock import MagicMock -from azure.communication.identity._shared.user_credential import CommunicationTokenCredential -from azure.communication.identity._shared.user_token_refresh_options import CommunicationTokenRefreshOptions -from azure.communication.identity._shared.utils import create_access_token - - -class TestCommunicationTokenCredential(TestCase): - sample_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."+\ - "eyJleHAiOjMyNTAzNjgwMDAwfQ.9i7FNNHHJT8cOzo-yrAUJyBSfJ-tPPk2emcHavOEpWc" - sample_token_expiry = 32503680000 - expired_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."+\ - "eyJleHAiOjEwMH0.1h_scYkNp-G98-O4cW6KvfJZwiz54uJMyeDACE4nypg" - - - def test_communicationtokencredential_decodes_token(self): - refresh_options = CommunicationTokenRefreshOptions(self.sample_token) - credential = CommunicationTokenCredential(refresh_options) - access_token = credential.get_token() - - self.assertEqual(access_token.token, self.sample_token) - - def test_communicationtokencredential_throws_if_invalid_token(self): - refresh_options = CommunicationTokenRefreshOptions("foo.bar.tar") - self.assertRaises(ValueError, lambda: CommunicationTokenCredential(refresh_options)) - - def test_communicationtokencredential_throws_if_nonstring_token(self): - refresh_options = CommunicationTokenRefreshOptions(454): - self.assertRaises(TypeError, lambda: CommunicationTokenCredential(refresh_options) - - def test_communicationtokencredential_static_token_returns_expired_token(self): - refresh_options = CommunicationTokenRefreshOptions(self.expired_token) - credential = CommunicationTokenCredential(refresh_options) - - self.assertEqual(credential.get_token().token, self.expired_token) - - def test_communicationtokencredential_token_expired_refresh_called(self): - refresher = MagicMock(return_value=self.sample_token) - refresh_options = CommunicationTokenRefreshOptions(self.sample_token, refresher) - access_token = CommunicationTokenCredential( - self.expired_token, - token_refresher=refresher).get_token() - refresher.assert_called_once() - self.assertEqual(access_token, self.sample_token) - - - def test_communicationtokencredential_token_expired_refresh_called_as_necessary(self): - refresher = MagicMock(return_value=create_access_token(self.expired_token)) - refresh_options = CommunicationTokenRefreshOptions(self.expired_token, refresher) - credential = CommunicationTokenCredential(refresh_options) - - credential.get_token() - access_token = credential.get_token() - - self.assertEqual(refresher.call_count, 2) - self.assertEqual(access_token.token, self.expired_token) diff --git a/sdk/communication/azure-communication-networktraversal/CHANGELOG.md b/sdk/communication/azure-communication-networktraversal/CHANGELOG.md index 7aff8ea34c71..e698d9db93af 100644 --- a/sdk/communication/azure-communication-networktraversal/CHANGELOG.md +++ b/sdk/communication/azure-communication-networktraversal/CHANGELOG.md @@ -43,4 +43,5 @@ The first preview of the Azure Communication Relay Client has the following feat - Added CommunicationRelayClient.get_relay_configuration in preview. + [read_me]: https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/communication/ diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/policy.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/policy.py index 301bfb545028..d4197ede0e38 100644 --- a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/policy.py +++ b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/policy.py @@ -21,7 +21,7 @@ def __init__(self, decode_url=False # type: bool ): # type: (...) -> None - super().__init__() + super(HMACCredentialsPolicy, self).__init__() if host.startswith("https://"): self._host = host.replace("https://", "") diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential.py index 9b5f17dcc95d..f4a89336ad58 100644 --- a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential.py +++ b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential.py @@ -3,56 +3,68 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from threading import Lock, Condition -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) +from threading import Lock, Condition, Timer, TIMEOUT_MAX, Event +from datetime import timedelta +from typing import Any +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The sync token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], AccessToken] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh' is enabled, the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' callable. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - - def __init__(self, - token, # type: str - **kwargs - ): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 + + def __init__(self, token: str, **kwargs: Any): + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None self._lock = Condition(Lock()) self._some_thread_refreshing = False + self._is_closed = Event() def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") - if not self._token_refresher or not self._token_expiring(): + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + self._update_token_and_reschedule() + return self._token + def _update_token_and_reschedule(self): should_this_thread_refresh = False - with self._lock: - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - self._wait_till_inprogress_thread_finish_refreshing() + self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -60,27 +72,74 @@ def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable - + new_token = self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise + if self._proactive_refresh: + self._schedule_refresh() return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + if timespan <= TIMEOUT_MAX: + self._timer = Timer(timespan, self._update_token_and_reschedule) + self._timer.daemon = True + self._timer.start() + + def _wait_till_lock_owner_finishes_refreshing(self): self._lock.release() self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) + + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on + + def __enter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() + return self + + def __exit__(self, *args): + self.close() + + def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential_async.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential_async.py index 52a99e7a4b6a..c41dc363c3e4 100644 --- a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_credential_async.py @@ -3,93 +3,149 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from asyncio import Condition, Lock -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, - Any -) +from asyncio import Condition, Lock, Event +from datetime import timedelta +from typing import Any +import sys +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token +from .utils_async import AsyncTimer class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The async token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], Awaitable[AccessToken]] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh is enabled', the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' function. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 def __init__(self, token: str, **kwargs: Any): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() - self._lock = Condition(Lock()) + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None + self._async_mutex = Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug(https://bugs.python.org/issue45416): + getattr(self._async_mutex, '_get_loop', lambda: None)() + self._lock = Condition(self._async_mutex) self._some_thread_refreshing = False + self._is_closed = Event() async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + await self._update_token_and_reschedule() + return self._token + async def _update_token_and_reschedule(self): should_this_thread_refresh = False - async with self._lock: - - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - await self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True break - if should_this_thread_refresh: try: - newtoken = await self._token_refresher() # pylint:disable=not-callable - + new_token = await self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") async with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise - + if self._proactive_refresh: + self._schedule_refresh() return self._token - async def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + + self._timer = AsyncTimer(timespan, self._update_token_and_reschedule) + self._timer.start() + + async def _wait_till_lock_owner_finishes_refreshing(self): + self._lock.release() await self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) - async def close(self) -> None: - pass + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on async def __aenter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() return self async def __aexit__(self, *args): await self.close() + + async def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_token_refresh_options.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_token_refresh_options.py deleted file mode 100644 index 6bdc0d456026..000000000000 --- a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/user_token_refresh_options.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -import six -from .utils import create_access_token - -class CommunicationTokenRefreshOptions(object): - """Options for refreshing CommunicationTokenCredential. - :param str token: The token used to authenticate to an Azure Communication service - :param token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError - """ - - def __init__(self, - token, # type: str - token_refresher=None - ): - # type: (str) -> None - if not isinstance(token, six.string_types): - raise TypeError("token must be a string.") - self._token = token - self._token_refresher = token_refresher - - def get_token(self): - """Return the the serialized JWT token.""" - return create_access_token(self._token) - - def get_token_refresher(self): - """Return the token refresher to provide capacity to fetch fresh token.""" - return self._token_refresher diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils.py index ce7330b2288a..0b3556bbaa44 100644 --- a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils.py +++ b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils.py @@ -6,15 +6,15 @@ import base64 import json -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -from datetime import datetime import calendar +from typing import (cast, + Tuple, + ) +from datetime import datetime from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + def _convert_datetime_to_utc_int(input_datetime): """ Converts DateTime in local time to the Epoch in UTC in second. @@ -26,6 +26,7 @@ def _convert_datetime_to_utc_int(input_datetime): """ return int(calendar.timegm(input_datetime.utctimetuple())) + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] if conn_str is None: @@ -53,9 +54,10 @@ def parse_connection_str(conn_str): return host, str(shared_access_key) + def get_current_utc_time(): # type: () -> str - return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + return str(datetime.now(tz=TZ_UTC).strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" def get_current_utc_as_int(): @@ -63,6 +65,7 @@ def get_current_utc_as_int(): current_utc_datetime = datetime.utcnow() return _convert_datetime_to_utc_int(current_utc_datetime) + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -84,18 +87,20 @@ def create_access_token(token): raise ValueError(token_parse_err_msg) try: - padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii') + padded_base64_payload = base64.b64decode( + parts[1] + '==').decode('ascii') payload = json.loads(padded_base64_payload) return AccessToken(token, _convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC))) except ValueError as val_error: raise ValueError(token_parse_err_msg) from val_error + def get_authentication_policy( - endpoint, # type: str - credential, # type: TokenCredential or str - decode_url=False, # type: bool - is_async=False, # type: bool + endpoint, # type: str + credential, # type: TokenCredential or str + decode_url=False, # type: bool + is_async=False, # type: bool ): # type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy """Returns the correct authentication policy based diff --git a/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils_async.py b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils_async.py new file mode 100644 index 000000000000..86e0e04d273c --- /dev/null +++ b/sdk/communication/azure-communication-networktraversal/azure/communication/networktraversal/_shared/utils_async.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio + + +class AsyncTimer: + """A non-blocking timer, that calls a function after a specified number of seconds: + :param int interval: time interval in seconds + :param callable callback: function to be called after the interval has elapsed + """ + + def __init__(self, interval, callback): + self._interval = interval + self._callback = callback + self._task = None + + def start(self): + self._task = asyncio.ensure_future(self._job()) + + async def _job(self): + await asyncio.sleep(self._interval) + await self._callback() + + def cancel(self): + if self._task is not None: + self._task.cancel() + self._task = None diff --git a/sdk/communication/azure-communication-networktraversal/tests/_shared/helper.py b/sdk/communication/azure-communication-networktraversal/tests/_shared/helper.py index 2f415d7f7f51..4d3585695f5a 100644 --- a/sdk/communication/azure-communication-networktraversal/tests/_shared/helper.py +++ b/sdk/communication/azure-communication-networktraversal/tests/_shared/helper.py @@ -4,8 +4,27 @@ # license information. # -------------------------------------------------------------------------- import re +import base64 from azure_devtools.scenario_tests import RecordingProcessor -from urllib.parse import urlparse +from datetime import datetime, timedelta +from functools import wraps +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse +import sys + +def generate_token_with_custom_expiry(valid_for_seconds): + return generate_token_with_custom_expiry_epoch((datetime.now() + timedelta(seconds=valid_for_seconds)).timestamp()) + +def generate_token_with_custom_expiry_epoch(expires_on_epoch): + expiry_json = f'{{"exp": {str(expires_on_epoch)} }}' + base64expiry = base64.b64encode( + expiry_json.encode('utf-8')).decode('utf-8').rstrip("=") + token_template = (f'''eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9. + {base64expiry}.adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs''') + return token_template + class URIIdentityReplacer(RecordingProcessor): """Replace the identity in request uri""" diff --git a/sdk/communication/azure-communication-phonenumbers/CHANGELOG.md b/sdk/communication/azure-communication-phonenumbers/CHANGELOG.md index 72298df0fa70..fb19947067b6 100644 --- a/sdk/communication/azure-communication-phonenumbers/CHANGELOG.md +++ b/sdk/communication/azure-communication-phonenumbers/CHANGELOG.md @@ -17,30 +17,34 @@ - Updates dependency `azure-core` to `1.20.0` ## 1.0.1 (2021-06-08) + ### Bug Fixes + - Fixed async client to use async bearer token credential policy instead of sync policy. ## 1.0.0 (2021-04-26) + - Stable release of `azure-communication-phonenumbers`. ## 1.0.0b5 (2021-03-29) ### Breaking Changes + - Renamed AcquiredPhoneNumber to PurchasedPhoneNumber - Renamed PhoneNumbersClient.get_phone_number and PhoneNumbersAsyncClient.get_phone_number to PhoneNumbersClient.get_purchased_phone_number -and PhoneNumbersAsyncClient.get_purchased_phone_number + and PhoneNumbersAsyncClient.get_purchased_phone_number - Renamed PhoneNumbersClient.list_acquired_phone_numbers and PhoneNumbersAsyncClient.list_acquired_phone_numbers to PhoneNumbersClient.list_purchased_phone_numbers -and PhoneNumbersAsyncClient.list_purchased_phone_numbers + and PhoneNumbersAsyncClient.list_purchased_phone_numbers ## 1.0.0b4 (2021-03-09) + - Dropped support for Python 3.5 ### Added -- Added PhoneNumbersClient (originally was part of the azure.communication.administration package). - - +- Added PhoneNumbersClient (originally was part of the azure.communication.administration package). + [read_me]: https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/communication/azure-communication-phonenumbers/README.md [documentation]: https://docs.microsoft.com/azure/communication-services/quickstarts/access-tokens?pivots=programming-language-python diff --git a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential.py b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential.py index 9b5f17dcc95d..f4a89336ad58 100644 --- a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential.py +++ b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential.py @@ -3,56 +3,68 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from threading import Lock, Condition -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) +from threading import Lock, Condition, Timer, TIMEOUT_MAX, Event +from datetime import timedelta +from typing import Any +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The sync token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], AccessToken] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh' is enabled, the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' callable. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - - def __init__(self, - token, # type: str - **kwargs - ): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 + + def __init__(self, token: str, **kwargs: Any): + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None self._lock = Condition(Lock()) self._some_thread_refreshing = False + self._is_closed = Event() def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") - if not self._token_refresher or not self._token_expiring(): + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + self._update_token_and_reschedule() + return self._token + def _update_token_and_reschedule(self): should_this_thread_refresh = False - with self._lock: - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - self._wait_till_inprogress_thread_finish_refreshing() + self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -60,27 +72,74 @@ def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable - + new_token = self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise + if self._proactive_refresh: + self._schedule_refresh() return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + if timespan <= TIMEOUT_MAX: + self._timer = Timer(timespan, self._update_token_and_reschedule) + self._timer.daemon = True + self._timer.start() + + def _wait_till_lock_owner_finishes_refreshing(self): self._lock.release() self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) + + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on + + def __enter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() + return self + + def __exit__(self, *args): + self.close() + + def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential_async.py b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential_async.py index 52a99e7a4b6a..c41dc363c3e4 100644 --- a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_credential_async.py @@ -3,93 +3,149 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from asyncio import Condition, Lock -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, - Any -) +from asyncio import Condition, Lock, Event +from datetime import timedelta +from typing import Any +import sys +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token +from .utils_async import AsyncTimer class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The async token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], Awaitable[AccessToken]] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh is enabled', the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' function. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 def __init__(self, token: str, **kwargs: Any): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() - self._lock = Condition(Lock()) + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None + self._async_mutex = Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug(https://bugs.python.org/issue45416): + getattr(self._async_mutex, '_get_loop', lambda: None)() + self._lock = Condition(self._async_mutex) self._some_thread_refreshing = False + self._is_closed = Event() async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + await self._update_token_and_reschedule() + return self._token + async def _update_token_and_reschedule(self): should_this_thread_refresh = False - async with self._lock: - - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - await self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True break - if should_this_thread_refresh: try: - newtoken = await self._token_refresher() # pylint:disable=not-callable - + new_token = await self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") async with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise - + if self._proactive_refresh: + self._schedule_refresh() return self._token - async def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + + self._timer = AsyncTimer(timespan, self._update_token_and_reschedule) + self._timer.start() + + async def _wait_till_lock_owner_finishes_refreshing(self): + self._lock.release() await self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) - async def close(self) -> None: - pass + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on async def __aenter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() return self async def __aexit__(self, *args): await self.close() + + async def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_token_refresh_options.py b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_token_refresh_options.py deleted file mode 100644 index 6bdc0d456026..000000000000 --- a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/user_token_refresh_options.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -import six -from .utils import create_access_token - -class CommunicationTokenRefreshOptions(object): - """Options for refreshing CommunicationTokenCredential. - :param str token: The token used to authenticate to an Azure Communication service - :param token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError - """ - - def __init__(self, - token, # type: str - token_refresher=None - ): - # type: (str) -> None - if not isinstance(token, six.string_types): - raise TypeError("token must be a string.") - self._token = token - self._token_refresher = token_refresher - - def get_token(self): - """Return the the serialized JWT token.""" - return create_access_token(self._token) - - def get_token_refresher(self): - """Return the token refresher to provide capacity to fetch fresh token.""" - return self._token_refresher diff --git a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils.py b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils.py index c9255a4217d7..0b3556bbaa44 100644 --- a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils.py +++ b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils.py @@ -6,15 +6,15 @@ import base64 import json -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -from datetime import datetime import calendar +from typing import (cast, + Tuple, + ) +from datetime import datetime from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + def _convert_datetime_to_utc_int(input_datetime): """ Converts DateTime in local time to the Epoch in UTC in second. @@ -26,6 +26,7 @@ def _convert_datetime_to_utc_int(input_datetime): """ return int(calendar.timegm(input_datetime.utctimetuple())) + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] if conn_str is None: @@ -53,9 +54,10 @@ def parse_connection_str(conn_str): return host, str(shared_access_key) + def get_current_utc_time(): # type: () -> str - return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + return str(datetime.now(tz=TZ_UTC).strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" def get_current_utc_as_int(): @@ -63,6 +65,7 @@ def get_current_utc_as_int(): current_utc_datetime = datetime.utcnow() return _convert_datetime_to_utc_int(current_utc_datetime) + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -84,18 +87,20 @@ def create_access_token(token): raise ValueError(token_parse_err_msg) try: - padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii') + padded_base64_payload = base64.b64decode( + parts[1] + '==').decode('ascii') payload = json.loads(padded_base64_payload) return AccessToken(token, _convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC))) - except ValueError: - raise ValueError(token_parse_err_msg) + except ValueError as val_error: + raise ValueError(token_parse_err_msg) from val_error + def get_authentication_policy( - endpoint, # type: str - credential, # type: TokenCredential or str - decode_url=False, # type: bool - is_async=False, # type: bool + endpoint, # type: str + credential, # type: TokenCredential or str + decode_url=False, # type: bool + is_async=False, # type: bool ): # type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy """Returns the correct authentication policy based diff --git a/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils_async.py b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils_async.py new file mode 100644 index 000000000000..86e0e04d273c --- /dev/null +++ b/sdk/communication/azure-communication-phonenumbers/azure/communication/phonenumbers/_shared/utils_async.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio + + +class AsyncTimer: + """A non-blocking timer, that calls a function after a specified number of seconds: + :param int interval: time interval in seconds + :param callable callback: function to be called after the interval has elapsed + """ + + def __init__(self, interval, callback): + self._interval = interval + self._callback = callback + self._task = None + + def start(self): + self._task = asyncio.ensure_future(self._job()) + + async def _job(self): + await asyncio.sleep(self._interval) + await self._callback() + + def cancel(self): + if self._task is not None: + self._task.cancel() + self._task = None diff --git a/sdk/communication/azure-communication-phonenumbers/test/_shared/helper.py b/sdk/communication/azure-communication-phonenumbers/test/_shared/helper.py new file mode 100644 index 000000000000..4d3585695f5a --- /dev/null +++ b/sdk/communication/azure-communication-phonenumbers/test/_shared/helper.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import re +import base64 +from azure_devtools.scenario_tests import RecordingProcessor +from datetime import datetime, timedelta +from functools import wraps +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse +import sys + +def generate_token_with_custom_expiry(valid_for_seconds): + return generate_token_with_custom_expiry_epoch((datetime.now() + timedelta(seconds=valid_for_seconds)).timestamp()) + +def generate_token_with_custom_expiry_epoch(expires_on_epoch): + expiry_json = f'{{"exp": {str(expires_on_epoch)} }}' + base64expiry = base64.b64encode( + expiry_json.encode('utf-8')).decode('utf-8').rstrip("=") + token_template = (f'''eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9. + {base64expiry}.adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs''') + return token_template + + +class URIIdentityReplacer(RecordingProcessor): + """Replace the identity in request uri""" + def process_request(self, request): + resource = (urlparse(request.uri).netloc).split('.')[0] + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + return request + + def process_response(self, response): + if 'url' in response: + response['url'] = re.sub('/identities/([^/?]+)', '/identities/sanitized', response['url']) + return response \ No newline at end of file diff --git a/sdk/communication/azure-communication-sms/CHANGELOG.md b/sdk/communication/azure-communication-sms/CHANGELOG.md index c9ebba5e89d4..9c42e376d729 100644 --- a/sdk/communication/azure-communication-sms/CHANGELOG.md +++ b/sdk/communication/azure-communication-sms/CHANGELOG.md @@ -12,15 +12,19 @@ Python 2.7 is no longer supported. Please use Python version 3.6 or later. ## 1.0.1 (2021-06-08) + ### Bug Fixes -- Fixed async client to use async bearer token credential policy instead of sync policy. +- Fixed async client to use async bearer token credential policy instead of sync policy. ## 1.0.0 (2021-03-29) + - Stable release of `azure-communication-sms`. ## 1.0.0b6 (2021-03-09) + ### Added + - Added support for Azure Active Directory authentication. - Added support for 1:N SMS messaging. - Added support for SMS idempotency. @@ -29,22 +33,28 @@ Python 2.7 is no longer supported. Please use Python version 3.6 or later. - The SmsClient constructor uses type `TokenCredential` and `AsyncTokenCredential` for the credential parameter. ### Breaking + - Send method takes in strings for phone numbers instead of `PhoneNumberIdentifier`. - Send method returns a list of `SmsSendResult`s instead of a `SendSmsResponse`. - Dropped support for Python 3.5 ## 1.0.0b4 (2020-11-16) + - Updated `azure-communication-sms` version. ### Breaking Changes + - Replaced CommunicationUser with CommunicationUserIdentifier. - Replaced PhoneNumber with PhoneNumberIdentifier. ## 1.0.0b3 (2020-10-07) + - Add dependency to `azure-communication-nspkg` package, to support py2 ## 1.0.0b2 (2020-10-06) + - Updated `azure-communication-sms` version. ## 1.0.0b1 (2020-09-22) + - Preview release of the package. diff --git a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential.py b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential.py index 9b5f17dcc95d..f4a89336ad58 100644 --- a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential.py +++ b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential.py @@ -3,56 +3,68 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from threading import Lock, Condition -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) +from threading import Lock, Condition, Timer, TIMEOUT_MAX, Event +from datetime import timedelta +from typing import Any +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The sync token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], AccessToken] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh' is enabled, the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' callable. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 - - def __init__(self, - token, # type: str - **kwargs - ): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 + + def __init__(self, token: str, **kwargs: Any): + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None self._lock = Condition(Lock()) self._some_thread_refreshing = False + self._is_closed = Event() def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") - if not self._token_refresher or not self._token_expiring(): + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + self._update_token_and_reschedule() + return self._token + def _update_token_and_reschedule(self): should_this_thread_refresh = False - with self._lock: - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - self._wait_till_inprogress_thread_finish_refreshing() + self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True @@ -60,27 +72,74 @@ def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument if should_this_thread_refresh: try: - newtoken = self._token_refresher() # pylint:disable=not-callable - + new_token = self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise + if self._proactive_refresh: + self._schedule_refresh() return self._token - def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + if timespan <= TIMEOUT_MAX: + self._timer = Timer(timespan, self._update_token_and_reschedule) + self._timer.daemon = True + self._timer.start() + + def _wait_till_lock_owner_finishes_refreshing(self): self._lock.release() self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) + + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on + + def __enter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() + return self + + def __exit__(self, *args): + self.close() + + def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential_async.py b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential_async.py index 52a99e7a4b6a..c41dc363c3e4 100644 --- a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential_async.py +++ b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_credential_async.py @@ -3,93 +3,149 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- -from asyncio import Condition, Lock -from datetime import timedelta -from typing import ( # pylint: disable=unused-import - cast, - Tuple, - Any -) +from asyncio import Condition, Lock, Event +from datetime import timedelta +from typing import Any +import sys +import six from .utils import get_current_utc_as_int -from .user_token_refresh_options import CommunicationTokenRefreshOptions +from .utils import create_access_token +from .utils_async import AsyncTimer class CommunicationTokenCredential(object): """Credential type used for authenticating to an Azure Communication service. - :param str token: The token used to authenticate to an Azure Communication service - :keyword token_refresher: The async token refresher to provide capacity to fetch fresh token - :raises: TypeError + :param str token: The token used to authenticate to an Azure Communication service. + :keyword token_refresher: The async token refresher to provide capacity to fetch a fresh token. + The returned token must be valid (expiration date must be in the future). + :paramtype token_refresher: Callable[[], Awaitable[AccessToken]] + :keyword bool proactive_refresh: Whether to refresh the token proactively or not. + If the proactive refreshing is enabled ('proactive_refresh' is true), the credential will use + a background thread to attempt to refresh the token within 10 minutes before the cached token expires, + the proactive refresh will request a new token by calling the 'token_refresher' callback. + When 'proactive_refresh is enabled', the Credential object must be either run within a context manager + or the 'close' method must be called once the object usage has been finished. + :raises: TypeError if paramater 'token' is not a string + :raises: ValueError if the 'proactive_refresh' is enabled without providing the 'token_refresher' function. """ _ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2 + _DEFAULT_AUTOREFRESH_INTERVAL_MINUTES = 10 def __init__(self, token: str, **kwargs: Any): - token_refresher = kwargs.pop('token_refresher', None) - communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token, - token_refresher=token_refresher) - self._token = communication_token_refresh_options.get_token() - self._token_refresher = communication_token_refresh_options.get_token_refresher() - self._lock = Condition(Lock()) + if not isinstance(token, six.string_types): + raise TypeError("Token must be a string.") + self._token = create_access_token(token) + self._token_refresher = kwargs.pop('token_refresher', None) + self._proactive_refresh = kwargs.pop('proactive_refresh', False) + if(self._proactive_refresh and self._token_refresher is None): + raise ValueError("When 'proactive_refresh' is True, 'token_refresher' must not be None.") + self._timer = None + self._async_mutex = Lock() + if sys.version_info[:3] == (3, 10, 0): + # Workaround for Python 3.10 bug(https://bugs.python.org/issue45416): + getattr(self._async_mutex, '_get_loop', lambda: None)() + self._lock = Condition(self._async_mutex) self._some_thread_refreshing = False + self._is_closed = Event() async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument # type (*str, **Any) -> AccessToken """The value of the configured token. :rtype: ~azure.core.credentials.AccessToken """ - if not self._token_refresher or not self._token_expiring(): + if self._proactive_refresh and self._is_closed.is_set(): + raise RuntimeError("An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + + if not self._token_refresher or not self._is_token_expiring_soon(self._token): return self._token + await self._update_token_and_reschedule() + return self._token + async def _update_token_and_reschedule(self): should_this_thread_refresh = False - async with self._lock: - - while self._token_expiring(): + while self._is_token_expiring_soon(self._token): if self._some_thread_refreshing: - if self._is_currenttoken_valid(): + if self._is_token_valid(self._token): return self._token - - await self._wait_till_inprogress_thread_finish_refreshing() + await self._wait_till_lock_owner_finishes_refreshing() else: should_this_thread_refresh = True self._some_thread_refreshing = True break - if should_this_thread_refresh: try: - newtoken = await self._token_refresher() # pylint:disable=not-callable - + new_token = await self._token_refresher() + if not self._is_token_valid(new_token): + raise ValueError( + "The token returned from the token_refresher is expired.") async with self._lock: - self._token = newtoken + self._token = new_token self._some_thread_refreshing = False self._lock.notify_all() except: async with self._lock: self._some_thread_refreshing = False self._lock.notify_all() - raise - + if self._proactive_refresh: + self._schedule_refresh() return self._token - async def _wait_till_inprogress_thread_finish_refreshing(self): + def _schedule_refresh(self): + if self._is_closed.is_set(): + return + if self._timer is not None: + self._timer.cancel() + + token_ttl = self._token.expires_on - get_current_utc_as_int() + + if self._is_token_expiring_soon(self._token): + # Schedule the next refresh for when it reaches a certain percentage of the remaining lifetime. + timespan = token_ttl // 2 + else: + # Schedule the next refresh for when it gets in to the soon-to-expire window. + timespan = token_ttl - timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES).total_seconds() + + self._timer = AsyncTimer(timespan, self._update_token_and_reschedule) + self._timer.start() + + async def _wait_till_lock_owner_finishes_refreshing(self): + self._lock.release() await self._lock.acquire() - def _token_expiring(self): - return self._token.expires_on - get_current_utc_as_int() <\ - timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds() - - def _is_currenttoken_valid(self): - return get_current_utc_as_int() < self._token.expires_on + def _is_token_expiring_soon(self, token): + if self._proactive_refresh: + interval = timedelta( + minutes=self._DEFAULT_AUTOREFRESH_INTERVAL_MINUTES) + else: + interval = timedelta( + minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES) + return ((token.expires_on - get_current_utc_as_int()) + < interval.total_seconds()) - async def close(self) -> None: - pass + @classmethod + def _is_token_valid(cls, token): + return get_current_utc_as_int() < token.expires_on async def __aenter__(self): + if self._proactive_refresh: + if self._is_closed.is_set(): + raise RuntimeError( + "An instance of CommunicationTokenCredential cannot be reused once it has been closed.") + self._schedule_refresh() return self async def __aexit__(self, *args): await self.close() + + async def close(self) -> None: + if self._timer is not None: + self._timer.cancel() + self._timer = None + self._is_closed.set() diff --git a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_token_refresh_options.py b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_token_refresh_options.py deleted file mode 100644 index 6bdc0d456026..000000000000 --- a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/user_token_refresh_options.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -import six -from .utils import create_access_token - -class CommunicationTokenRefreshOptions(object): - """Options for refreshing CommunicationTokenCredential. - :param str token: The token used to authenticate to an Azure Communication service - :param token_refresher: The token refresher to provide capacity to fetch fresh token - :raises: TypeError - """ - - def __init__(self, - token, # type: str - token_refresher=None - ): - # type: (str) -> None - if not isinstance(token, six.string_types): - raise TypeError("token must be a string.") - self._token = token - self._token_refresher = token_refresher - - def get_token(self): - """Return the the serialized JWT token.""" - return create_access_token(self._token) - - def get_token_refresher(self): - """Return the token refresher to provide capacity to fetch fresh token.""" - return self._token_refresher diff --git a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils.py b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils.py index c9255a4217d7..0b3556bbaa44 100644 --- a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils.py +++ b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils.py @@ -6,15 +6,15 @@ import base64 import json -from typing import ( # pylint: disable=unused-import - cast, - Tuple, -) -from datetime import datetime import calendar +from typing import (cast, + Tuple, + ) +from datetime import datetime from msrest.serialization import TZ_UTC from azure.core.credentials import AccessToken + def _convert_datetime_to_utc_int(input_datetime): """ Converts DateTime in local time to the Epoch in UTC in second. @@ -26,6 +26,7 @@ def _convert_datetime_to_utc_int(input_datetime): """ return int(calendar.timegm(input_datetime.utctimetuple())) + def parse_connection_str(conn_str): # type: (str) -> Tuple[str, str, str, str] if conn_str is None: @@ -53,9 +54,10 @@ def parse_connection_str(conn_str): return host, str(shared_access_key) + def get_current_utc_time(): # type: () -> str - return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" + return str(datetime.now(tz=TZ_UTC).strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT" def get_current_utc_as_int(): @@ -63,6 +65,7 @@ def get_current_utc_as_int(): current_utc_datetime = datetime.utcnow() return _convert_datetime_to_utc_int(current_utc_datetime) + def create_access_token(token): # type: (str) -> azure.core.credentials.AccessToken """Creates an instance of azure.core.credentials.AccessToken from a @@ -84,18 +87,20 @@ def create_access_token(token): raise ValueError(token_parse_err_msg) try: - padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii') + padded_base64_payload = base64.b64decode( + parts[1] + '==').decode('ascii') payload = json.loads(padded_base64_payload) return AccessToken(token, _convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC))) - except ValueError: - raise ValueError(token_parse_err_msg) + except ValueError as val_error: + raise ValueError(token_parse_err_msg) from val_error + def get_authentication_policy( - endpoint, # type: str - credential, # type: TokenCredential or str - decode_url=False, # type: bool - is_async=False, # type: bool + endpoint, # type: str + credential, # type: TokenCredential or str + decode_url=False, # type: bool + is_async=False, # type: bool ): # type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy """Returns the correct authentication policy based diff --git a/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils_async.py b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils_async.py new file mode 100644 index 000000000000..86e0e04d273c --- /dev/null +++ b/sdk/communication/azure-communication-sms/azure/communication/sms/_shared/utils_async.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- + +import asyncio + + +class AsyncTimer: + """A non-blocking timer, that calls a function after a specified number of seconds: + :param int interval: time interval in seconds + :param callable callback: function to be called after the interval has elapsed + """ + + def __init__(self, interval, callback): + self._interval = interval + self._callback = callback + self._task = None + + def start(self): + self._task = asyncio.ensure_future(self._job()) + + async def _job(self): + await asyncio.sleep(self._interval) + await self._callback() + + def cancel(self): + if self._task is not None: + self._task.cancel() + self._task = None diff --git a/sdk/communication/azure-communication-sms/tests/_shared/helper.py b/sdk/communication/azure-communication-sms/tests/_shared/helper.py new file mode 100644 index 000000000000..4d3585695f5a --- /dev/null +++ b/sdk/communication/azure-communication-sms/tests/_shared/helper.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import re +import base64 +from azure_devtools.scenario_tests import RecordingProcessor +from datetime import datetime, timedelta +from functools import wraps +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse +import sys + +def generate_token_with_custom_expiry(valid_for_seconds): + return generate_token_with_custom_expiry_epoch((datetime.now() + timedelta(seconds=valid_for_seconds)).timestamp()) + +def generate_token_with_custom_expiry_epoch(expires_on_epoch): + expiry_json = f'{{"exp": {str(expires_on_epoch)} }}' + base64expiry = base64.b64encode( + expiry_json.encode('utf-8')).decode('utf-8').rstrip("=") + token_template = (f'''eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9. + {base64expiry}.adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs''') + return token_template + + +class URIIdentityReplacer(RecordingProcessor): + """Replace the identity in request uri""" + def process_request(self, request): + resource = (urlparse(request.uri).netloc).split('.')[0] + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + request.uri = re.sub('/identities/([^/?]+)', '/identities/sanitized', request.uri) + request.uri = re.sub(resource, 'sanitized', request.uri) + return request + + def process_response(self, response): + if 'url' in response: + response['url'] = re.sub('/identities/([^/?]+)', '/identities/sanitized', response['url']) + return response \ No newline at end of file