Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Key Vault] Add local-only mode to CryptographyClient #16565

Merged
merged 18 commits into from
Mar 10, 2021
Merged
120 changes: 106 additions & 14 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
from ._key_validity import raise_if_time_invalid
from ._providers import get_local_cryptography_provider, NoLocalCryptography
from .. import KeyOperation
from .._models import KeyVaultKey
from .._models import JsonWebKey, KeyVaultKey
from .._shared import KeyVaultClientBase, parse_key_vault_id

if TYPE_CHECKING:
# pylint:disable=unused-import
# pylint:disable=unused-import,ungrouped-imports
from datetime import datetime
from typing import Any, Optional, Union
from azure.core.credentials import TokenCredential
from . import EncryptionAlgorithm, KeyWrapAlgorithm, SignatureAlgorithm
from .._shared import KeyVaultResourceId

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,33 +100,87 @@ class CryptographyClient(KeyVaultClientBase):

def __init__(self, key, credential, **kwargs):
# type: (Union[KeyVaultKey, str], TokenCredential, **Any) -> None
self._jwk = kwargs.pop("_jwk", False)
self._not_before = None # type: Optional[datetime]
self._expires_on = None # type: Optional[datetime]
self._key_id = None # type: Optional[KeyVaultResourceId]

if isinstance(key, KeyVaultKey):
self._key = key
self._key = key.key
self._key_id = parse_key_vault_id(key.id)
if key.properties._attributes: # pylint:disable=protected-access
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
self._not_before = key.properties.not_before
self._expires_on = key.properties.expires_on
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
elif self._jwk:
self._key = key
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not self._key_id.version:
if not (self._jwk or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False
if self._jwk:
try:
self._local_provider = get_local_cryptography_provider(self._key)
self._initialized = True
except Exception as ex: # pylint:disable=broad-except
raise ValueError("The provided jwk is not valid for local cryptography: {}".format(ex))
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
else:
self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
self._vault_url = None if self._jwk else self._key_id.vault_url
super(CryptographyClient, self).__init__(
vault_url=self._vault_url or "vault_url", credential=credential, **kwargs
)

@property
def key_id(self):
# type: () -> str
# type: () -> Optional[str]
"""The full identifier of the client's key.

This property may be None when a client is constructed with :func:`from_jwk`.

:rtype: str
"""
if not self._jwk:
return self._key_id.source_id
return self._key.kid

@property
def vault_url(self):
# type: () -> Optional[str]
"""The base vault URL of the client's key.

This property may be None when a client is constructed with :func:`from_jwk`.

:rtype: str
mccoyp marked this conversation as resolved.
Show resolved Hide resolved
"""
return self._key_id.source_id
return self._vault_url

@classmethod
def from_jwk(cls, jwk):
# type: (Union[JsonWebKey, dict]) -> CryptographyClient
"""Creates a client that can only perform cryptographic operations locally.

:param jwk: the key's cryptographic material, as a JsonWebKey or dictionary.
:type jwk: JsonWebKey or dict
:rtype: CryptographyClient

.. literalinclude:: ../tests/test_examples_crypto.py
:start-after: [START from_jwk]
:end-before: [END from_jwk]
:caption: Create a CryptographyClient from a JsonWebKey
:language: python
:dedent: 8
"""
if not isinstance(jwk, JsonWebKey):
jwk = JsonWebKey(**jwk)
return cls(jwk, object(), _jwk=True)

@distributed_trace
def _initialize(self, **kwargs):
Expand All @@ -138,15 +194,15 @@ def _initialize(self, **kwargs):
key_bundle = self._client.get_key(
self._key_id.vault_url, self._key_id.name, self._key_id.version, **kwargs
)
self._key = KeyVaultKey._from_key_bundle(key_bundle) # pylint:disable=protected-access
self._key = KeyVaultKey._from_key_bundle(key_bundle).key # pylint:disable=protected-access
except HttpResponseError as ex:
# if we got a 403, we don't have keys/get permission and won't try to get the key again
# (other errors may be transient)
self._keys_get_forbidden = ex.status_code == 403

# if we have the key material, create a local crypto provider with it
if self._key:
self._local_provider = get_local_cryptography_provider(self._key)
self._local_provider = get_local_cryptography_provider(self._key, _key_id=self.key_id)
self._initialized = True
else:
# try to get the key again next time unless we know we're forbidden to do so
Expand Down Expand Up @@ -181,11 +237,17 @@ def encrypt(self, algorithm, plaintext, **kwargs):
self._initialize(**kwargs)

if self._local_provider.supports(KeyOperation.encrypt, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.encrypt(algorithm, plaintext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local encrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "encrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.encrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -240,6 +302,12 @@ def decrypt(self, algorithm, ciphertext, **kwargs):
return self._local_provider.decrypt(algorithm, ciphertext)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local decrypt operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "decrypt" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.decrypt(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -272,11 +340,17 @@ def wrap_key(self, algorithm, key, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.wrap_key, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.wrap_key(algorithm, key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local wrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "wrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.wrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -311,6 +385,12 @@ def unwrap_key(self, algorithm, encrypted_key, **kwargs):
return self._local_provider.unwrap_key(algorithm, encrypted_key)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local unwrap operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "unwrapKey" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.unwrap_key(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -340,11 +420,17 @@ def sign(self, algorithm, digest, **kwargs):
"""
self._initialize(**kwargs)
if self._local_provider.supports(KeyOperation.sign, algorithm):
raise_if_time_invalid(self._key)
raise_if_time_invalid(self._not_before, self._expires_on)
try:
return self._local_provider.sign(algorithm, digest)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local sign operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "sign" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.sign(
vault_base_url=self._key_id.vault_url,
Expand Down Expand Up @@ -381,6 +467,12 @@ def verify(self, algorithm, digest, signature, **kwargs):
return self._local_provider.verify(algorithm, digest, signature)
except Exception as ex: # pylint:disable=broad-except
_LOGGER.warning("Local verify operation failed: %s", ex, exc_info=_LOGGER.isEnabledFor(logging.DEBUG))
if self._jwk:
raise
elif self._jwk:
raise NotImplementedError(
'This key does not support the "verify" operation with algorithm "{}"'.format(algorithm)
)

operation_result = self._client.verify(
vault_base_url=self._key_id.vault_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import
from .. import KeyVaultKey
from typing import Optional


class _UTC_TZ(tzinfo):
Expand All @@ -28,20 +28,12 @@ def dst(self, dt):
_UTC = _UTC_TZ()


def raise_if_time_invalid(key):
# type: (KeyVaultKey) -> None
try:
nbf = key.properties.not_before
exp = key.properties.expires_on
except AttributeError:
# we consider the key valid because a user must have deliberately created it
# (if it came from Key Vault, it would have those attributes)
return

def raise_if_time_invalid(not_before, expires_on):
# type: (Optional[datetime], Optional[datetime]) -> None
now = datetime.now(_UTC)
if (nbf and exp) and not nbf <= now <= exp:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(nbf, exp))
if nbf and nbf > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(nbf))
if exp and exp <= now:
raise ValueError("This client's key expired at {} (UTC)".format(exp))
if (not_before and expires_on) and not not_before <= now <= expires_on:
raise ValueError("This client's key is useable only between {} and {} (UTC)".format(not_before, expires_on))
if not_before and not_before > now:
raise ValueError("This client's key is not useable until {} (UTC)".format(not_before))
if expires_on and expires_on <= now:
raise ValueError("This client's key expires_onired at {} (UTC)".format(expires_on))
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@
from ... import KeyType

if TYPE_CHECKING:
from ... import KeyVaultKey
from typing import Any
from ... import JsonWebKey


def get_local_cryptography_provider(key):
# type: (KeyVaultKey) -> LocalCryptographyProvider
if key.key_type in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key)
if key.key_type in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key)
if key.key_type in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key)
def get_local_cryptography_provider(key, **kwargs):
# type: (JsonWebKey, **Any) -> LocalCryptographyProvider
if key.kty in (KeyType.ec, KeyType.ec_hsm):
return EllipticCurveCryptographyProvider(key, **kwargs)
if key.kty in (KeyType.rsa, KeyType.rsa_hsm):
return RsaCryptographyProvider(key, **kwargs)
if key.kty in (KeyType.oct, KeyType.oct_hsm):
return SymmetricCryptographyProvider(key, **kwargs)

raise ValueError('Unsupported key type "{}"'.format(key.key_type))
raise ValueError('Unsupported key type "{}"'.format(key.kty))


class NoLocalCryptography(LocalCryptographyProvider):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# pylint:disable=unused-import
from .local_provider import Algorithm
from .._internal import Key
from ... import KeyVaultKey
from ... import JsonWebKey

_PRIVATE_KEY_OPERATIONS = frozenset((KeyOperation.decrypt, KeyOperation.sign, KeyOperation.unwrap_key))


class EllipticCurveCryptographyProvider(LocalCryptographyProvider):
def _get_internal_key(self, key):
# type: (KeyVaultKey) -> Key
if key.key_type not in (KeyType.ec, KeyType.ec_hsm):
# type: (JsonWebKey) -> Key
if key.kty not in (KeyType.ec, KeyType.ec_hsm):
raise ValueError('"key" must be an EC or EC-HSM key')
return EllipticCurveKey.from_jwk(key.key)
return EllipticCurveKey.from_jwk(key)

def supports(self, operation, algorithm):
# type: (KeyOperation, Algorithm) -> bool
Expand Down
Loading