Skip to content

Commit

Permalink
factor base class out of ManagedIdentityClient
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Dec 16, 2020
1 parent 5075cab commit 0383294
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]


class ManagedIdentityClient(object):
class ManagedIdentityClientBase(ABC):
# pylint:disable=missing-client-constructor-parameter-credential
def __init__(self, request_factory, client_id=None, **kwargs):
# type: (Callable[[str, dict], HttpRequest], Optional[str], **Any) -> None
Expand All @@ -55,24 +55,6 @@ def __init__(self, request_factory, client_id=None, **kwargs):

self._request_factory = request_factory

def get_cached_token(self, *scopes):
# type: (*str) -> Optional[AccessToken]
resource = _scopes_to_resource(*scopes)
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource])
for token in tokens:
if token["expires_on"] > time.time():
return AccessToken(token["secret"], token["expires_on"])
return None

def request_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
resource = _scopes_to_resource(*scopes)
request = self._request_factory(resource, self._identity_config)
request_time = int(time.time())
response = self._pipeline.run(request)
token = self._process_response(response, request_time)
return token

def _process_response(self, response, request_time):
# type: (PipelineResponse, int) -> AccessToken

Expand Down Expand Up @@ -102,6 +84,34 @@ def _process_response(self, response, request_time):

return token

def get_cached_token(self, *scopes):
# type: (*str) -> Optional[AccessToken]
resource = _scopes_to_resource(*scopes)
tokens = self._cache.find(TokenCache.CredentialType.ACCESS_TOKEN, target=[resource])
for token in tokens:
if token["expires_on"] > time.time():
return AccessToken(token["secret"], token["expires_on"])
return None

@abc.abstractmethod
def request_token(self, *scopes, **kwargs):
pass

@abc.abstractmethod
def _build_pipeline(self, config, policies=None, transport=None, **kwargs):
pass


class ManagedIdentityClient(ManagedIdentityClientBase):
def request_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (*str, **Any) -> AccessToken
resource = _scopes_to_resource(*scopes)
request = self._request_factory(resource, self._identity_config)
request_time = int(time.time())
response = self._pipeline.run(request)
token = self._process_response(response, request_time)
return token

def _build_pipeline(self, config, policies=None, transport=None, **kwargs): # pylint:disable=no-self-use
# type: (Configuration, Optional[List[PolicyType]], Optional[HttpTransport], **Any) -> Pipeline
if policies is None: # [] is a valid policy list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from azure.core.pipeline.policies import AsyncRetryPolicy

from ..._internal import _scopes_to_resource
from ..._internal.managed_identity_client import ManagedIdentityClient, _get_policies
from ..._internal.managed_identity_client import ManagedIdentityClientBase, _get_policies

if TYPE_CHECKING:
# pylint:disable=ungrouped-imports
Expand All @@ -23,7 +23,7 @@


# pylint:disable=async-client-bad-name,missing-client-constructor-parameter-credential
class AsyncManagedIdentityClient(ManagedIdentityClient):
class AsyncManagedIdentityClient(ManagedIdentityClientBase):
def __init__(self, request_factory: "Callable[[str, dict], HttpRequest]", **kwargs: "Any") -> None:
config = _get_configuration(**kwargs)
super().__init__(request_factory, _config=config, **kwargs)
Expand Down

0 comments on commit 0383294

Please sign in to comment.