Skip to content

Commit

Permalink
refactor: organize code for constructing error packet
Browse files Browse the repository at this point in the history
  • Loading branch information
doitian committed Sep 24, 2024
1 parent 89f8f6b commit 74047e9
Showing 1 changed file with 146 additions and 88 deletions.
234 changes: 146 additions & 88 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
//! let hops_data = vec![vec![0], vec![1, 0], vec![5, 0, 1, 2, 3, 4]];
//! let get_length = |packet_data: &[u8]| Some(packet_data[0] as usize + 1);
//! let assoc_data = vec![0x42u8; 32];

//!
//! let packet = new_onion_packet(
//! 1300,
//! hops_path,
Expand Down Expand Up @@ -91,6 +91,8 @@ use thiserror::Error;
const HMAC_KEY_RHO: &[u8] = b"rho";
const HMAC_KEY_MU: &[u8] = b"mu";
const HMAC_KEY_PAD: &[u8] = b"pad";
const HMAC_KEY_UM: &[u8] = b"um";
const HMAC_KEY_AMMAG: &[u8] = b"ammag";
const CHACHA_NONCE: [u8; 12] = [0u8; 12];

#[derive(Debug, Clone, Eq, PartialEq)]
Expand All @@ -105,6 +107,11 @@ pub struct OnionPacket {
pub hmac: [u8; 32],
}

pub struct OnionErrorPacket {
/// Encrypted error-returning packet data.
pub packet_data: Vec<u8>,
}

impl OnionPacket {
/// Converts the onion packet into a byte vector.
pub fn into_bytes(self) -> Vec<u8> {
Expand All @@ -116,6 +123,11 @@ impl OnionPacket {
bytes
}

/// Derives the shared secret using the node secret key and the ephemeral public key in the onion packet.
pub fn shared_secret(&self, secret_key: &SecretKey) -> [u8; 32] {
SharedSecret::new(&self.public_key, secret_key).secret_bytes()
}

/// Peels the onion packet at the current hop.
///
/// - `secret_key`: the node private key. _x_<sub>i</sub> in the specification.
Expand All @@ -136,7 +148,7 @@ impl OnionPacket {
F: FnOnce(&[u8]) -> Option<usize>,
{
let packet_data_len = self.packet_data.len();
let shared_secret = SharedSecret::new(&self.public_key, secret_key);
let shared_secret = self.shared_secret(secret_key);
let rho = derive_key(HMAC_KEY_RHO, shared_secret.as_ref());
let mu = derive_key(HMAC_KEY_MU, shared_secret.as_ref());

Expand Down Expand Up @@ -196,6 +208,117 @@ pub enum SphinxError {
HopDataLenTooLarge,
}

/// Keys used to forward the onion packet.
pub struct ForwardKeys {
/// Key derived from the shared secret for the hop. It is used to encrypt the packet data.
pub rho: [u8; 32],
/// Key derived from the shared secret for the hop. It is used to compute the HMAC of the packet data.
pub mu: [u8; 32],
}

/// Keys used to return the error packet.
pub struct ReturnKeys {
/// Key derived from the shared secret for the hop. It is used to encrypt the error packet data.
pub ammag: [u8; 32],
/// Key derived from the shared secret for the hop. It is used to compute the HMAC of the error packet data.
pub um: [u8; 32],
}

/// Shared secrets generator.
///
/// ## Example
///
/// ```rust
/// use secp256k1::{PublicKey, SecretKey, Secp256k1};
/// use fiber_sphinx::{SharedSecretIter};
///
/// let secp = Secp256k1::new();
/// let hops_keys = vec![
/// SecretKey::from_slice(&[0x20; 32]).expect("32 bytes, within curve order"),
/// SecretKey::from_slice(&[0x21; 32]).expect("32 bytes, within curve order"),
/// SecretKey::from_slice(&[0x22; 32]).expect("32 bytes, within curve order"),
/// ];
/// let hops_path: Vec<_> = hops_keys.iter().map(|sk| sk.public_key(&secp)).collect();
/// let session_key = SecretKey::from_slice(&[0x41; 32]).expect("32 bytes, within curve order");
/// // Gets shared secrets for each hop
/// let hops_ss: Vec<_> = SharedSecretIter::new(hops_path.into_iter(), session_key, &secp).collect();
/// ```
#[derive(Clone)]
pub struct SharedSecretIter<'a, I, C: Signing> {
/// A list of node public keys
hops_path_iter: I,
ephemeral_secret_key: SecretKey,
secp_ctx: &'a Secp256k1<C>,
}

impl<'a, I, C: Signing> SharedSecretIter<'a, I, C> {
/// Creates an iterator to generate shared secrets for each hop.
///
/// - `hops_path`: The public keys for each hop. These are _y_<sub>i</sub> in the specification.
/// - `session_key`: The ephemeral secret key for the onion packet. It must be generated securely using a random process.
/// This is _x_ in the specification.
pub fn new(
hops_path_iter: I,
session_key: SecretKey,
secp_ctx: &'a Secp256k1<C>,
) -> SharedSecretIter<I, C> {
SharedSecretIter {
hops_path_iter,
secp_ctx,
ephemeral_secret_key: session_key,
}
}
}

impl<'a, I: Iterator<Item = PublicKey>, C: Signing> Iterator for SharedSecretIter<'a, I, C> {
type Item = [u8; 32];

fn next(&mut self) -> Option<Self::Item> {
self.hops_path_iter.next().map(|pk| {
let shared_secret = SharedSecret::new(&pk, &self.ephemeral_secret_key);

let ephemeral_public_key = self.ephemeral_secret_key.public_key(self.secp_ctx);
self.ephemeral_secret_key = derive_next_hop_ephemeral_secret_key(
self.ephemeral_secret_key,
&ephemeral_public_key,
shared_secret.as_ref(),
);

shared_secret.secret_bytes()
})
}
}

/// Derive keys for forwarding the onion packet from the shared secret.
pub fn derive_forward_keys(shared_secret: &[u8]) -> ForwardKeys {
ForwardKeys {
rho: derive_key(HMAC_KEY_RHO, shared_secret),
mu: derive_key(HMAC_KEY_MU, shared_secret),
}
}

/// Derive keys for returning the error onion packet from the shared secret.
pub fn derive_return_keys(shared_secret: &[u8]) -> ReturnKeys {
ReturnKeys {
ammag: derive_key(HMAC_KEY_AMMAG, shared_secret),
um: derive_key(HMAC_KEY_UM, shared_secret),
}
}

/// Derives keys for forwarding the onion packet.
pub fn derive_hops_forward_keys<C: Signing>(
hops_path: &Vec<PublicKey>,
session_key: SecretKey,
secp_ctx: &Secp256k1<C>,
) -> Vec<ForwardKeys> {
SharedSecretIter::new(hops_path.iter().cloned(), session_key, secp_ctx)
.map(|shared_secret| ForwardKeys {
rho: derive_key(HMAC_KEY_RHO, shared_secret.as_ref()),
mu: derive_key(HMAC_KEY_MU, shared_secret.as_ref()),
})
.collect()
}

#[inline]
fn shift_slice_right(arr: &mut [u8], amt: usize) {
for i in (amt..arr.len()).rev() {
Expand All @@ -217,9 +340,9 @@ fn shift_slice_left(arr: &mut [u8], amt: usize) {
}
}

/// Computes hmac of packet_data and optional associated data using the key `mu`.
fn compute_hmac(mu: &[u8; 32], packet_data: &[u8], assoc_data: Option<&[u8]>) -> [u8; 32] {
let mut hmac_engine = Hmac::<Sha256>::new_from_slice(mu).expect("valid hmac key");
/// Computes hmac of packet_data and optional associated data using the key `hmac_key`.
fn compute_hmac(hmac_key: &[u8; 32], packet_data: &[u8], assoc_data: Option<&[u8]>) -> [u8; 32] {
let mut hmac_engine = Hmac::<Sha256>::new_from_slice(hmac_key).expect("valid hmac key");
hmac_engine.update(&packet_data);
if let Some(ref assoc_data) = assoc_data {
hmac_engine.update(assoc_data);
Expand Down Expand Up @@ -288,46 +411,6 @@ fn derive_next_hop_ephemeral_public_key<C: Verification>(
.expect("valid mul tweak")
}

// Keys manager for each hop
struct HopKeys {
/// Ephemeral public key for the hop. The _alpha_ in the specification.
ephemeral_public_key: PublicKey,
/// Key derived from the shared secret for the hop. It is used to encrypt the packet data.
rho: [u8; 32],
/// Key derived from the shared secret for the hop. It is used to compute the HMAC of the packet data.
mu: [u8; 32],
}

/// Derives HopKeys for each hop.
fn derive_hops_keys<C: Signing>(
hops_path: &Vec<PublicKey>,
session_key: SecretKey,
secp_ctx: &Secp256k1<C>,
) -> Vec<HopKeys> {
hops_path
.iter()
.scan(session_key, |ephemeral_secret_key, pk| {
let ephemeral_public_key = ephemeral_secret_key.public_key(secp_ctx);

let shared_secret = SharedSecret::new(pk, ephemeral_secret_key);
let rho = derive_key(HMAC_KEY_RHO, shared_secret.as_ref());
let mu = derive_key(HMAC_KEY_MU, shared_secret.as_ref());

*ephemeral_secret_key = derive_next_hop_ephemeral_secret_key(
*ephemeral_secret_key,
&ephemeral_public_key,
shared_secret.as_ref(),
);

Some(HopKeys {
ephemeral_public_key,
rho,
mu,
})
})
.collect()
}

/// Derives a key from the shared secret using HMAC.
fn derive_key(hmac_key: &[u8], shared_secret: &[u8]) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(hmac_key).expect("valid hmac key");
Expand All @@ -348,7 +431,7 @@ fn generate_padding_data(packet_data_len: usize, pad_key: &[u8]) -> Vec<u8> {
/// Generates the filler to obfuscate the onion packet.
fn generate_filler(
packet_data_len: usize,
hops_keys: &[HopKeys],
hops_keys: &[ForwardKeys],
hops_data: &[Vec<u8>],
) -> Result<Vec<u8>, SphinxError> {
let mut filler = Vec::new();
Expand Down Expand Up @@ -378,14 +461,16 @@ fn generate_filler(
/// Constructs the onion packet internally.
///
/// - `packet_data`: The initial 1300 bytes of the onion packet generated by `generate_padding_data`.
/// - `hops_keys`: The keys for each hop generated by `derive_hops_keys`.
/// - `public_key`: The ephemeral public key for the first hop.
/// - `hops_keys`: The keys for each hop generated by `derive_hops_forward_keys`.
/// - `hops_data`: The unencrypted data for each hop.
/// - `assoc_data`: The associated data. It will not be included in the packet itself but will be covered by the packet's
/// HMAC. This allows each hop to verify that the associated data has not been tampered with.
/// - `filler`: The filler to obfuscate the packet data, which is generated by `generate_filler`.
fn construct_onion_packet(
mut packet_data: Vec<u8>,
hops_keys: &[HopKeys],
public_key: PublicKey,
hops_keys: &[ForwardKeys],
hops_data: &[Vec<u8>],
assoc_data: Option<Vec<u8>>,
filler: Vec<u8>,
Expand Down Expand Up @@ -414,7 +499,7 @@ fn construct_onion_packet(

Ok(OnionPacket {
version: 0,
public_key: hops_keys.first().unwrap().ephemeral_public_key,
public_key,
packet_data,
hmac,
})
Expand Down Expand Up @@ -446,12 +531,20 @@ pub fn new_onion_packet(
return Err(SphinxError::HopsIsEmpty);
}

let hops_keys = derive_hops_keys(&hops_path, session_key, &Secp256k1::new());
let secp_ctx = Secp256k1::new();
let hops_keys = derive_hops_forward_keys(&hops_path, session_key, &secp_ctx);
let pad_key = derive_key(HMAC_KEY_PAD, &session_key.secret_bytes());
let packet_data = generate_padding_data(packet_data_len, &pad_key);
let filler = generate_filler(packet_data_len, &hops_keys, &hops_data)?;

construct_onion_packet(packet_data, &hops_keys, &hops_data, assoc_data, filler)
construct_onion_packet(
packet_data,
session_key.public_key(&secp_ctx),
&hops_keys,
&hops_data,
assoc_data,
filler,
)
}

#[cfg(test)]
Expand Down Expand Up @@ -491,18 +584,11 @@ mod tests {
fn test_derive_hops_keys() {
let hops_path = get_test_hops_path();
let session_key = get_test_session_key();
let hops_keys = derive_hops_keys(&hops_path, session_key, &Secp256k1::new());
let hops_keys = derive_hops_forward_keys(&hops_path, session_key, &Secp256k1::new());

assert_eq!(hops_keys.len(), 5);

// hop 0
assert_eq!(
hops_keys[0]
.ephemeral_public_key
.serialize()
.to_lower_hex_string(),
"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619",
);
assert_eq!(
hops_keys[0].rho.to_lower_hex_string(),
"ce496ec94def95aadd4bec15cdb41a740c9f2b62347c4917325fcc6fb0453986",
Expand All @@ -513,13 +599,6 @@ mod tests {
);

// hop 1
assert_eq!(
hops_keys[1]
.ephemeral_public_key
.serialize()
.to_lower_hex_string(),
"028f9438bfbf7feac2e108d677e3a82da596be706cc1cf342b75c7b7e22bf4e6e2",
);
assert_eq!(
hops_keys[1].rho.to_lower_hex_string(),
"450ffcabc6449094918ebe13d4f03e433d20a3d28a768203337bc40b6e4b2c59",
Expand All @@ -530,13 +609,6 @@ mod tests {
);

// hop 2
assert_eq!(
hops_keys[2]
.ephemeral_public_key
.serialize()
.to_lower_hex_string(),
"03bfd8225241ea71cd0843db7709f4c222f62ff2d4516fd38b39914ab6b83e0da0",
);
assert_eq!(
hops_keys[2].rho.to_lower_hex_string(),
"11bf5c4f960239cb37833936aa3d02cea82c0f39fd35f566109c41f9eac8deea",
Expand All @@ -547,13 +619,6 @@ mod tests {
);

// hop 3
assert_eq!(
hops_keys[3]
.ephemeral_public_key
.serialize()
.to_lower_hex_string(),
"031dde6926381289671300239ea8e57ffaf9bebd05b9a5b95beaf07af05cd43595",
);
assert_eq!(
hops_keys[3].rho.to_lower_hex_string(),
"cbe784ab745c13ff5cffc2fbe3e84424aa0fd669b8ead4ee562901a4a4e89e9e",
Expand All @@ -564,13 +629,6 @@ mod tests {
);

// hop 4
assert_eq!(
hops_keys[4]
.ephemeral_public_key
.serialize()
.to_lower_hex_string(),
"03a214ebd875aab6ddfd77f22c5e7311d7f77f17a169e599f157bbcdae8bf071f4",
);
assert_eq!(
hops_keys[4].rho.to_lower_hex_string(),
"034e18b8cc718e8af6339106e706c52d8df89e2b1f7e9142d996acf88df8799b",
Expand Down Expand Up @@ -606,7 +664,7 @@ mod tests {
fn test_generate_filler() {
let hops_path = get_test_hops_path();
let session_key = get_test_session_key();
let hops_keys = derive_hops_keys(&hops_path, session_key, &Secp256k1::new());
let hops_keys = derive_hops_forward_keys(&hops_path, session_key, &Secp256k1::new());
let hops_data = get_test_hops_data();

let filler = generate_filler(PACKET_DATA_LEN, &hops_keys, &hops_data);
Expand Down

0 comments on commit 74047e9

Please sign in to comment.