Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync the _shared folder for Communication #21777

Merged
merged 5 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
/sdk/communication/azure-communication-phonenumbers/ @RoyHerrod @danielav7 @whisper6284 @AlonsoMondal
/sdk/communication/azure-communication-sms/ @RoyHerrod @arifibrahim4
/sdk/communication/azure-communication-identity/ @Azure/acs-identity-sdk
/sdk/communication/**/_shared/ @Azure/acs-identity-sdk

# PRLabel: %KeyVault
/sdk/keyvault/ @schaabs @chlowell @mccoyp @YalinLi0312
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken

def _convert_datetime_to_utc_int(expires_on):
def _convert_datetime_to_utc_int(input_datetime):
petrsvihlik marked this conversation as resolved.
Show resolved Hide resolved
"""
Converts DateTime in local time to the Epoch in UTC in second.

Expand All @@ -24,7 +24,7 @@ def _convert_datetime_to_utc_int(expires_on):
:return: Integer
:rtype: int
"""
return int(calendar.timegm(expires_on.utctimetuple()))
return int(calendar.timegm(input_datetime.utctimetuple()))

def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
Expand Down Expand Up @@ -87,7 +87,7 @@ def create_access_token(token):
padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii')
payload = json.loads(padded_base64_payload)
return AccessToken(token,
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp']).replace(tzinfo=TZ_UTC)))
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC)))
except ValueError:
raise ValueError(token_parse_err_msg)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
import unittest
from datetime import datetime
from azure.communication.chat._shared.utils import create_access_token
from azure.communication.chat._shared.utils import get_current_utc_as_int
import dateutil.tz
import base64

from azure.communication.chat._shared.utils import(
_convert_datetime_to_utc_int
)

class UtilsTest(unittest.TestCase):

@staticmethod
def get_token_with_custom_expiry(expires_on):
expiry_json = '{"exp": ' + str(expires_on) + '}'
base64expiry = base64.b64encode(
expiry_json.encode('utf-8')).decode('utf-8').rstrip("=")
token_template = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9." +\
base64expiry + ".adM-ddBZZlQ1WlN3pdPBOF5G4Wh9iZpxNP_fSvpF4cWs"
return token_template

def test_convert_datetime_to_utc_int(self):
# UTC
utc_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.tzutc()))
assert utc_time_in_sec == 0
assert utc_time_in_sec == 0
# UTC naive (without a timezone specified)
utc_naive_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0))
assert utc_naive_time_in_sec == 0
# PST is UTC-8
pst_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.gettz('America/Vancouver')))
assert pst_time_in_sec == 8 * 3600
Expand All @@ -22,5 +37,16 @@ def test_convert_datetime_to_utc_int(self):
cst_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.gettz('Asia/Shanghai')))
assert cst_time_in_sec == -8 * 3600


def test_access_token_expiry_deserialized_correctly_from_payload(self):
start_timestamp = get_current_utc_as_int()
token_validity_minutes = 60
token_expiry = start_timestamp + token_validity_minutes * 60

token = create_access_token(
self.get_token_with_custom_expiry(token_expiry))

self.assertEqual(token.expires_on, token_expiry)

if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
# license information.
# --------------------------------------------------------------------------
from asyncio import Condition, Lock
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Any
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
:keyword token_refresher: The token refresher to provide capacity to fetch fresh token
:keyword token_refresher: The async token refresher to provide capacity to fetch fresh token
:raises: TypeError
"""

_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
def __init__(self, token: str, **kwargs: Any):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,25 +33,24 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""

if not self._token_refresher or not self._token_expiring():
return self._token

should_this_thread_refresh = False

with self._lock:
async with self._lock:

while self._token_expiring():
if self._some_thread_refreshing:
if self._is_currenttoken_valid():
return self._token

self._wait_till_inprogress_thread_finish_refreshing()
await self._wait_till_inprogress_thread_finish_refreshing()
else:
should_this_thread_refresh = True
self._some_thread_refreshing = True
Expand All @@ -62,32 +59,37 @@ def get_token(self):

if should_this_thread_refresh:
try:
newtoken = self._token_refresher() # pylint:disable=not-callable
newtoken = await self._token_refresher() # pylint:disable=not-callable

with self._lock:
async with self._lock:
self._token = newtoken
self._some_thread_refreshing = False
self._lock.notify_all()
except:
with self._lock:
async with self._lock:
self._some_thread_refreshing = False
self._lock.notify_all()

raise

return self._token

def _wait_till_inprogress_thread_finish_refreshing(self):
async def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.release()
self._lock.acquire()
await self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on
return get_current_utc_as_int() < self._token.expires_on

async def close(self) -> None:
pass

async def __aenter__(self):
return self

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
async def __aexit__(self, *args):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import base64
import json
import time
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Expand All @@ -16,8 +15,16 @@
from msrest.serialization import TZ_UTC
from azure.core.credentials import AccessToken

def _convert_datetime_to_utc_int(expires_on):
return int(calendar.timegm(expires_on.utctimetuple()))
def _convert_datetime_to_utc_int(input_datetime):
"""
Converts DateTime in local time to the Epoch in UTC in second.

:param input_datetime: Input datetime
petrsvihlik marked this conversation as resolved.
Show resolved Hide resolved
:type input_datetime: datetime
:return: Integer
:rtype: int
"""
return int(calendar.timegm(input_datetime.utctimetuple()))

def parse_connection_str(conn_str):
# type: (str) -> Tuple[str, str, str, str]
Expand Down Expand Up @@ -50,9 +57,10 @@ def get_current_utc_time():
# type: () -> str
return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT"


def get_current_utc_as_int():
# type: () -> int
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime = datetime.utcnow()
return _convert_datetime_to_utc_int(current_utc_datetime)

def create_access_token(token):
Expand All @@ -79,14 +87,10 @@ def create_access_token(token):
padded_base64_payload = base64.b64decode(parts[1] + "==").decode('ascii')
payload = json.loads(padded_base64_payload)
return AccessToken(token,
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp']).replace(tzinfo=TZ_UTC)))
_convert_datetime_to_utc_int(datetime.fromtimestamp(payload['exp'], TZ_UTC)))
except ValueError:
raise ValueError(token_parse_err_msg)

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())

def get_authentication_policy(
endpoint, # type: str
credential, # type: TokenCredential or str
Expand Down Expand Up @@ -122,7 +126,3 @@ def get_authentication_policy(

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self,
access_key, # type: str
decode_url=False # type: bool
):
# pylint: disable=bad-option-value,useless-object-inheritance,disable=super-with-arguments
# type: (...) -> None
super(HMACCredentialsPolicy, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()
petrsvihlik marked this conversation as resolved.
Show resolved Hide resolved

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Loading