Skip to content

Commit

Permalink
Cleanup use of unsafe blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
justsmth committed Jan 17, 2024
1 parent a2e83eb commit 7c68274
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 94 deletions.
26 changes: 12 additions & 14 deletions aws-lc-rs/src/agreement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ use crate::ec::{
};
use crate::error::{KeyRejected, Unspecified};
use crate::fips::indicator_check;
use crate::ptr::{ConstPointer, DetachableLcPtr, LcPtr, Pointer};
use crate::ptr::{ConstPointer, LcPtr, Pointer};
use crate::{ec, hex};
use aws_lc::{
EVP_PKEY_CTX_new, EVP_PKEY_CTX_new_id, EVP_PKEY_derive, EVP_PKEY_derive_init,
EVP_PKEY_derive_set_peer, EVP_PKEY_get0_EC_KEY, EVP_PKEY_get_raw_private_key,
EVP_PKEY_get_raw_public_key, EVP_PKEY_keygen, EVP_PKEY_keygen_init,
EVP_PKEY_new_raw_private_key, EVP_PKEY_new_raw_public_key, NID_X9_62_prime256v1, NID_secp384r1,
NID_secp521r1, EVP_PKEY, EVP_PKEY_X25519, NID_X25519,
NID_secp521r1, BIGNUM, EC_GROUP, EVP_PKEY, EVP_PKEY_X25519, NID_X25519,
};

use crate::buffer::Buffer;
Expand Down Expand Up @@ -307,13 +307,11 @@ impl PrivateKey {
)
})?
} else {
let ec_group = unsafe { ec_group_from_nid(alg.id.nid())? };
let private_bn = DetachableLcPtr::try_from(key_bytes)?;
let ec_group: LcPtr<EC_GROUP> = ec_group_from_nid(alg.id.nid())?;
let private_bn: LcPtr<BIGNUM> = LcPtr::try_from(key_bytes)?;

unsafe {
ec::evp_pkey_from_private(&ec_group.as_const(), &private_bn.as_const())
.map_err(|_| KeyRejected::invalid_encoding())?
}
ec::evp_pkey_from_private(&ec_group.as_const(), &private_bn.as_const())
.map_err(|_| KeyRejected::invalid_encoding())?
};
Ok(Self::new(alg, evp_pkey))
}
Expand Down Expand Up @@ -511,10 +509,10 @@ impl AsBigEndian<Curve25519SeedBin<'static>> for PrivateKey {

#[cfg(test)]
fn from_ec_private_key(priv_key: &[u8], nid: i32) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
let ec_group = unsafe { ec_group_from_nid(nid)? };
let priv_key = DetachableLcPtr::try_from(priv_key)?;
let ec_group: LcPtr<EC_GROUP> = ec_group_from_nid(nid)?;
let priv_key: LcPtr<BIGNUM> = LcPtr::try_from(priv_key)?;

let pkey = unsafe { ec::evp_pkey_from_private(&ec_group.as_const(), &priv_key.as_const())? };
let pkey = ec::evp_pkey_from_private(&ec_group.as_const(), &priv_key.as_const())?;

Ok(pkey)
}
Expand Down Expand Up @@ -696,9 +694,9 @@ fn ec_key_ecdh<'a>(
peer_pub_key_bytes: &[u8],
nid: i32,
) -> Result<&'a [u8], ()> {
let ec_group = unsafe { ec_group_from_nid(nid)? };
let pub_key_point = unsafe { ec_point_from_bytes(&ec_group, peer_pub_key_bytes) }?;
let pub_key = unsafe { evp_pkey_from_public_point(&ec_group, &pub_key_point) }?;
let ec_group = ec_group_from_nid(nid)?;
let pub_key_point = ec_point_from_bytes(&ec_group, peer_pub_key_bytes)?;
let pub_key = evp_pkey_from_public_point(&ec_group, &pub_key_point)?;

let pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new(**priv_key, null_mut()) })?;

Expand Down
10 changes: 9 additions & 1 deletion aws-lc-rs/src/bn.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

use crate::ptr::{ConstPointer, DetachableLcPtr};
use crate::ptr::{ConstPointer, DetachableLcPtr, LcPtr};
use aws_lc::{BN_bin2bn, BN_bn2bin, BN_cmp, BN_new, BN_num_bits, BN_num_bytes, BN_set_u64, BIGNUM};
use mirai_annotations::unrecoverable;
use std::cmp::Ordering;
use std::ptr::null_mut;

impl TryFrom<&[u8]> for LcPtr<BIGNUM> {
type Error = ();

fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
unsafe { LcPtr::new(BN_bin2bn(bytes.as_ptr(), bytes.len(), null_mut())) }
}
}

impl TryFrom<&[u8]> for DetachableLcPtr<BIGNUM> {
type Error = ();

Expand Down
78 changes: 41 additions & 37 deletions aws-lc-rs/src/ec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ impl AsDer<EcPublicKeyX509Der<'static>> for PublicKey {
/// # Errors
/// Returns an error if the underlying implementation is unable to marshal the point.
fn as_der(&self) -> Result<EcPublicKeyX509Der<'static>, Unspecified> {
let ec_group = unsafe { LcPtr::new(EC_GROUP_new_by_curve_name(self.algorithm.id.nid()))? };
let ec_point = unsafe { ec_point_from_bytes(&ec_group, self.as_ref())? };
let ec_key = unsafe { LcPtr::new(EC_KEY_new())? };
let ec_group = LcPtr::new(unsafe { EC_GROUP_new_by_curve_name(self.algorithm.id.nid()) })?;
let ec_point = ec_point_from_bytes(&ec_group, self.as_ref())?;
let ec_key = LcPtr::new(unsafe { EC_KEY_new() })?;
if 1 != unsafe { EC_KEY_set_group(*ec_key, *ec_group) } {
return Err(Unspecified);
}
Expand Down Expand Up @@ -266,9 +266,9 @@ fn evp_pkey_from_public_key(
alg: &'static AlgorithmID,
public_key: &[u8],
) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
let ec_group = unsafe { ec_group_from_nid(alg.nid())? };
let ec_point = unsafe { ec_point_from_bytes(&ec_group, public_key)? };
let pkey = unsafe { evp_pkey_from_public_point(&ec_group, &ec_point)? };
let ec_group = ec_group_from_nid(alg.nid())?;
let ec_point = ec_point_from_bytes(&ec_group, public_key)?;
let pkey = evp_pkey_from_public_point(&ec_group, &ec_point)?;

Ok(pkey)
}
Expand Down Expand Up @@ -397,16 +397,16 @@ pub(crate) fn marshal_public_key(
}

#[inline]
pub(crate) unsafe fn evp_pkey_from_public_point(
pub(crate) fn evp_pkey_from_public_point(
ec_group: &LcPtr<EC_GROUP>,
public_ec_point: &LcPtr<EC_POINT>,
) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
let nid = EC_GROUP_get_curve_name(ec_group.as_const_ptr());
let ec_key = DetachableLcPtr::new(EC_KEY_new())?;
if 1 != EC_KEY_set_group(*ec_key, **ec_group) {
let nid = unsafe { EC_GROUP_get_curve_name(ec_group.as_const_ptr()) };
let ec_key = DetachableLcPtr::new(unsafe { EC_KEY_new() })?;
if 1 != unsafe { EC_KEY_set_group(*ec_key, **ec_group) } {
return Err(Unspecified);
}
if 1 != EC_KEY_set_public_key(*ec_key, **public_ec_point) {
if 1 != unsafe { EC_KEY_set_public_key(*ec_key, **public_ec_point) } {
return Err(Unspecified);
}

Expand All @@ -423,32 +423,34 @@ pub(crate) unsafe fn evp_pkey_from_public_point(
Ok(pkey)
}

pub(crate) unsafe fn evp_pkey_from_private(
pub(crate) fn evp_pkey_from_private(
ec_group: &ConstPointer<EC_GROUP>,
private_big_num: &ConstPointer<BIGNUM>,
) -> Result<LcPtr<EVP_PKEY>, Unspecified> {
let ec_key = DetachableLcPtr::new(EC_KEY_new())?;
if 1 != EC_KEY_set_group(*ec_key, **ec_group) {
let ec_key = DetachableLcPtr::new(unsafe { EC_KEY_new() })?;
if 1 != unsafe { EC_KEY_set_group(*ec_key, **ec_group) } {
return Err(Unspecified);
}
if 1 != EC_KEY_set_private_key(*ec_key, **private_big_num) {
if 1 != unsafe { EC_KEY_set_private_key(*ec_key, **private_big_num) } {
return Err(Unspecified);
}
let pub_key = LcPtr::new(EC_POINT_new(**ec_group))?;
if 1 != EC_POINT_mul(
**ec_group,
*pub_key,
**private_big_num,
null(),
null(),
null_mut(),
) {
let pub_key = LcPtr::new(unsafe { EC_POINT_new(**ec_group) })?;
if 1 != unsafe {
EC_POINT_mul(
**ec_group,
*pub_key,
**private_big_num,
null(),
null(),
null_mut(),
)
} {
return Err(Unspecified);
}
if 1 != EC_KEY_set_public_key(*ec_key, *pub_key) {
if 1 != unsafe { EC_KEY_set_public_key(*ec_key, *pub_key) } {
return Err(Unspecified);
}
let expected_curve_nid = EC_GROUP_get_curve_name(**ec_group);
let expected_curve_nid = unsafe { EC_GROUP_get_curve_name(**ec_group) };

let pkey = LcPtr::new(unsafe { EVP_PKEY_new() })?;

Expand Down Expand Up @@ -519,24 +521,26 @@ pub(crate) unsafe fn evp_key_from_public_private(
}

#[inline]
pub(crate) unsafe fn ec_group_from_nid(nid: i32) -> Result<LcPtr<EC_GROUP>, ()> {
LcPtr::new(EC_GROUP_new_by_curve_name(nid))
pub(crate) fn ec_group_from_nid(nid: i32) -> Result<LcPtr<EC_GROUP>, ()> {
LcPtr::new(unsafe { EC_GROUP_new_by_curve_name(nid) })
}

#[inline]
pub(crate) unsafe fn ec_point_from_bytes(
pub(crate) fn ec_point_from_bytes(
ec_group: &LcPtr<EC_GROUP>,
bytes: &[u8],
) -> Result<LcPtr<EC_POINT>, Unspecified> {
let ec_point = LcPtr::new(EC_POINT_new(**ec_group))?;
let ec_point = LcPtr::new(unsafe { EC_POINT_new(**ec_group) })?;

if 1 != EC_POINT_oct2point(
**ec_group,
*ec_point,
bytes.as_ptr(),
bytes.len(),
null_mut(),
) {
if 1 != unsafe {
EC_POINT_oct2point(
**ec_group,
*ec_point,
bytes.as_ptr(),
bytes.len(),
null_mut(),
)
} {
return Err(Unspecified);
}

Expand Down
2 changes: 1 addition & 1 deletion aws-lc-rs/src/ec/key_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl EcdsaKeyPair {
pkcs8: &[u8],
) -> Result<Self, KeyRejected> {
// Includes a call to `EC_KEY_check_key`
let evp_pkey = LcPtr::try_from(pkcs8)?;
let evp_pkey: LcPtr<EVP_PKEY> = LcPtr::try_from(pkcs8)?;

#[cfg(not(feature = "fips"))]
verify_evp_key_nid(&evp_pkey.as_const(), alg.id.nid())?;
Expand Down
83 changes: 42 additions & 41 deletions aws-lc-rs/src/ed25519.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,16 @@ impl Ed25519KeyPair {
/// `error::Unspecified` on internal error.
///
pub fn to_pkcs8v1(&self) -> Result<Document, Unspecified> {
unsafe {
let evp_pkey: LcPtr<EVP_PKEY> = LcPtr::new(EVP_PKEY_new_raw_private_key(
let evp_pkey: LcPtr<EVP_PKEY> = LcPtr::new(unsafe {
EVP_PKEY_new_raw_private_key(
EVP_PKEY_ED25519,
null_mut(),
self.private_key.as_ref().as_ptr(),
ED25519_PRIVATE_KEY_SEED_LEN,
))?;
)
})?;

evp_pkey.marshall_private_key(Version::V1)
}
evp_pkey.marshall_private_key(Version::V1)
}

/// Constructs an Ed25519 key pair from the private key seed `seed` and its
Expand All @@ -268,27 +268,27 @@ impl Ed25519KeyPair {
return Err(KeyRejected::inconsistent_components());
}

let mut derived_public_key = MaybeUninit::<[u8; ED25519_PUBLIC_KEY_LEN]>::uninit();
let mut private_key = MaybeUninit::<[u8; ED25519_PRIVATE_KEY_LEN]>::uninit();
unsafe {
let mut derived_public_key = MaybeUninit::<[u8; ED25519_PUBLIC_KEY_LEN]>::uninit();
let mut private_key = MaybeUninit::<[u8; ED25519_PRIVATE_KEY_LEN]>::uninit();
ED25519_keypair_from_seed(
derived_public_key.as_mut_ptr().cast(),
private_key.as_mut_ptr().cast(),
seed.as_ptr(),
);
let derived_public_key = derived_public_key.assume_init();
let mut private_key = private_key.assume_init();

constant_time::verify_slices_are_equal(public_key, &derived_public_key)
.map_err(|_| KeyRejected::inconsistent_components())?;

let key_pair = Self {
private_key: Box::new(private_key),
public_key: PublicKey(derived_public_key),
};
private_key.zeroize();
Ok(key_pair)
}
let derived_public_key = unsafe { derived_public_key.assume_init() };
let mut private_key = unsafe { private_key.assume_init() };

constant_time::verify_slices_are_equal(public_key, &derived_public_key)
.map_err(|_| KeyRejected::inconsistent_components())?;

let key_pair = Self {
private_key: Box::new(private_key),
public_key: PublicKey(derived_public_key),
};
private_key.zeroize();
Ok(key_pair)
}

/// Constructs an Ed25519 key pair by parsing an unencrypted PKCS#8 v1 or v2
Expand Down Expand Up @@ -330,33 +330,34 @@ impl Ed25519KeyPair {
}

fn parse_pkcs8(pkcs8: &[u8]) -> Result<Self, KeyRejected> {
unsafe {
let evp_pkey = LcPtr::try_from(pkcs8)?;
let evp_pkey: LcPtr<EVP_PKEY> = LcPtr::try_from(pkcs8)?;

evp_pkey.validate_as_ed25519()?;
evp_pkey.validate_as_ed25519()?;

let mut private_key = [0u8; ED25519_PRIVATE_KEY_LEN];
let mut out_len: usize = ED25519_PRIVATE_KEY_LEN;
if 1 != EVP_PKEY_get_raw_private_key(*evp_pkey, private_key.as_mut_ptr(), &mut out_len)
{
return Err(KeyRejected::wrong_algorithm());
}
let mut private_key = [0u8; ED25519_PRIVATE_KEY_LEN];
let mut out_len: usize = ED25519_PRIVATE_KEY_LEN;
if 1 != unsafe {
EVP_PKEY_get_raw_private_key(*evp_pkey, private_key.as_mut_ptr(), &mut out_len)
} {
return Err(KeyRejected::wrong_algorithm());
}

let mut public_key = [0u8; ED25519_PUBLIC_KEY_LEN];
let mut out_len: usize = ED25519_PUBLIC_KEY_LEN;
if 1 != EVP_PKEY_get_raw_public_key(*evp_pkey, public_key.as_mut_ptr(), &mut out_len) {
return Err(KeyRejected::wrong_algorithm());
}
private_key[ED25519_PRIVATE_KEY_SEED_LEN..].copy_from_slice(&public_key);
let mut public_key = [0u8; ED25519_PUBLIC_KEY_LEN];
let mut out_len: usize = ED25519_PUBLIC_KEY_LEN;
if 1 != unsafe {
EVP_PKEY_get_raw_public_key(*evp_pkey, public_key.as_mut_ptr(), &mut out_len)
} {
return Err(KeyRejected::wrong_algorithm());
}
private_key[ED25519_PRIVATE_KEY_SEED_LEN..].copy_from_slice(&public_key);

let key_pair = Self {
private_key: Box::new(private_key),
public_key: PublicKey(public_key),
};
private_key.zeroize();
let key_pair = Self {
private_key: Box::new(private_key),
public_key: PublicKey(public_key),
};
private_key.zeroize();

Ok(key_pair)
}
Ok(key_pair)
}

/// Returns the signature of the message msg.
Expand Down

0 comments on commit 7c68274

Please sign in to comment.