From 5fee0e9ed6fb17126ce19865aae6679f91ea55c2 Mon Sep 17 00:00:00 2001 From: Daniyar Itegulov Date: Thu, 9 Nov 2023 16:57:49 +1100 Subject: [PATCH] address PR comments --- Cargo.lock | 1 - mpc-recovery/Cargo.toml | 3 +- mpc-recovery/src/oauth.rs | 2 +- .../src/sign_node/aggregate_signer.rs | 4 +- node/src/protocol/cryptography.rs | 12 +- node/src/protocol/message.rs | 48 +++++-- node/src/protocol/mod.rs | 2 +- node/src/protocol/triple.rs | 136 ++++++++++-------- node/src/util.rs | 4 +- 9 files changed, 129 insertions(+), 83 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c85ebd9dc..6d0718f64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3629,7 +3629,6 @@ dependencies = [ "opentelemetry-otlp 0.13.0", "opentelemetry-semantic-conventions 0.12.0", "prometheus", - "rand 0.7.3", "rand 0.8.5", "reqwest", "rsa", diff --git a/mpc-recovery/Cargo.toml b/mpc-recovery/Cargo.toml index 77e3962df..a26baf621 100644 --- a/mpc-recovery/Cargo.toml +++ b/mpc-recovery/Cargo.toml @@ -34,8 +34,7 @@ opentelemetry-otlp = { version = "0.13.0", features = [ ] } opentelemetry-semantic-conventions = "0.12.0" prometheus = { version = "0.13.3", features = ["process"] } -rand = "0.7" -rand8 = { package = "rand", version = "0.8" } +rand = "0.8" reqwest = { version = "0.11.16", features = ["blocking"] } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/mpc-recovery/src/oauth.rs b/mpc-recovery/src/oauth.rs index e16ad110c..a22e57a1c 100644 --- a/mpc-recovery/src/oauth.rs +++ b/mpc-recovery/src/oauth.rs @@ -113,7 +113,7 @@ mod tests { use super::*; use chrono::{Duration, Utc}; use jsonwebtoken::{encode, EncodingKey, Header}; - use rand8::rngs::OsRng; + use rand::rngs::OsRng; use rsa::{ pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey}, RsaPrivateKey, RsaPublicKey, diff --git a/mpc-recovery/src/sign_node/aggregate_signer.rs b/mpc-recovery/src/sign_node/aggregate_signer.rs index af8e7544f..bd810d605 100644 --- a/mpc-recovery/src/sign_node/aggregate_signer.rs +++ b/mpc-recovery/src/sign_node/aggregate_signer.rs @@ -10,8 +10,8 @@ use curv::BigInt; use ed25519_dalek::{Sha512, Signature, Verifier}; use multi_party_eddsa::protocols; use multi_party_eddsa::protocols::aggsig::{self, KeyAgg, SignSecondMsg}; -use rand8::rngs::OsRng; -use rand8::Rng; +use rand::rngs::OsRng; +use rand::Rng; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index 0d0698554..8f4af1415 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -4,7 +4,7 @@ use crate::protocol::message::{GeneratingMessage, ResharingMessage}; use crate::protocol::state::WaitingForConsensusState; use crate::protocol::MpcMessage; use async_trait::async_trait; -use cait_sith::protocol::{Action, Participant}; +use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use k256::elliptic_curve::group::GroupEncoding; pub trait CryptographicCtx { @@ -18,6 +18,10 @@ pub enum CryptographicError { SendError(#[from] SendError), #[error("unknown participant: {0:?}")] UnknownParticipant(Participant), + #[error("cait-sith initialization error: {0}")] + CaitSithInitializationError(#[from] InitializationError), + #[error("cait-sith protocol error: {0}")] + CaitSithProtocolError(#[from] ProtocolError), } #[async_trait] @@ -36,7 +40,7 @@ impl CryptographicProtocol for GeneratingState { ) -> Result { tracing::info!("progressing key generation"); loop { - let action = self.protocol.poke().unwrap(); + let action = self.protocol.poke()?; match action { Action::Wait => { tracing::debug!("waiting"); @@ -170,9 +174,9 @@ impl CryptographicProtocol for RunningState { ctx: C, ) -> Result { if self.triple_manager.potential_len() < 2 { - self.triple_manager.generate(); + self.triple_manager.generate()?; } - for (p, msg) in self.triple_manager.poke() { + for (p, msg) in self.triple_manager.poke()? { let url = self.participants.get(&p).unwrap(); http_client::message(ctx.http_client(), url.clone(), MpcMessage::Triple(msg)).await?; } diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index 699c92c7a..0ef2724a3 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, VecDeque}; use super::state::{GeneratingState, NodeState, ResharingState, RunningState}; -use cait_sith::protocol::{MessageData, Participant}; +use cait_sith::protocol::{InitializationError, MessageData, Participant, ProtocolError}; use serde::{Deserialize, Serialize}; pub trait MessageCtx { @@ -63,49 +63,81 @@ impl MpcMessageQueue { } } +#[derive(thiserror::Error, Debug)] +pub enum MessageHandleError { + #[error("cait-sith initialization error: {0}")] + CaitSithInitializationError(#[from] InitializationError), + #[error("cait-sith protocol error: {0}")] + CaitSithProtocolError(#[from] ProtocolError), +} + pub trait MessageHandler { - fn handle(&mut self, ctx: C, queue: &mut MpcMessageQueue); + fn handle( + &mut self, + ctx: C, + queue: &mut MpcMessageQueue, + ) -> Result<(), MessageHandleError>; } impl MessageHandler for GeneratingState { - fn handle(&mut self, _ctx: C, queue: &mut MpcMessageQueue) { + fn handle( + &mut self, + _ctx: C, + queue: &mut MpcMessageQueue, + ) -> Result<(), MessageHandleError> { while let Some(msg) = queue.generating.pop_front() { tracing::debug!("handling new generating message"); self.protocol.message(msg.from, msg.data); } + Ok(()) } } impl MessageHandler for ResharingState { - fn handle(&mut self, _ctx: C, queue: &mut MpcMessageQueue) { + fn handle( + &mut self, + _ctx: C, + queue: &mut MpcMessageQueue, + ) -> Result<(), MessageHandleError> { let q = queue.resharing_bins.entry(self.old_epoch).or_default(); while let Some(msg) = q.pop_front() { tracing::debug!("handling new resharing message"); self.protocol.message(msg.from, msg.data); } + Ok(()) } } impl MessageHandler for RunningState { - fn handle(&mut self, _ctx: C, queue: &mut MpcMessageQueue) { + fn handle( + &mut self, + _ctx: C, + queue: &mut MpcMessageQueue, + ) -> Result<(), MessageHandleError> { for (id, queue) in queue.triple_bins.entry(self.epoch).or_default() { - if let Some(protocol) = self.triple_manager.get_or_generate(*id) { + if let Some(protocol) = self.triple_manager.get_or_generate(*id)? { while let Some(message) = queue.pop_front() { protocol.message(message.from, message.data); } } } + Ok(()) } } impl MessageHandler for NodeState { - fn handle(&mut self, ctx: C, queue: &mut MpcMessageQueue) { + fn handle( + &mut self, + ctx: C, + queue: &mut MpcMessageQueue, + ) -> Result<(), MessageHandleError> { match self { NodeState::Generating(state) => state.handle(ctx, queue), NodeState::Resharing(state) => state.handle(ctx, queue), NodeState::Running(state) => state.handle(ctx, queue), _ => { - tracing::debug!("skipping message processing") + tracing::debug!("skipping message processing"); + Ok(()) } } } diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index dad13c018..1c155efd9 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -148,7 +148,7 @@ impl MpcSignProtocol { let mut state = std::mem::take(&mut *state_guard); state = state.progress(&self.ctx).await?; state = state.advance(&self.ctx, contract_state).await?; - state.handle(&self.ctx, &mut queue); + state.handle(&self.ctx, &mut queue)?; *state_guard = state; drop(state_guard); tokio::time::sleep(Duration::from_millis(1000)).await; diff --git a/node/src/protocol/triple.rs b/node/src/protocol/triple.rs index ca9239b42..7c1bd7dc1 100644 --- a/node/src/protocol/triple.rs +++ b/node/src/protocol/triple.rs @@ -1,7 +1,7 @@ use super::message::TripleMessage; use crate::types::TripleProtocol; use crate::util::AffinePointExt; -use cait_sith::protocol::{Action, Participant}; +use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use cait_sith::triples::TripleGenerationOutput; use k256::Secp256k1; use std::collections::btree_map::Entry; @@ -57,24 +57,23 @@ impl TripleManager { } /// Starts a new Beaver triple generation protocol. - pub fn generate(&mut self) { + pub fn generate(&mut self) -> Result<(), InitializationError> { let id = rand::random(); tracing::info!(id, "starting protocol to generate a new triple"); - let protocol: TripleProtocol = Box::new( - cait_sith::triples::generate_triple(&self.participants, self.me, self.threshold) - .unwrap(), - ); + let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple( + &self.participants, + self.me, + self.threshold, + )?); self.generators.insert(id, protocol); + Ok(()) } /// Take an unspent triple by its id with no way to return it. /// It is very important to NOT reuse the same triple twice for two different /// protocols. pub fn take(&mut self, id: TripleId) -> Option> { - match self.triples.entry(id) { - Entry::Vacant(_) => None, - Entry::Occupied(entry) => Some(entry.remove()), - } + self.triples.remove(&id) } /// Ensures that the triple with the given id is either: @@ -82,21 +81,25 @@ impl TripleManager { /// 2) Is currently being generated by `protocol` in which case returns `Some(protocol)`, or /// 3) Has never been seen by the manager in which case start a new protocol and returns `Some(protocol)` // TODO: What if the triple completed generation and is already spent? - pub fn get_or_generate(&mut self, id: TripleId) -> Option<&mut TripleProtocol> { + pub fn get_or_generate( + &mut self, + id: TripleId, + ) -> Result, InitializationError> { if self.triples.contains_key(&id) { - None + Ok(None) } else { - Some(self.generators.entry(id).or_insert_with(|| { - tracing::info!(id, "joining protocol to generate a new triple"); - Box::new( - cait_sith::triples::generate_triple( + match self.generators.entry(id) { + Entry::Vacant(e) => { + tracing::info!(id, "joining protocol to generate a new triple"); + let protocol = cait_sith::triples::generate_triple( &self.participants, self.me, self.threshold, - ) - .unwrap(), - ) - })) + )?; + Ok(Some(e.insert(Box::new(protocol)))) + } + Entry::Occupied(e) => Ok(Some(e.into_mut())), + } } } @@ -104,52 +107,61 @@ impl TripleManager { /// messages to be sent to the respective participant. /// /// An empty vector means we cannot progress until we receive a new message. - pub fn poke(&mut self) -> Vec<(Participant, TripleMessage)> { + pub fn poke(&mut self) -> Result, ProtocolError> { let mut messages = Vec::new(); - self.generators.retain(|id, protocol| loop { - let action = protocol.poke().unwrap(); - match action { - Action::Wait => { - tracing::debug!("waiting"); - // Retain protocol until we are finished - return true; - } - Action::SendMany(data) => { - for p in &self.participants { - messages.push(( - *p, - TripleMessage { - id: *id, - epoch: self.epoch, - from: self.me, - data: data.clone(), - }, - )) + let mut result = Ok(()); + self.generators.retain(|id, protocol| { + loop { + let action = match protocol.poke() { + Ok(action) => action, + Err(e) => { + result = Err(e); + break false; + } + }; + match action { + Action::Wait => { + tracing::debug!("waiting"); + // Retain protocol until we are finished + break true; + } + Action::SendMany(data) => { + for p in &self.participants { + messages.push(( + *p, + TripleMessage { + id: *id, + epoch: self.epoch, + from: self.me, + data: data.clone(), + }, + )) + } + } + Action::SendPrivate(p, data) => messages.push(( + p, + TripleMessage { + id: *id, + epoch: self.epoch, + from: self.me, + data: data.clone(), + }, + )), + Action::Return(output) => { + tracing::info!( + id, + big_a = ?output.1.big_a.to_base58(), + big_b = ?output.1.big_b.to_base58(), + big_c = ?output.1.big_c.to_base58(), + "completed triple generation" + ); + self.triples.insert(*id, output); + // Do not retain the protocol + break false; } - } - Action::SendPrivate(p, data) => messages.push(( - p, - TripleMessage { - id: *id, - epoch: self.epoch, - from: self.me, - data: data.clone(), - }, - )), - Action::Return(output) => { - tracing::info!( - id, - big_a = ?output.1.big_a.into_base58(), - big_b = ?output.1.big_b.into_base58(), - big_c = ?output.1.big_c.into_base58(), - "completed triple generation" - ); - self.triples.insert(*id, output); - // Do not retain the protocol - return false; } } }); - messages + result.map(|_| messages) } } diff --git a/node/src/util.rs b/node/src/util.rs index 0a364f16a..cae923cc1 100644 --- a/node/src/util.rs +++ b/node/src/util.rs @@ -35,7 +35,7 @@ impl NearPublicKeyExt for near_crypto::PublicKey { pub trait AffinePointExt { fn into_near_public_key(self) -> near_crypto::PublicKey; - fn into_base58(self) -> String; + fn to_base58(&self) -> String; } impl AffinePointExt for AffinePoint { @@ -48,7 +48,7 @@ impl AffinePointExt for AffinePoint { ) } - fn into_base58(self) -> String { + fn to_base58(&self) -> String { let key = near_crypto::Secp256K1PublicKey::try_from( &self.to_encoded_point(false).as_bytes()[1..65], )