Skip to content

Commit

Permalink
Sync the _shared folder for Communication (#21777)
Browse files Browse the repository at this point in the history
* identity and network traversal unified with sms

* unified get_current_utc_as_int + added a test

* adjusted variable name to match the comment

* added owners to enforce reviews

* token deserialization fix + test
  • Loading branch information
petrsvihlik authored Jan 12, 2022
1 parent e9ca031 commit cdeb608
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 118 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
/sdk/communication/azure-communication-phonenumbers/ @RoyHerrod @danielav7 @whisper6284 @AlonsoMondal
/sdk/communication/azure-communication-sms/ @RoyHerrod @arifibrahim4
/sdk/communication/azure-communication-identity/ @Azure/acs-identity-sdk
/sdk/communication/**/_shared/ @Azure/acs-identity-sdk

# PRLabel: %KeyVault
/sdk/keyvault/ @schaabs @mccoyp @YalinLi0312
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken

def _convert_datetime_to_utc_int(expires_on):
def _convert_datetime_to_utc_int(input_datetime):
"""
Converts DateTime in local time to the Epoch in UTC in second.
Expand All @@ -24,7 +24,7 @@ def _convert_datetime_to_utc_int(expires_on):
:return: Integer
:rtype: int
"""
return int(calendar.timegm(expires_on.utctimetuple()))
return int(calendar.timegm(input_datetime.utctimetuple()))

def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
Expand Down Expand Up @@ -87,7 +87,7 @@ def create_access_token(token):
padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii')
payload = json.loads(padded_base64_payload)
return AccessToken(token,
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp']).replace(tzinfo=TZ_UTC)))
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC)))
except ValueError:
raise ValueError(token_parse_err_msg)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import unittest
from datetime import datetime
from azure.communication.chat._shared.utils import create_access_token
from azure.communication.chat._shared.utils import get_current_utc_as_int
import dateutil.tz
import base64

from azure.communication.chat._shared.utils import(
_convert_datetime_to_utc_int
)

class UtilsTest(unittest.TestCase):

@staticmethod
def get_token_with_custom_expiry(expires_on):
expiry_json = '{"exp": ' + str(expires_on) + '}'
base64expiry = base64.b64encode(
expiry_json.encode('utf-8')).decode('utf-8').rstrip("=")
token_template = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." +\
base64expiry + ".adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs"
return token_template

def test_convert_datetime_to_utc_int(self):
# UTC
utc_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.tzutc()))
assert utc_time_in_sec == 0
assert utc_time_in_sec == 0
# UTC naive (without a timezone specified)
utc_naive_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0))
assert utc_naive_time_in_sec == 0
# PST is UTC-8
pst_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.gettz('America/Vancouver')))
assert pst_time_in_sec == 8 * 3600
Expand All @@ -22,5 +37,16 @@ def test_convert_datetime_to_utc_int(self):
cst_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.gettz('Asia/Shanghai')))
assert cst_time_in_sec == -8 * 3600


def test_access_token_expiry_deserialized_correctly_from_payload(self):
start_timestamp = get_current_utc_as_int()
token_validity_minutes = 60
token_expiry = start_timestamp + token_validity_minutes * 60

token = create_access_token(
self.get_token_with_custom_expiry(token_expiry))

self.assertEqual(token.expires_on, token_expiry)

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
# license information.
# --------------------------------------------------------------------------
from asyncio import Condition, Lock
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Any
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
:keyword token_refresher: The token refresher to provide capacity to fetch fresh token
:keyword token_refresher: The async token refresher to provide capacity to fetch fresh token
:raises: TypeError
"""

_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
def __init__(self, token: str, **kwargs: Any):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,25 +33,24 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""

if not self._token_refresher or not self._token_expiring():
return self._token

should_this_thread_refresh = False

with self._lock:
async with self._lock:

while self._token_expiring():
if self._some_thread_refreshing:
if self._is_currenttoken_valid():
return self._token

self._wait_till_inprogress_thread_finish_refreshing()
await self._wait_till_inprogress_thread_finish_refreshing()
else:
should_this_thread_refresh = True
self._some_thread_refreshing = True
Expand All @@ -62,32 +59,37 @@ def get_token(self):

if should_this_thread_refresh:
try:
newtoken = self._token_refresher() # pylint:disable=not-callable
newtoken = await self._token_refresher() # pylint:disable=not-callable

with self._lock:
async with self._lock:
self._token = newtoken
self._some_thread_refreshing = False
self._lock.notify_all()
except:
with self._lock:
async with self._lock:
self._some_thread_refreshing = False
self._lock.notify_all()

raise

return self._token

def _wait_till_inprogress_thread_finish_refreshing(self):
async def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.release()
self._lock.acquire()
await self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on
return get_current_utc_as_int() < self._token.expires_on

async def close(self) -> None:
pass

async def __aenter__(self):
return self

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
async def __aexit__(self, *args):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import base64
import json
import time
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Expand All @@ -16,8 +15,16 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken

def _convert_datetime_to_utc_int(expires_on):
return int(calendar.timegm(expires_on.utctimetuple()))
def _convert_datetime_to_utc_int(input_datetime):
"""
Converts DateTime in local time to the Epoch in UTC in second.
:param input_datetime: Input datetime
:type input_datetime: datetime
:return: Integer
:rtype: int
"""
return int(calendar.timegm(input_datetime.utctimetuple()))

def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
Expand Down Expand Up @@ -50,9 +57,10 @@ def get_current_utc_time():
# type: () -> str
return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT"


def get_current_utc_as_int():
# type: () -> int
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime = datetime.utcnow()
return _convert_datetime_to_utc_int(current_utc_datetime)

def create_access_token(token):
Expand All @@ -79,14 +87,10 @@ def create_access_token(token):
padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii')
payload = json.loads(padded_base64_payload)
return AccessToken(token,
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp']).replace(tzinfo=TZ_UTC)))
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC)))
except ValueError:
raise ValueError(token_parse_err_msg)

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())

def get_authentication_policy(
endpoint, # type: str
credential, # type: TokenCredential or str
Expand Down Expand Up @@ -122,7 +126,3 @@ def get_authentication_policy(

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self,
access_key, # type: str
decode_url=False # type: bool
):
# pylint: disable=bad-option-value,useless-object-inheritance,disable=super-with-arguments
# type: (...) -> None
super(HMACCredentialsPolicy, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Loading

0 comments on commit cdeb608

Please sign in to comment.