Skip to content

Commit

Permalink
Allow None key_id, try JWK's .kid
Browse files Browse the repository at this point in the history
  • Loading branch information
mccoyp committed Mar 5, 2021
1 parent b801f12 commit ae2967b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
14 changes: 10 additions & 4 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,12 @@ def __init__(self, key_id, attributes=None, **kwargs):
# type: (str, Optional[_models.KeyAttributes], **Any) -> None
self._attributes = attributes
self._id = key_id
self._vault_id = parse_key_vault_id(key_id)
self._managed = kwargs.get("managed", None)
self._tags = kwargs.get("tags", None)
try:
self._vault_id = parse_key_vault_id(key_id)
except:
self._vault_id = None

def __repr__(self):
# type () -> str
Expand Down Expand Up @@ -106,7 +109,8 @@ def name(self):
:rtype: str
"""
return self._vault_id.name
if self._vault_id:
return self._vault_id.name

@property
def version(self):
Expand All @@ -115,7 +119,8 @@ def version(self):
:rtype: str
"""
return self._vault_id.version
if self._vault_id:
return self._vault_id.version

@property
def enabled(self):
Expand Down Expand Up @@ -169,7 +174,8 @@ def vault_url(self):
:rtype: str
"""
return self._vault_id.vault_url
if self._vault_id:
return self._vault_id.vault_url

@property
def recoverable_days(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,27 @@ def __init__(self, key, credential, **kwargs):

if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_key_vault_id(key.id)
try:
self._key_id = parse_key_vault_id(key.id)
except ValueError:
if not self._local_only:
raise
self._key_id = None
elif isinstance(key, six.string_types):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not (self._key_id.version or self._local_only):
if not (self._local_only or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False

super(CryptographyClient, self).__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
vault_url = "vault_url" if not self._key_id else self._key_id.vault_url
super(CryptographyClient, self).__init__(vault_url=vault_url, credential=credential, **kwargs)

@property
def key_id(self):
Expand All @@ -125,7 +131,8 @@ def key_id(self):
:rtype: str
"""
return self._key_id.source_id
if self._key_id:
return self._key_id.source_id

@classmethod
def from_jwk(cls, jwk):
Expand All @@ -138,10 +145,9 @@ def from_jwk(cls, jwk):
"""
if isinstance(jwk, JsonWebKey):
key = vars(jwk)
key_id = jwk.kid
else:
key = jwk
key_id = jwk.get("kid")
key_id = key.get("kid")
return cls(KeyVaultKey(key_id, key), object(), _local_only=True)

@distributed_trace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,36 @@ def __init__(self, key: "Union[KeyVaultKey, str]", credential: "AsyncTokenCreden

if isinstance(key, KeyVaultKey):
self._key = key
self._key_id = parse_key_vault_id(key.id)
try:
self._key_id = parse_key_vault_id(key.id)
except ValueError:
if not self._local_only:
raise
self._key_id = None
elif isinstance(key, str):
self._key = None
self._key_id = parse_key_vault_id(key)
self._keys_get_forbidden = None # type: Optional[bool]
else:
raise ValueError("'key' must be a KeyVaultKey instance or a key ID string including a version")

if not (self._key_id.version or self._local_only):
if not (self._local_only or self._key_id.version):
raise ValueError("'key' must include a version")

self._local_provider = NoLocalCryptography()
self._initialized = False

super().__init__(vault_url=self._key_id.vault_url, credential=credential, **kwargs)
vault_url = "vault_url" if not self._key_id else self._key_id.vault_url
super().__init__(vault_url=vault_url, credential=credential, **kwargs)

@property
def key_id(self) -> str:
"""The full identifier of the client's key.
:rtype: str
"""
return self._key_id.source_id
if self._key_id:
return self._key_id.source_id

@classmethod
def from_jwk(cls, jwk: "Union[JsonWebKey, dict]") -> "CryptographyClient":
Expand All @@ -91,10 +98,9 @@ def from_jwk(cls, jwk: "Union[JsonWebKey, dict]") -> "CryptographyClient":
"""
if isinstance(jwk, JsonWebKey):
key = vars(jwk)
key_id = jwk.kid
else:
key = jwk
key_id = jwk.get("kid")
key_id = key.get("kid")
return cls(KeyVaultKey(key_id, key), object(), _local_only=True)

@distributed_trace_async
Expand Down

0 comments on commit ae2967b

Please sign in to comment.