From 7373db2a493e23bb03a4d8d2f74fff9767f0cb94 Mon Sep 17 00:00:00 2001 From: Yalin Li Date: Wed, 8 Sep 2021 18:49:45 -0700 Subject: [PATCH] Run mypy in azure-keyvault-keys CI (#20545) --- eng/tox/mypy_hard_failure_packages.py | 1 + .../azure/keyvault/keys/_client.py | 8 +- .../azure/keyvault/keys/_models.py | 87 ++++++++++--------- .../azure/keyvault/keys/crypto/_client.py | 63 +++++++------- .../keys/crypto/_internal/__init__.py | 4 +- .../keys/crypto/_internal/algorithm.py | 14 ++- .../keys/crypto/_internal/algorithms/sha_2.py | 11 ++- .../keyvault/keys/crypto/_internal/key.py | 14 ++- .../azure/keyvault/keys/crypto/_models.py | 14 +-- .../keys/crypto/_providers/__init__.py | 8 +- .../keyvault/keys/crypto/_providers/ec.py | 2 +- .../keys/crypto/_providers/local_provider.py | 22 +++-- .../keyvault/keys/crypto/_providers/rsa.py | 2 +- .../keys/crypto/_providers/symmetric.py | 2 +- .../azure/keyvault/keys/crypto/aio/_client.py | 65 +++++++------- sdk/keyvault/azure-keyvault-keys/mypy.ini | 7 ++ 16 files changed, 191 insertions(+), 133 deletions(-) create mode 100644 sdk/keyvault/azure-keyvault-keys/mypy.ini diff --git a/eng/tox/mypy_hard_failure_packages.py b/eng/tox/mypy_hard_failure_packages.py index 3ae063eb4fb1..175058d3f286 100644 --- a/eng/tox/mypy_hard_failure_packages.py +++ b/eng/tox/mypy_hard_failure_packages.py @@ -11,6 +11,7 @@ "azure-identity", "azure-keyvault-administration", "azure-keyvault-certificates", + "azure-keyvault-keys", "azure-keyvault-secrets", "azure-servicebus", "azure-ai-textanalytics", diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py index 498d6afb32b2..1e561321b9c2 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py @@ -19,7 +19,9 @@ # pylint:disable=unused-import from typing import Any, Optional, Union from azure.core.paging import ItemPaged + from azure.core.polling import LROPoller from ._models import JsonWebKey + from ._enums import KeyType class KeyClient(KeyVaultClientBase): @@ -55,7 +57,7 @@ def _get_attributes(self, enabled, not_before, expires_on, exportable=None): @distributed_trace def create_key(self, name, key_type, **kwargs): - # type: (str, Union[str, azure.keyvault.keys.KeyType], **Any) -> KeyVaultKey + # type: (str, Union[str, KeyType], **Any) -> KeyVaultKey """Create a key or, if ``name`` is already in use, create a new version of the key. Requires keys/create permission. @@ -242,7 +244,7 @@ def create_oct_key(self, name, **kwargs): @distributed_trace def begin_delete_key(self, name, **kwargs): - # type: (str, **Any) -> DeletedKey + # type: (str, **Any) -> LROPoller """Delete all versions of a key and its cryptographic material. Requires keys/delete permission. When this method returns Key Vault has begun deleting the key. Deletion may @@ -450,7 +452,7 @@ def purge_deleted_key(self, name, **kwargs): @distributed_trace def begin_recover_deleted_key(self, name, **kwargs): - # type: (str, **Any) -> KeyVaultKey + # type: (str, **Any) -> LROPoller """Recover a deleted key to its latest version. Possible only in a vault with soft-delete enabled. Requires keys/recover permission. diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py index 822cf9d420c6..6a127ced8283 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py @@ -13,10 +13,10 @@ if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Dict, Optional + from typing import Any, Dict, Optional, List from datetime import datetime from ._generated.v7_0 import models as _models - from ._enums import KeyOperation + from ._enums import KeyOperation, KeyType KeyOperationResult = namedtuple("KeyOperationResult", ["id", "value"]) @@ -83,13 +83,15 @@ def _from_key_bundle(cls, key_bundle): """Construct a KeyProperties from an autorest-generated KeyBundle""" # release_policy was added in 7.3-preview release_policy = None - if hasattr(key_bundle, "release_policy") and key_bundle.release_policy is not None: + if (hasattr(key_bundle, "release_policy") and + key_bundle.release_policy is not None): # type: ignore[attr-defined] release_policy = KeyReleasePolicy( - data=key_bundle.release_policy.data, content_type=key_bundle.release_policy.content_type + data=key_bundle.release_policy.data, # type: ignore[attr-defined] + content_type=key_bundle.release_policy.content_type # type: ignore[attr-defined] ) return cls( - key_bundle.key.kid, + key_bundle.key.kid, # type: ignore attributes=key_bundle.attributes, managed=key_bundle.managed, tags=key_bundle.tags, @@ -100,7 +102,12 @@ def _from_key_bundle(cls, key_bundle): def _from_key_item(cls, key_item): # type: (_models.KeyItem) -> KeyProperties """Construct a KeyProperties from an autorest-generated KeyItem""" - return cls(key_id=key_item.kid, attributes=key_item.attributes, managed=key_item.managed, tags=key_item.tags) + return cls( + key_id=key_item.kid, # type: ignore + attributes=key_item.attributes, + managed=key_item.managed, + tags=key_item.tags + ) @property def id(self): @@ -122,57 +129,57 @@ def name(self): @property def version(self): - # type: () -> str + # type: () -> Optional[str] """The key's version - :rtype: str + :rtype: str or None """ return self._vault_id.version @property def enabled(self): - # type: () -> bool + # type: () -> Optional[bool] """Whether the key is enabled for use - :rtype: bool + :rtype: bool or None """ - return self._attributes.enabled + return self._attributes.enabled if self._attributes else None @property def not_before(self): - # type: () -> datetime + # type: () -> Optional[datetime] """The time before which the key can not be used, in UTC - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ - return self._attributes.not_before + return self._attributes.not_before if self._attributes else None @property def expires_on(self): - # type: () -> datetime + # type: () -> Optional[datetime] """When the key will expire, in UTC - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ - return self._attributes.expires + return self._attributes.expires if self._attributes else None @property def created_on(self): - # type: () -> datetime + # type: () -> Optional[datetime] """When the key was created, in UTC - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ - return self._attributes.created + return self._attributes.created if self._attributes else None @property def updated_on(self): - # type: () -> datetime + # type: () -> Optional[datetime] """When the key was last updated, in UTC - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ - return self._attributes.updated + return self._attributes.updated if self._attributes else None @property def vault_url(self): @@ -188,7 +195,7 @@ def recoverable_days(self): # type: () -> Optional[int] """The number of days the key is retained before being deleted from a soft-delete enabled Key Vault. - :rtype: int + :rtype: int or None """ # recoverable_days was added in 7.1-preview if self._attributes: @@ -197,12 +204,12 @@ def recoverable_days(self): @property def recovery_level(self): - # type: () -> str + # type: () -> Optional[str] """The vault's deletion recovery level for keys - :rtype: str + :rtype: str or None """ - return self._attributes.recovery_level + return self._attributes.recovery_level if self._attributes else None @property def tags(self): @@ -326,7 +333,7 @@ def _from_key_bundle(cls, key_bundle): """Construct a KeyVaultKey from an autorest-generated KeyBundle""" # pylint:disable=protected-access return cls( - key_id=key_bundle.key.kid, + key_id=key_bundle.key.kid, # type: ignore jwk={field: getattr(key_bundle.key, field, None) for field in JsonWebKey._FIELDS}, properties=KeyProperties._from_key_bundle(key_bundle), ) @@ -369,21 +376,23 @@ def key(self): @property def key_type(self): - # type: () -> str + # type: () -> KeyType """The key's type. See :class:`~azure.keyvault.keys.KeyType` for possible values. :rtype: ~azure.keyvault.keys.KeyType or str """ - return self._key_material.kty # pylint:disable=no-member + # pylint:disable=no-member + return self._key_material.kty # type: ignore[attr-defined] @property def key_operations(self): - # type: () -> list[KeyOperation] + # type: () -> List[KeyOperation] """Permitted operations. See :class:`~azure.keyvault.keys.KeyOperation` for possible values. :rtype: list[~azure.keyvault.keys.KeyOperation or str] """ - return self._key_material.key_ops # pylint:disable=no-member + # pylint:disable=no-member + return self._key_material.key_ops # type: ignore[attr-defined] class KeyVaultKeyIdentifier(object): @@ -454,7 +463,7 @@ def _from_deleted_key_bundle(cls, deleted_key_bundle): # pylint:disable=protected-access return cls( properties=KeyProperties._from_key_bundle(deleted_key_bundle), - key_id=deleted_key_bundle.key.kid, + key_id=deleted_key_bundle.key.kid, # type: ignore jwk={field: getattr(deleted_key_bundle.key, field, None) for field in JsonWebKey._FIELDS}, deleted_date=deleted_key_bundle.deleted_date, recovery_id=deleted_key_bundle.recovery_id, @@ -475,28 +484,28 @@ def _from_deleted_key_item(cls, deleted_key_item): @property def deleted_date(self): - # type: () -> datetime + # type: () -> Optional[datetime] """When the key was deleted, in UTC - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ return self._deleted_date @property def recovery_id(self): - # type: () -> str + # type: () -> Optional[str] """An identifier used to recover the deleted key. Returns ``None`` if soft-delete is disabled. - :rtype: str + :rtype: str or None """ return self._recovery_id @property def scheduled_purge_date(self): - # type: () -> datetime + # type: () -> Optional[datetime] """When the key is scheduled to be purged, in UTC. Returns ``None`` if soft-delete is disabled. - :rtype: ~datetime.datetime + :rtype: ~datetime.datetime or None """ return self._scheduled_purge_date diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py index c577ae69c471..1008ee131392 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import six from azure.core.exceptions import HttpResponseError @@ -112,7 +112,7 @@ def __init__(self, key, credential, **kwargs): self._key_id = None # type: Optional[KeyVaultResourceId] if isinstance(key, KeyVaultKey): - self._key = key.key + self._key = key.key # type: Union[JsonWebKey, KeyVaultKey, str, None] self._key_id = parse_key_vault_id(key.id) if key.properties._attributes: # pylint:disable=protected-access self._not_before = key.properties.not_before @@ -120,18 +120,18 @@ def __init__(self, key, credential, **kwargs): elif isinstance(key, six.string_types): self._key = None self._key_id = parse_key_vault_id(key) - self._keys_get_forbidden = None # type: Optional[bool] + self._keys_get_forbidden = False elif self._jwk: self._key = key else: raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version") - if not (self._jwk or self._key_id.version): + if not (self._jwk or (self._key_id.version if self._key_id else None)): raise ValueError("'key' must include a version") if self._jwk: try: - self._local_provider = get_local_cryptography_provider(self._key) + self._local_provider = get_local_cryptography_provider(cast(JsonWebKey, self._key)) self._initialized = True except Exception as ex: # pylint:disable=broad-except six.raise_from(ValueError("The provided jwk is not valid for local cryptography"), ex) @@ -139,7 +139,7 @@ def __init__(self, key, credential, **kwargs): self._local_provider = NoLocalCryptography() self._initialized = False - self._vault_url = None if self._jwk else self._key_id.vault_url + self._vault_url = None if (self._jwk or self._key_id is None) else self._key_id.vault_url # type: ignore super(CryptographyClient, self).__init__( vault_url=self._vault_url or "vault_url", credential=credential, **kwargs ) @@ -154,11 +154,11 @@ def key_id(self): :rtype: str or None """ if not self._jwk: - return self._key_id.source_id - return self._key.kid + return self._key_id.source_id if self._key_id else None + return cast(JsonWebKey, self._key).kid # type: ignore[attr-defined] @property - def vault_url(self): + def vault_url(self): # type: ignore # type: () -> Optional[str] """The base vault URL of the client's key. @@ -179,7 +179,7 @@ def from_jwk(cls, jwk): """ if not isinstance(jwk, JsonWebKey): jwk = JsonWebKey(**jwk) - return cls(jwk, object(), _jwk=True) + return cls(jwk, object(), _jwk=True) # type: ignore @distributed_trace def _initialize(self, **kwargs): @@ -191,7 +191,10 @@ def _initialize(self, **kwargs): if not (self._key or self._keys_get_forbidden): try: key_bundle = self._client.get_key( - self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs + self._key_id.vault_url if self._key_id else None, + self._key_id.name if self._key_id else None, + self._key_id.version if self._key_id else None, + **kwargs ) self._key = KeyVaultKey._from_key_bundle(key_bundle).key # pylint:disable=protected-access except HttpResponseError as ex: @@ -201,7 +204,7 @@ def _initialize(self, **kwargs): # if we have the key material, create a local crypto provider with it if self._key: - self._local_provider = get_local_cryptography_provider(self._key) + self._local_provider = get_local_cryptography_provider(cast(JsonWebKey, self._key)) self._initialized = True else: # try to get the key again next time unless we know we're forbidden to do so @@ -250,9 +253,9 @@ def encrypt(self, algorithm, plaintext, **kwargs): ) operation_result = self._client.encrypt( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=plaintext, iv=iv, aad=aad), **kwargs ) @@ -311,9 +314,9 @@ def decrypt(self, algorithm, ciphertext, **kwargs): ) operation_result = self._client.decrypt( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters( algorithm=algorithm, value=ciphertext, iv=iv, tag=tag, aad=aad ), @@ -356,9 +359,9 @@ def wrap_key(self, algorithm, key, **kwargs): ) operation_result = self._client.wrap_key( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=key), **kwargs ) @@ -398,9 +401,9 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs): ) operation_result = self._client.unwrap_key( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=encrypted_key), **kwargs ) @@ -440,9 +443,9 @@ def sign(self, algorithm, digest, **kwargs): ) operation_result = self._client.sign( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeySignParameters(algorithm=algorithm, value=digest), **kwargs ) @@ -484,9 +487,9 @@ def verify(self, algorithm, digest, signature, **kwargs): ) operation_result = self._client.verify( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyVerifyParameters(algorithm=algorithm, digest=digest, signature=signature), **kwargs ) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/__init__.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/__init__.py index 942edd662580..880d4cdeb7ae 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/__init__.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/__init__.py @@ -15,7 +15,7 @@ from .symmetric_key import SymmetricKey from .transform import CryptoTransform, BlockCryptoTransform, AuthenticatedCryptoTransform, SignatureTransform -__all__ = { +__all__ = [ "Key", "EllipticCurveKey", "RsaKey", @@ -29,4 +29,4 @@ "AuthenticatedSymmetricEncryptionAlgorithm", "SignatureTransform", "SymmetricKey", -} +] diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithm.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithm.py index 7270b14ff43c..1b850cf2b14b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithm.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithm.py @@ -4,11 +4,21 @@ # ------------------------------------ from abc import abstractmethod +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Optional, Union + from cryptography.hazmat.primitives import hashes + _alg_registry = {} class Algorithm(object): - _name = None + _name = None # type: Optional[str] @classmethod def name(cls): @@ -56,7 +66,7 @@ def create_decryptor(self, key, iv, auth_data, auth_tag): class SignatureAlgorithm(Algorithm): - _default_hash_algorithm = None + _default_hash_algorithm = None # type: Union[hashes.SHA256, hashes.SHA384, hashes.SHA512, None] @property def default_hash_algorithm(self): diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithms/sha_2.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithms/sha_2.py index c6656b0493c5..18b4838afeba 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithms/sha_2.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/algorithms/sha_2.py @@ -8,6 +8,15 @@ from ..algorithm import HashAlgorithm from ..transform import DigestTransform +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Union, Type + class _Sha2DigestTransform(DigestTransform): def __init__(self, algorithm): @@ -23,7 +32,7 @@ def finalize(self, data): class _Sha2HashAlgorithm(HashAlgorithm): - _algorithm_cls = None + _algorithm_cls = None # type: Union[Type[hashes.SHA256], Type[hashes.SHA384], Type[hashes.SHA512], None] def create_digest(self): return _Sha2DigestTransform(self._algorithm_cls()) # pylint:disable=not-callable diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/key.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/key.py index a593279bb2d7..2a04602ba8f1 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/key.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_internal/key.py @@ -7,11 +7,19 @@ from six import with_metaclass from .algorithm import Algorithm +try: + from typing import TYPE_CHECKING +except ImportError: + TYPE_CHECKING = False + +if TYPE_CHECKING: + # pylint:disable=unused-import + from typing import Any, FrozenSet class Key(with_metaclass(ABCMeta, object)): - _supported_encryption_algorithms = [] - _supported_key_wrap_algorithms = [] - _supported_signature_algorithms = [] + _supported_encryption_algorithms = frozenset([]) # type: FrozenSet[Any] + _supported_key_wrap_algorithms = frozenset([]) # type: FrozenSet[Any] + _supported_signature_algorithms = frozenset([]) # type: FrozenSet[Any] def __init__(self): self._kid = None diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_models.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_models.py index 61fe89865625..fc9e9f9c6d2f 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_models.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_models.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from . import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm - from typing import Any + from typing import Any, Optional class DecryptResult: @@ -19,7 +19,7 @@ class DecryptResult: """ def __init__(self, key_id, algorithm, plaintext): - # type: (str, EncryptionAlgorithm, bytes) -> None + # type: (Optional[str], EncryptionAlgorithm, bytes) -> None self.key_id = key_id self.algorithm = algorithm self.plaintext = plaintext @@ -40,7 +40,7 @@ class EncryptResult: """ def __init__(self, key_id, algorithm, ciphertext, **kwargs): - # type: (str, EncryptionAlgorithm, bytes, **Any) -> None + # type: (Optional[str], EncryptionAlgorithm, bytes, **Any) -> None self.key_id = key_id self.algorithm = algorithm self.ciphertext = ciphertext @@ -59,7 +59,7 @@ class SignResult: """ def __init__(self, key_id, algorithm, signature): - # type: (str, SignatureAlgorithm, bytes) -> None + # type: (Optional[str], SignatureAlgorithm, bytes) -> None self.key_id = key_id self.algorithm = algorithm self.signature = signature @@ -75,7 +75,7 @@ class VerifyResult: """ def __init__(self, key_id, is_valid, algorithm): - # type: (str, bool, SignatureAlgorithm) -> None + # type: (Optional[str], bool, SignatureAlgorithm) -> None self.key_id = key_id self.is_valid = is_valid self.algorithm = algorithm @@ -91,7 +91,7 @@ class UnwrapResult: """ def __init__(self, key_id, algorithm, key): - # type: (str, KeyWrapAlgorithm, bytes) -> None + # type: (Optional[str], KeyWrapAlgorithm, bytes) -> None self.key_id = key_id self.algorithm = algorithm self.key = key @@ -107,7 +107,7 @@ class WrapResult: """ def __init__(self, key_id, algorithm, encrypted_key): - # type: (str, KeyWrapAlgorithm, bytes) -> None + # type: (Optional[str], KeyWrapAlgorithm, bytes) -> None self.key_id = key_id self.algorithm = algorithm self.encrypted_key = encrypted_key diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/__init__.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/__init__.py index f13ec67b4dba..1ff9652740e7 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/__init__.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/__init__.py @@ -16,14 +16,14 @@ def get_local_cryptography_provider(key): # type: (JsonWebKey) -> LocalCryptographyProvider - if key.kty in (KeyType.ec, KeyType.ec_hsm): + if key.kty in (KeyType.ec, KeyType.ec_hsm): # type: ignore[attr-defined] return EllipticCurveCryptographyProvider(key) - if key.kty in (KeyType.rsa, KeyType.rsa_hsm): + if key.kty in (KeyType.rsa, KeyType.rsa_hsm): # type: ignore[attr-defined] return RsaCryptographyProvider(key) - if key.kty in (KeyType.oct, KeyType.oct_hsm): + if key.kty in (KeyType.oct, KeyType.oct_hsm): # type: ignore[attr-defined] return SymmetricCryptographyProvider(key) - raise ValueError('Unsupported key type "{}"'.format(key.kty)) + raise ValueError('Unsupported key type "{}"'.format(key.kty)) # type: ignore[attr-defined] class NoLocalCryptography(LocalCryptographyProvider): diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/ec.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/ec.py index e415474c90ae..5b2957062727 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/ec.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/ec.py @@ -20,7 +20,7 @@ class EllipticCurveCryptographyProvider(LocalCryptographyProvider): def _get_internal_key(self, key): # type: (JsonWebKey) -> Key - if key.kty not in (KeyType.ec, KeyType.ec_hsm): + if key.kty not in (KeyType.ec, KeyType.ec_hsm): # type: ignore[attr-defined] raise ValueError('"key" must be an EC or EC-HSM key') return EllipticCurveKey.from_jwk(key) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py index 504ff0ae8828..df8c03638bc1 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/local_provider.py @@ -28,7 +28,7 @@ class LocalCryptographyProvider(ABC): def __init__(self, key): # type: (JsonWebKey) -> None - self._allowed_ops = frozenset(key.key_ops or []) + self._allowed_ops = frozenset(key.key_ops or []) # type: ignore[attr-defined] self._internal_key = self._get_internal_key(key) self._key = key @@ -49,7 +49,7 @@ def key_id(self): :rtype: str or None """ - return self._key.kid + return self._key.kid # type: ignore[attr-defined] def _raise_if_unsupported(self, operation, algorithm): # type: (KeyOperation, Algorithm) -> None @@ -64,34 +64,40 @@ def encrypt(self, algorithm, plaintext, iv=None): # type: (EncryptionAlgorithm, bytes, Optional[bytes]) -> EncryptResult self._raise_if_unsupported(KeyOperation.encrypt, algorithm) ciphertext = self._internal_key.encrypt(plaintext, algorithm=algorithm.value, iv=iv) - return EncryptResult(key_id=self._key.kid, algorithm=algorithm, ciphertext=ciphertext, iv=iv) + return EncryptResult( + key_id=self._key.kid, algorithm=algorithm, ciphertext=ciphertext, iv=iv # type: ignore[attr-defined] + ) def decrypt(self, algorithm, ciphertext, iv=None): # type: (EncryptionAlgorithm, bytes, Optional[bytes]) -> DecryptResult self._raise_if_unsupported(KeyOperation.decrypt, algorithm) plaintext = self._internal_key.decrypt(ciphertext, iv=iv, algorithm=algorithm.value) - return DecryptResult(key_id=self._key.kid, algorithm=algorithm, plaintext=plaintext) + return DecryptResult( + key_id=self._key.kid, algorithm=algorithm, plaintext=plaintext # type: ignore[attr-defined] + ) def wrap_key(self, algorithm, key): # type: (KeyWrapAlgorithm, bytes) -> WrapResult self._raise_if_unsupported(KeyOperation.wrap_key, algorithm) encrypted_key = self._internal_key.wrap_key(key, algorithm=algorithm.value) - return WrapResult(key_id=self._key.kid, algorithm=algorithm, encrypted_key=encrypted_key) + return WrapResult( + key_id=self._key.kid, algorithm=algorithm, encrypted_key=encrypted_key # type: ignore[attr-defined] + ) def unwrap_key(self, algorithm, encrypted_key): # type: (KeyWrapAlgorithm, bytes) -> UnwrapResult self._raise_if_unsupported(KeyOperation.unwrap_key, algorithm) unwrapped_key = self._internal_key.unwrap_key(encrypted_key, algorithm=algorithm.value) - return UnwrapResult(key_id=self._key.kid, algorithm=algorithm, key=unwrapped_key) + return UnwrapResult(key_id=self._key.kid, algorithm=algorithm, key=unwrapped_key) # type: ignore[attr-defined] def sign(self, algorithm, digest): # type: (SignatureAlgorithm, bytes) -> SignResult self._raise_if_unsupported(KeyOperation.sign, algorithm) signature = self._internal_key.sign(digest, algorithm=algorithm.value) - return SignResult(key_id=self._key.kid, algorithm=algorithm, signature=signature) + return SignResult(key_id=self._key.kid, algorithm=algorithm, signature=signature) # type: ignore[attr-defined] def verify(self, algorithm, digest, signature): # type: (SignatureAlgorithm, bytes, bytes) -> VerifyResult self._raise_if_unsupported(KeyOperation.verify, algorithm) is_valid = self._internal_key.verify(digest, signature, algorithm=algorithm.value) - return VerifyResult(key_id=self._key.kid, algorithm=algorithm, is_valid=is_valid) + return VerifyResult(key_id=self._key.kid, algorithm=algorithm, is_valid=is_valid) # type: ignore[attr-defined] diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/rsa.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/rsa.py index 3498072db217..292e6bba4c97 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/rsa.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/rsa.py @@ -20,7 +20,7 @@ class RsaCryptographyProvider(LocalCryptographyProvider): def _get_internal_key(self, key): # type: (JsonWebKey) -> Key - if key.kty not in (KeyType.rsa, KeyType.rsa_hsm): + if key.kty not in (KeyType.rsa, KeyType.rsa_hsm): # type: ignore[attr-defined] raise ValueError('"key" must be an RSA or RSA-HSM key') return RsaKey.from_jwk(key) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/symmetric.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/symmetric.py index 73a07f296a6a..729377acab7b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/symmetric.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_providers/symmetric.py @@ -18,7 +18,7 @@ class SymmetricCryptographyProvider(LocalCryptographyProvider): def _get_internal_key(self, key): # type: (JsonWebKey) -> Key - if key.kty not in (KeyType.oct, KeyType.oct_hsm): + if key.kty not in (KeyType.oct, KeyType.oct_hsm): # type: ignore[attr-defined] raise ValueError('"key" must be an oct or oct-HSM (symmetric) key') return SymmetricKey.from_jwk(key) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py index ffda9d3d06f8..8c21543eddff 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/aio/_client.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from azure.core.exceptions import HttpResponseError from azure.core.tracing.decorator_async import distributed_trace_async @@ -61,7 +61,7 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCreden self._key_id = None # type: Optional[KeyVaultResourceId] if isinstance(key, KeyVaultKey): - self._key = key.key + self._key = key.key # type: Union[JsonWebKey, KeyVaultKey, str, None] self._key_id = parse_key_vault_id(key.id) if key.properties._attributes: # pylint:disable=protected-access self._not_before = key.properties.not_before @@ -69,18 +69,18 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCreden elif isinstance(key, str): self._key = None self._key_id = parse_key_vault_id(key) - self._keys_get_forbidden = None # type: Optional[bool] + self._keys_get_forbidden = False elif self._jwk: self._key = key else: raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version") - if not (self._jwk or self._key_id.version): + if not (self._jwk or (self._key_id.version if self._key_id else None)): raise ValueError("'key' must include a version") if self._jwk: try: - self._local_provider = get_local_cryptography_provider(self._key) + self._local_provider = get_local_cryptography_provider(cast(JsonWebKey, self._key)) self._initialized = True except Exception as ex: # pylint:disable=broad-except raise ValueError("The provided jwk is not valid for local cryptography") from ex @@ -88,7 +88,7 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCreden self._local_provider = NoLocalCryptography() self._initialized = False - self._vault_url = None if self._jwk else self._key_id.vault_url + self._vault_url = None if (self._jwk or self._key_id is None) else self._key_id.vault_url # type: ignore super().__init__(vault_url=self._vault_url or "vault_url", credential=credential, **kwargs) @property @@ -100,11 +100,11 @@ def key_id(self) -> "Optional[str]": :rtype: str or None """ if not self._jwk: - return self._key_id.source_id - return self._key.kid + return self._key_id.source_id if self._key_id else None + return cast(JsonWebKey, self._key).kid # type: ignore[attr-defined] @property - def vault_url(self) -> "Optional[str]": + def vault_url(self) -> "Optional[str]": # type: ignore """The base vault URL of the client's key. This property may be None when a client is constructed with :func:`from_jwk`. @@ -123,7 +123,7 @@ def from_jwk(cls, jwk: "Union[JsonWebKey, dict]") -> "CryptographyClient": """ if not isinstance(jwk, JsonWebKey): jwk = JsonWebKey(**jwk) - return cls(jwk, object(), _jwk=True) + return cls(jwk, object(), _jwk=True) # type: ignore @distributed_trace_async async def _initialize(self, **kwargs): @@ -135,7 +135,10 @@ async def _initialize(self, **kwargs): if not (self._key or self._keys_get_forbidden): try: key_bundle = await self._client.get_key( - self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs + self._key_id.vault_url if self._key_id else None, + self._key_id.name if self._key_id else None, + self._key_id.version if self._key_id else None, + **kwargs ) self._key = KeyVaultKey._from_key_bundle(key_bundle).key # pylint:disable=protected-access except HttpResponseError as ex: @@ -145,7 +148,7 @@ async def _initialize(self, **kwargs): # if we have the key material, create a local crypto provider with it if self._key: - self._local_provider = get_local_cryptography_provider(self._key) + self._local_provider = get_local_cryptography_provider(cast(JsonWebKey, self._key)) self._initialized = True else: # try to get the key again next time unless we know we're forbidden to do so @@ -193,9 +196,9 @@ async def encrypt(self, algorithm: "EncryptionAlgorithm", plaintext: bytes, **kw ) operation_result = await self._client.encrypt( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=plaintext, iv=iv, aad=aad), **kwargs ) @@ -253,9 +256,9 @@ async def decrypt(self, algorithm: "EncryptionAlgorithm", ciphertext: bytes, **k ) operation_result = await self._client.decrypt( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters( algorithm=algorithm, value=ciphertext, iv=iv, tag=tag, aad=aad ), @@ -297,9 +300,9 @@ async def wrap_key(self, algorithm: "KeyWrapAlgorithm", key: bytes, **kwargs: "A ) operation_result = await self._client.wrap_key( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=key), **kwargs ) @@ -338,14 +341,14 @@ async def unwrap_key(self, algorithm: "KeyWrapAlgorithm", encrypted_key: bytes, ) operation_result = await self._client.unwrap_key( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyOperationsParameters(algorithm=algorithm, value=encrypted_key), **kwargs ) - return UnwrapResult(key_id=self._key_id, algorithm=algorithm, key=operation_result.result) + return UnwrapResult(key_id=self.key_id, algorithm=algorithm, key=operation_result.result) @distributed_trace_async async def sign(self, algorithm: "SignatureAlgorithm", digest: bytes, **kwargs: "Any") -> SignResult: @@ -380,9 +383,9 @@ async def sign(self, algorithm: "SignatureAlgorithm", digest: bytes, **kwargs: " ) operation_result = await self._client.sign( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeySignParameters(algorithm=algorithm, value=digest), **kwargs ) @@ -425,9 +428,9 @@ async def verify( ) operation_result = await self._client.verify( - vault_base_url=self._key_id.vault_url, - key_name=self._key_id.name, - key_version=self._key_id.version, + vault_base_url=self._key_id.vault_url if self._key_id else None, + key_name=self._key_id.name if self._key_id else None, + key_version=self._key_id.version if self._key_id else None, parameters=self._models.KeyVerifyParameters(algorithm=algorithm, digest=digest, signature=signature), **kwargs ) diff --git a/sdk/keyvault/azure-keyvault-keys/mypy.ini b/sdk/keyvault/azure-keyvault-keys/mypy.ini new file mode 100644 index 000000000000..18b37b44c426 --- /dev/null +++ b/sdk/keyvault/azure-keyvault-keys/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +python_version = 3.6 +warn_unused_configs = True +ignore_missing_imports = True + +[mypy-azure.keyvault.*._generated.*] +ignore_errors = True