Skip to content

Commit

Permalink
Merge pull request #330 from andrewwhitehead/symm-jwk-import
Browse files Browse the repository at this point in the history
Implement JWK import for symmetric keys
  • Loading branch information
andrewwhitehead authored Dec 2, 2024
2 parents e902e7f + 57e3422 commit 6a3f9fa
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 51 deletions.
41 changes: 40 additions & 1 deletion askar-crypto/src/alg/aes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
encrypt::{KeyAeadInPlace, KeyAeadMeta, KeyAeadParams},
error::Error,
generic_array::{typenum::Unsigned, GenericArray},
jwk::{JwkEncoder, ToJwk},
jwk::{FromJwk, JwkEncoder, JwkParts, ToJwk},
kdf::{FromKeyDerivation, FromKeyExchange, KeyDerivation, KeyExchange},
random::KeyMaterial,
repr::{KeyGen, KeyMeta, KeySecretBytes},
Expand Down Expand Up @@ -122,6 +122,24 @@ impl<T: AesType> FromKeyDerivation for AesKey<T> {
}
}

impl<T: AesType> FromJwk for AesKey<T> {
fn from_jwk_parts(jwk: JwkParts<'_>) -> Result<Self, Error> {
if jwk.kty != JWK_KEY_TYPE {
return Err(err_msg!(InvalidKeyData, "Unsupported key type"));
}
if jwk.alg.is_some() && jwk.alg != T::JWK_ALG {
return Err(err_msg!(InvalidKeyData, "Unsupported key algorithm"));
}
Ok(Self(ArrayKey::try_new_with(|buf| {
if jwk.k.decode_base64(buf)? != buf.len() {
Err(err_msg!(InvalidKeyData))
} else {
Ok(())
}
})?))
}
}

impl<T: AesType> ToJwk for AesKey<T> {
fn encode_jwk(&self, enc: &mut dyn JwkEncoder) -> Result<(), Error> {
if enc.is_public() {
Expand Down Expand Up @@ -292,6 +310,27 @@ mod tests {
.unwrap();
}

#[cfg(feature = "any_key")]
#[test]
fn jwk_any_compat() {
use crate::alg::{any::AnyKey, AesTypes, KeyAlg};
use alloc::boxed::Box;

let test_jwk_compat = r#"
{"alg": "A128CBC-HS256",
"k": "6scajSsnjo2fI-wjCCvBC2xNSYyErNyN93CAsyzVVGI",
"kty": "oct"}
"#;
let key = Box::<AnyKey>::from_jwk(test_jwk_compat).expect("Error decoding AES key JWK");
assert_eq!(key.algorithm(), KeyAlg::Aes(AesTypes::A128CbcHs256));
let as_aes = key
.downcast_ref::<AesKey<A128CbcHs256>>()
.expect("Error downcasting AES key");
let _ = as_aes
.to_jwk_secret(None)
.expect("Error converting key to JWK");
}

#[test]
fn serialize_round_trip() {
fn test_serialize<T: AesType>() {
Expand Down
51 changes: 31 additions & 20 deletions askar-crypto/src/alg/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use core::{

#[cfg(feature = "aes")]
use super::{
aes::{A128CbcHs256, A128Gcm, A128Kw, A256CbcHs512, A256Gcm, A256Kw, AesKey},
aes::{A128CbcHs256, A128Gcm, A128Kw, A256CbcHs512, A256Gcm, A256Kw, AesKey, AesType},
AesTypes,
};

Expand All @@ -19,7 +19,7 @@ use super::{

#[cfg(feature = "chacha")]
use super::{
chacha20::{Chacha20Key, C20P, XC20P},
chacha20::{Chacha20Key, Chacha20Type, C20P, XC20P},
Chacha20Types,
};

Expand Down Expand Up @@ -601,34 +601,45 @@ impl FromJwk for Arc<AnyKey> {

#[inline]
fn from_jwk_any<R: AllocKey>(jwk: JwkParts<'_>) -> Result<R, Error> {
match (jwk.kty, jwk.crv.as_ref()) {
#[cfg(feature = "ed25519")]
("OKP", c) if c == ed25519::JWK_CURVE => {
Ed25519KeyPair::from_jwk_parts(jwk).map(R::alloc_key)
match (jwk.kty, jwk.crv.as_ref(), jwk.alg.as_ref()) {
#[cfg(feature = "aes")]
("oct", _, A128Gcm::JWK_ALG) => AesKey::<A128Gcm>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "aes")]
("oct", _, A256Gcm::JWK_ALG) => AesKey::<A256Gcm>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "aes")]
("oct", _, A128CbcHs256::JWK_ALG) => {
AesKey::<A128CbcHs256>::from_jwk_parts(jwk).map(R::alloc_key)
}
#[cfg(feature = "ed25519")]
("OKP", c) if c == x25519::JWK_CURVE => {
X25519KeyPair::from_jwk_parts(jwk).map(R::alloc_key)
#[cfg(feature = "aes")]
("oct", _, A256CbcHs512::JWK_ALG) => {
AesKey::<A256CbcHs512>::from_jwk_parts(jwk).map(R::alloc_key)
}
#[cfg(feature = "aes")]
("oct", _, A128Kw::JWK_ALG) => AesKey::<A128Kw>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "aes")]
("oct", _, A256Kw::JWK_ALG) => AesKey::<A256Kw>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "bls")]
("OKP" | "EC", c) if c == G1::JWK_CURVE => {
BlsKeyPair::<G1>::from_jwk_parts(jwk).map(R::alloc_key)
}
("OKP" | "EC", G1::JWK_CURVE, _) => BlsKeyPair::<G1>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "bls")]
("OKP" | "EC", c) if c == G2::JWK_CURVE => {
BlsKeyPair::<G2>::from_jwk_parts(jwk).map(R::alloc_key)
}
("OKP" | "EC", G2::JWK_CURVE, _) => BlsKeyPair::<G2>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "bls")]
("OKP" | "EC", c) if c == G1G2::JWK_CURVE => {
("OKP" | "EC", G1G2::JWK_CURVE, _) => {
BlsKeyPair::<G1G2>::from_jwk_parts(jwk).map(R::alloc_key)
}
#[cfg(feature = "chacha")]
("oct", _, C20P::JWK_ALG) => Chacha20Key::<C20P>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "chacha")]
("oct", _, XC20P::JWK_ALG) => Chacha20Key::<XC20P>::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "ed25519")]
("OKP", ed25519::JWK_CURVE, _) => Ed25519KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "ed25519")]
("OKP", x25519::JWK_CURVE, _) => X25519KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "k256")]
("EC", c) if c == k256::JWK_CURVE => K256KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
("EC", k256::JWK_CURVE, _) => K256KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "p256")]
("EC", c) if c == p256::JWK_CURVE => P256KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
("EC", p256::JWK_CURVE, _) => P256KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
#[cfg(feature = "p384")]
("EC", c) if c == p384::JWK_CURVE => P384KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
// FIXME implement symmetric keys?
("EC", p384::JWK_CURVE, _) => P384KeyPair::from_jwk_parts(jwk).map(R::alloc_key),
_ => Err(err_msg!(Unsupported, "Unsupported JWK for key import")),
}
}
Expand Down
41 changes: 40 additions & 1 deletion askar-crypto/src/alg/chacha20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
encrypt::{KeyAeadInPlace, KeyAeadMeta, KeyAeadParams},
error::Error,
generic_array::{typenum::Unsigned, GenericArray},
jwk::{JwkEncoder, ToJwk},
jwk::{FromJwk, JwkEncoder, JwkParts, ToJwk},
kdf::{FromKeyDerivation, FromKeyExchange, KeyDerivation, KeyExchange},
random::KeyMaterial,
repr::{KeyGen, KeyMeta, KeySecretBytes},
Expand Down Expand Up @@ -207,6 +207,24 @@ impl<T: Chacha20Type> KeyAeadInPlace for Chacha20Key<T> {
}
}

impl<T: Chacha20Type> FromJwk for Chacha20Key<T> {
fn from_jwk_parts(jwk: JwkParts<'_>) -> Result<Self, Error> {
if jwk.kty != JWK_KEY_TYPE {
return Err(err_msg!(InvalidKeyData, "Unsupported key type"));
}
if jwk.alg.is_some() && jwk.alg != T::JWK_ALG {
return Err(err_msg!(InvalidKeyData, "Unsupported key algorithm"));
}
Ok(Self(ArrayKey::try_new_with(|buf| {
if jwk.k.decode_base64(buf)? != buf.len() {
Err(err_msg!(InvalidKeyData))
} else {
Ok(())
}
})?))
}
}

impl<T: Chacha20Type> ToJwk for Chacha20Key<T> {
fn encode_jwk(&self, enc: &mut dyn JwkEncoder) -> Result<(), Error> {
if enc.is_public() {
Expand Down Expand Up @@ -263,6 +281,27 @@ mod tests {
test_encrypt::<XC20P>();
}

#[cfg(feature = "any_key")]
#[test]
fn jwk_any_compat() {
use crate::alg::{any::AnyKey, Chacha20Types, KeyAlg};
use alloc::boxed::Box;

let test_jwk_compat = r#"
{"alg": "XC20P",
"k": "IateWalmifmgIAtA6XhbPVKPmjBUiwrs3p0ePHpMivU",
"kty": "oct"}
"#;
let key = Box::<AnyKey>::from_jwk(test_jwk_compat).expect("Error decoding ChaCha key JWK");
assert_eq!(key.algorithm(), KeyAlg::Chacha20(Chacha20Types::XC20P));
let as_chacha = key
.downcast_ref::<Chacha20Key<XC20P>>()
.expect("Error downcasting ChaCha key");
let _ = as_chacha
.to_jwk_secret(None)
.expect("Error converting key to JWK");
}

#[test]
fn serialize_round_trip() {
fn test_serialize<T: Chacha20Type>() {
Expand Down
4 changes: 2 additions & 2 deletions askar-crypto/src/alg/ed25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ use crate::{
};

/// The 'kty' value of an Ed25519 JWK
pub static JWK_KEY_TYPE: &str = "OKP";
pub const JWK_KEY_TYPE: &str = "OKP";
/// The 'crv' value of an Ed25519 JWK
pub static JWK_CURVE: &str = "Ed25519";
pub const JWK_CURVE: &str = "Ed25519";

/// An Ed25519 public key or keypair
#[derive(Clone)]
Expand Down
4 changes: 2 additions & 2 deletions askar-crypto/src/alg/k256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ pub const SECRET_KEY_LENGTH: usize = 32;
pub const KEYPAIR_LENGTH: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH;

/// The 'kty' value of an elliptic curve key JWK
pub static JWK_KEY_TYPE: &str = "EC";
pub const JWK_KEY_TYPE: &str = "EC";
/// The 'crv' value of a K-256 key JWK
pub static JWK_CURVE: &str = "secp256k1";
pub const JWK_CURVE: &str = "secp256k1";

type FieldSize = elliptic_curve::FieldBytesSize<k256::Secp256k1>;

Expand Down
4 changes: 2 additions & 2 deletions askar-crypto/src/alg/p256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ pub const SECRET_KEY_LENGTH: usize = 32;
pub const KEYPAIR_LENGTH: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH;

/// The 'kty' value of an elliptic curve key JWK
pub static JWK_KEY_TYPE: &str = "EC";
pub const JWK_KEY_TYPE: &str = "EC";
/// The 'crv' value of a P-256 key JWK
pub static JWK_CURVE: &str = "P-256";
pub const JWK_CURVE: &str = "P-256";

type FieldSize = elliptic_curve::FieldBytesSize<p256::NistP256>;

Expand Down
4 changes: 2 additions & 2 deletions askar-crypto/src/alg/p384.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ pub const SECRET_KEY_LENGTH: usize = 48;
pub const KEYPAIR_LENGTH: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH;

/// The 'kty' value of an elliptic curve key JWK
pub static JWK_KEY_TYPE: &str = "EC";
pub const JWK_KEY_TYPE: &str = "EC";
/// The 'crv' value of a P-384 key JWK
pub static JWK_CURVE: &str = "P-384";
pub const JWK_CURVE: &str = "P-384";

type FieldSize = elliptic_curve::FieldBytesSize<p384::NistP384>;

Expand Down
4 changes: 2 additions & 2 deletions askar-crypto/src/alg/x25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pub const SECRET_KEY_LENGTH: usize = 32;
pub const KEYPAIR_LENGTH: usize = SECRET_KEY_LENGTH + PUBLIC_KEY_LENGTH;

/// The 'kty' value of an X25519 JWK
pub static JWK_KEY_TYPE: &str = "OKP";
pub const JWK_KEY_TYPE: &str = "OKP";
/// The 'crv' value of an X25519 JWK
pub static JWK_CURVE: &str = "X25519";
pub const JWK_CURVE: &str = "X25519";

/// An X25519 public key or keypair
#[derive(Clone)]
Expand Down
31 changes: 13 additions & 18 deletions wrappers/python/tests/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,27 @@ def test_get_supported_backends():
assert backends == [str(KeyBackend.Software)]


def test_aes_cbc_hmac():
key = Key.generate(KeyAlg.A128CBC_HS256)
assert key.algorithm == KeyAlg.A128CBC_HS256
@pytest.mark.parametrize(
"key_alg",
[KeyAlg.A128CBC_HS256, KeyAlg.A128GCM, KeyAlg.XC20P],
)
def test_symmetric(key_alg: KeyAlg):
key = Key.generate(key_alg)
assert key.algorithm == key_alg

data = b"test message"
nonce = key.aead_random_nonce()
params = key.aead_params()
assert params.nonce_length == 16
assert params.tag_length == 16
assert isinstance(params.nonce_length, int)
assert isinstance(params.tag_length, int)
enc = key.aead_encrypt(data, nonce=nonce, aad=b"aad")
dec = key.aead_decrypt(enc, nonce=nonce, aad=b"aad")
assert data == bytes(dec)


def test_aes_gcm():
key = Key.generate(KeyAlg.A128GCM)
assert key.algorithm == KeyAlg.A128GCM

data = b"test message"
nonce = key.aead_random_nonce()
params = key.aead_params()
assert params.nonce_length == 12
assert params.tag_length == 16
enc = key.aead_encrypt(data, nonce=nonce, aad=b"aad")
dec = key.aead_decrypt(enc, nonce=nonce, aad=b"aad")
assert data == bytes(dec)
jwk = json.loads(key.get_jwk_secret())
assert jwk["kty"] == "oct"
assert KeyAlg.from_key_alg(jwk["alg"].lower().replace("-", "")) == key_alg
assert jwk["k"]


def test_bls_keygen():
Expand Down
41 changes: 40 additions & 1 deletion wrappers/python/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def inc():


@mark.asyncio
async def test_key_store(store: Store):
async def test_key_store_ed25519(store: Store):
# test key operations in a new session
async with store as session:
# Create a new keypair
Expand Down Expand Up @@ -279,6 +279,45 @@ async def test_key_store(store: Store):
assert await session.fetch_key(key_name) is None


@mark.asyncio
@mark.parametrize(
"key_alg",
[KeyAlg.A128CBC_HS256, KeyAlg.XC20P],
)
async def test_key_store_symmetric(store: Store, key_alg: KeyAlg):
# test key operations in a new session
async with store as session:
# Create a new keypair
symm = Key.generate(key_alg)

# Store symmetric key
key_name = "testkey"
await session.insert_key(key_name, symm, metadata="metadata", tags={"a": "b"})

# Fetch keypair
fetch_key = await session.fetch_key(key_name)
assert fetch_key and fetch_key.name == key_name and fetch_key.tags == {"a": "b"}

# Update keypair
await session.update_key(key_name, metadata="updated metadata", tags={"a": "c"})

# Fetch keypair
fetch_key = await session.fetch_key(key_name)
assert fetch_key and fetch_key.name == key_name and fetch_key.tags == {"a": "c"}

# Check key equality
jwk_secret = symm.get_jwk_secret()
assert fetch_key.key.get_jwk_secret() == jwk_secret

# Fetch with filters
keys = await session.fetch_all_keys(alg=key_alg, tag_filter={"a": "c"}, limit=1)
assert len(keys) == 1 and keys[0].name == key_name

# Remove
await session.remove_key(key_name)
assert await session.fetch_key(key_name) is None


@mark.asyncio
async def test_profile(store: Store):
# New session in the default profile
Expand Down

0 comments on commit 6a3f9fa

Please sign in to comment.