Skip to content

Commit

Permalink
smallish fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Dec 16, 2020
1 parent 9c5b00c commit 5075cab
Show file tree
Hide file tree
Showing 17 changed files with 80 additions and 66 deletions.
15 changes: 8 additions & 7 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from time import struct_time
from typing import Any, Dict, Iterable, Mapping, Optional, Union
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.transport import HttpTransport
from azure.core.pipeline.policies import HTTPPolicy
from azure.core.pipeline.policies import HTTPPolicy, SansIOHTTPPolicy

PolicyListType = List[Union[HTTPPolicy, SansIOHTTPPolicy]]


class AuthnClientBase(ABC):
Expand Down Expand Up @@ -166,10 +168,9 @@ def _parse_app_service_expires_on(expires_on):

raise ValueError("'{}' doesn't match the expected format".format(expires_on))

# TODO: public, factor out of request_token
def _prepare_request(
self,
method="POST", # type: Optional[str]
method="POST", # type: str
headers=None, # type: Optional[Mapping[str, str]]
form_data=None, # type: Optional[Mapping[str, str]]
params=None, # type: Optional[Dict[str, str]]
Expand Down Expand Up @@ -200,7 +201,7 @@ class AuthnClient(AuthnClientBase):
def __init__(
self,
config=None, # type: Optional[Configuration]
policies=None, # type: Optional[Iterable[HTTPPolicy]]
policies=None, # type: Optional[PolicyListType]
transport=None, # type: Optional[HttpTransport]
**kwargs # type: Any
):
Expand All @@ -217,13 +218,13 @@ def __init__(
]
if not transport:
transport = RequestsTransport(**kwargs)
self._pipeline = Pipeline(transport=transport, policies=policies)
self._pipeline = Pipeline(transport=transport, policies=policies) # type: Pipeline
super(AuthnClient, self).__init__(**kwargs)

def request_token(
self,
scopes, # type: Iterable[str]
method="POST", # type: Optional[str]
method="POST", # type: str
headers=None, # type: Optional[Mapping[str, str]]
form_data=None, # type: Optional[Mapping[str, str]]
params=None, # type: Optional[Dict[str, str]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def __init__(self, **kwargs):

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from typing import TYPE_CHECKING

from azure.core.exceptions import ClientAuthenticationError
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.pipeline.transport import HttpRequest
from azure.core.pipeline.policies import (
DistributedTracingPolicy,
HttpLoggingPolicy,
Expand All @@ -28,6 +27,7 @@
from typing import Any, List, Optional, Union
from azure.core.configuration import Configuration
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

PolicyType = Union[HTTPPolicy, SansIOHTTPPolicy]
Expand All @@ -40,24 +40,21 @@ def __init__(self, **kwargs):

url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
if not (url and imds):
# Azure Arc managed identity isn't available in this environment
self._client = None
return

identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()

self._client = ManagedIdentityClient(
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
**kwargs
)
self._available = url and imds
if self._available:
identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()

self._client = ManagedIdentityClient(
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
**kwargs
)

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Azure Arc managed identity configuration not found in environment"
)
Expand Down Expand Up @@ -125,7 +122,7 @@ class ArcChallengeAuthPolicy(HTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""

def send(self, request):
# type: (PipelineRequest) -> HttpResponse
# type: (PipelineRequest) -> PipelineResponse
request.http_request.headers["Metadata"] = "true"
response = self.next.send(request)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
TYPE_CHECKING = False

if TYPE_CHECKING:
from typing import Any
from azure.core.credentials import AccessToken
from typing import Any, List
from azure.core.credentials import AccessToken, TokenCredential

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(self, **kwargs):
exclude_cli_credential = kwargs.pop("exclude_cli_credential", False)
exclude_interactive_browser_credential = kwargs.pop("exclude_interactive_browser_credential", True)

credentials = []
credentials = [] # type: List[TokenCredential]
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, **kwargs))
if not exclude_managed_identity_credential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
if TYPE_CHECKING:
# pylint:disable=unused-import
from typing import Any, Optional, Type
from azure.core.credentials import TokenCredential

_LOGGER = logging.getLogger(__name__)

Expand All @@ -52,7 +53,7 @@ class ManagedIdentityCredential(object):

def __init__(self, **kwargs):
# type: (**Any) -> None
self._credential = None
self._credential = None # type: Optional[TokenCredential]
if os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
if os.environ.get(EnvironmentVariables.MSI_SECRET):
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def __init__(self, **kwargs):

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = ManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def _acquire_token_silent(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
"""Silently acquire a token from MSAL. Requires an AuthenticationRecord."""

# self._auth_record and ._app will not be None when this method is called by get_token
# but should either be None anyway (and to satisfy mypy) we raise
if self._app is None or self._auth_record is None:
raise CredentialUnavailableError("Initialization failed")

result = None

accounts_for_user = self._app.get_accounts(username=self._auth_record.username)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def get_default_authority():


def validate_tenant_id(tenant_id):
"""Raise ValueError if tenant_id is empty or contains a character invalid for a tenant id"""
# type: (str) -> None
"""Raise ValueError if tenant_id is empty or contains a character invalid for a tenant id"""
if not tenant_id or any(c not in VALID_TENANT_ID_CHARACTERS for c in tenant_id):
raise ValueError(
"Invalid tenant id provided. You can locate your tenant id by following the instructions here: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class GetTokenMixin(ABC):
def __init__(self, *args, **kwargs):
# type: (*Any, **Any) -> None
self._last_request_time = 0
super(GetTokenMixin, self).__init__(*args, **kwargs)

# https://github.com/python/mypy/issues/5887
super(GetTokenMixin, self).__init__(*args, **kwargs) # type: ignore

@abc.abstractmethod
def _acquire_token_silently(self, *scopes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class _SECRET_SCHEMA(ct.Structure):
_libsecret.secret_password_lookup_sync.restype = ct.c_char_p
_libsecret.secret_password_free.argtypes = [ct.c_char_p]
except OSError:
_libsecret = None
_libsecret = None # type: ignore


def _get_user_settings_path():
Expand Down
12 changes: 7 additions & 5 deletions sdk/identity/azure-identity/azure/identity/aio/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from .._internal.user_agent import USER_AGENT

if TYPE_CHECKING:
from typing import Any, Dict, Iterable, Mapping, Optional
from azure.core.pipeline.policies import HTTPPolicy
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from azure.core.pipeline.policies import AsyncHTTPPolicy, SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpTransport

PolicyListType = List[Union[AsyncHTTPPolicy, SansIOHTTPPolicy]]


class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name
"""Async authentication client"""
Expand All @@ -35,7 +37,7 @@ class AsyncAuthnClient(AuthnClientBase): # pylint:disable=async-client-bad-name
def __init__(
self,
config: "Optional[Configuration]" = None,
policies: "Optional[Iterable[HTTPPolicy]]" = None,
policies: "Optional[PolicyListType]" = None,
transport: "Optional[AsyncHttpTransport]" = None,
**kwargs: "Any"
) -> None:
Expand All @@ -51,7 +53,7 @@ def __init__(
]
if not transport:
transport = AioHttpTransport(**kwargs)
self._pipeline = AsyncPipeline(transport=transport, policies=policies)
self._pipeline = AsyncPipeline(transport=transport, policies=policies) # type: AsyncPipeline
super().__init__(**kwargs)

async def __aenter__(self):
Expand All @@ -67,7 +69,7 @@ async def close(self) -> None:
async def request_token( # pylint:disable=invalid-overridden-method
self,
scopes: "Iterable[str]",
method: "Optional[str]" = "POST",
method: str = "POST",
headers: "Optional[Mapping[str, str]]" = None,
form_data: "Optional[Mapping[str, str]]" = None,
params: "Optional[Dict[str, str]]" = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ def __init__(self, **kwargs: "Any") -> None:

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = AsyncManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

async def get_token( # pylint:disable=invalid-overridden-method
self, *scopes: str, **kwargs: "Any"
) -> "AccessToken":
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="App Service managed identity configuration not found in environment"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Any, List, Optional, Union
from azure.core.configuration import Configuration
from azure.core.credentials import AccessToken
from azure.core.pipeline import PipelineRequest
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy
from azure.core.pipeline.transport import AsyncHttpResponse

Expand All @@ -39,25 +39,23 @@ def __init__(self, **kwargs: "Any") -> None:
super().__init__()

url = os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT)
if not url:
# Azure Arc managed identity isn't available in this environment
self._client = None
return
identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()
client_args = dict(
kwargs,
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
)

self._client = AsyncManagedIdentityClient(**client_args)
imds = os.environ.get(EnvironmentVariables.IMDS_ENDPOINT)
self._available = url and imds
if self._available:
identity_config = kwargs.pop("_identity_config", None) or {}
config = _get_configuration()

self._client = AsyncManagedIdentityClient(
_identity_config=identity_config,
policies=_get_policies(config),
request_factory=functools.partial(_get_request, url),
**kwargs
)

async def get_token( # pylint:disable=invalid-overridden-method
self, *scopes: str, **kwargs: "Any"
) -> "AccessToken":
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
Expand Down Expand Up @@ -89,7 +87,7 @@ def _get_policies(config: "Configuration", **kwargs: "Any") -> "List[PolicyType]
class ArcChallengeAuthPolicy(AsyncHTTPPolicy):
"""Policy for handling Azure Arc's challenge authentication"""

async def send(self, request: "PipelineRequest") -> "AsyncHttpResponse":
async def send(self, request: "PipelineRequest") -> "PipelineResponse":
request.http_request.headers["Metadata"] = "true"
response = await self.next.send(request)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .vscode import VisualStudioCodeCredential

if TYPE_CHECKING:
from typing import Any
from typing import Any, List
from azure.core.credentials_async import AsyncTokenCredential

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, **kwargs: "Any") -> None:
exclude_managed_identity_credential = kwargs.pop("exclude_managed_identity_credential", False)
exclude_shared_token_cache_credential = kwargs.pop("exclude_shared_token_cache_credential", False)

credentials = []
credentials = [] # type: List[AsyncTokenCredential]
if not exclude_environment_credential:
credentials.append(EnvironmentCredential(authority=authority, **kwargs))
if not exclude_managed_identity_credential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if TYPE_CHECKING:
from typing import Any, Optional
from azure.core.configuration import Configuration
from azure.core.credentials_async import AsyncTokenCredential

_LOGGER = logging.getLogger(__name__)

Expand All @@ -39,7 +40,7 @@ class ManagedIdentityCredential(AsyncContextManager):
"""

def __init__(self, **kwargs: "Any") -> None:
self._credential = None
self._credential = None # type: Optional[AsyncTokenCredential]

if os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
if os.environ.get(EnvironmentVariables.MSI_SECRET):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ def __init__(self, **kwargs: "Any") -> None:

client_args = _get_client_args(**kwargs)
if client_args:
self._available = True
self._client = AsyncManagedIdentityClient(**client_args)
else:
self._client = None
self._available = False

async def get_token( # pylint:disable=invalid-overridden-method
self, *scopes: str, **kwargs: "Any"
) -> "AccessToken":
if not self._client:
if not self._available:
raise CredentialUnavailableError(
message="Service Fabric managed identity configuration not found in environment"
)
Expand Down
Loading

0 comments on commit 5075cab

Please sign in to comment.