From 6d384a4aac297283d786339b92ac177f54892e65 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 Mar 2024 17:17:35 -0700 Subject: [PATCH] Added message compacting and new /msg_multi endpoint --- node/src/http_client.rs | 96 +++++++++++++++++++++++++++++------- node/src/protocol/message.rs | 4 +- node/src/web/mod.rs | 31 +++++++++++- 3 files changed, 109 insertions(+), 22 deletions(-) diff --git a/node/src/http_client.rs b/node/src/http_client.rs index 8dccb036a..36b6eacc8 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -2,7 +2,7 @@ use crate::protocol::contract::primitives::{ParticipantInfo, Participants}; use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; use cait_sith::protocol::Participant; -use mpc_keys::hpke; +use mpc_keys::hpke::Ciphered; use reqwest::{Client, IntoUrl}; use std::collections::{HashMap, HashSet, VecDeque}; use std::str::Utf8Error; @@ -33,26 +33,21 @@ pub enum SendError { ParticipantNotAlive(String), } -async fn send_encrypted( +async fn send_encrypted_multi( from: Participant, - cipher_pk: &hpke::PublicKey, - sign_sk: &near_crypto::SecretKey, client: &Client, url: U, - message: &MpcMessage, + message: Vec, ) -> Result<(), SendError> { - let encrypted = SignedMessage::encrypt(message, from, sign_sk, cipher_pk) - .map_err(|err| SendError::EncryptionError(err.to_string()))?; - let _span = tracing::info_span!("message_request"); let mut url = url.into_url()?; - url.set_path("msg"); + url.set_path("msg_multi"); tracing::debug!(?from, to = %url, "making http request: sending encrypted message"); let action = || async { let response = client .post(url.clone()) .header("content-type", "application/json") - .json(&encrypted) + .json(&message) .send() .await .map_err(SendError::ReqwestClientError)?; @@ -111,6 +106,10 @@ impl MessageQueue { let mut failed = VecDeque::new(); let mut errors = Vec::new(); let mut participant_counter = HashMap::new(); + + let outer = Instant::now(); + let uncompacted = self.deque.len(); + let mut encrypted = HashMap::new(); while let Some((info, msg, instant)) = self.deque.pop_front() { if !participants.contains_key(&Participant::from(info.id)) { if instant.elapsed() > message_type_to_timeout(&msg) { @@ -124,21 +123,55 @@ impl MessageQueue { failed.push_back((info, msg, instant)); continue; } + if instant.elapsed() > message_type_to_timeout(&msg) { + errors.push(SendError::Timeout(format!( + "message has timed out: {info:?}" + ))); + continue; + } - if let Err(err) = - send_encrypted(from, &info.cipher_pk, sign_sk, client, &info.url, &msg).await - { - if instant.elapsed() > message_type_to_timeout(&msg) { - errors.push(SendError::Timeout(format!( - "message has timed out: {err:?}" - ))); + let encrypted_msg = match SignedMessage::encrypt(&msg, from, sign_sk, &info.cipher_pk) { + Ok(encrypted) => encrypted, + Err(err) => { + errors.push(SendError::EncryptionError(err.to_string())); continue; } + }; - failed.push_back((info, msg, instant)); - errors.push(err); + let (encrypted, msgs) = encrypted + .entry(info.id) + .or_insert_with(|| (Vec::new(), Vec::new())); + + encrypted.push(encrypted_msg); + msgs.push((info, msg, instant)); + } + + let mut compacted = 0; + for (id, (encrypted, msgs)) in encrypted { + let partitioned = partition_ciphered(encrypted); + compacted += partitioned.len(); + + for encrypted_partition in partitioned { + let info = participants.get(&Participant::from(id)).unwrap(); + if let Err(err) = + send_encrypted_multi(from, client, &info.url, encrypted_partition).await + { + for (info, msg, instant) in msgs { + failed.push_back((info, msg, instant)); + } + errors.push(err); + break; + } } } + + tracing::info!( + uncompacted, + compacted, + "{from:?} sent messages in {:?};", + outer.elapsed() + ); + // only add the participant count if it hasn't been seen before. let counts = format!("{participant_counter:?}"); if !participant_counter.is_empty() && self.seen_counts.insert(counts.clone()) { @@ -153,6 +186,31 @@ impl MessageQueue { } } +fn partition_ciphered(encrypted: Vec) -> Vec> { + let mut result: Vec> = Vec::new(); + let mut current_partition: Vec = Vec::new(); + let mut current_size: usize = 0; + + for ciphered in encrypted { + let bytesize = ciphered.text.len(); + if current_size + bytesize > 256 * 1024 { + // If adding this byte vector exceeds 1MB, start a new partition + result.push(current_partition); + current_partition = Vec::new(); + current_size = 0; + } + current_partition.push(ciphered); + current_size += bytesize; + } + + if !current_partition.is_empty() { + // Add the last partition + result.push(current_partition); + } + + result +} + const fn message_type_to_timeout(msg: &MpcMessage) -> Duration { match msg { MpcMessage::Generating(_) => MESSAGE_TIMEOUT, diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index c0707a317..f5a1465b6 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -376,12 +376,12 @@ where T: Serialize, { pub fn encrypt( - msg: T, + msg: &T, from: Participant, sign_sk: &near_crypto::SecretKey, cipher_pk: &hpke::PublicKey, ) -> Result { - let msg = serde_json::to_vec(&msg)?; + let msg = serde_json::to_vec(msg)?; let sig = sign_sk.sign(&msg); let msg = SignedMessage { msg, sig, from }; let msg = serde_json::to_vec(&msg)?; diff --git a/node/src/web/mod.rs b/node/src/web/mod.rs index 2b0f5c08e..a798d46bf 100644 --- a/node/src/web/mod.rs +++ b/node/src/web/mod.rs @@ -45,6 +45,7 @@ pub async fn run( }), ) .route("/msg", post(msg)) + .route("/msg_multi", post(msg_multi)) .route("/state", get(state)) .route("/metrics", get(metrics)) .layer(Extension(Arc::new(axum_state))); @@ -86,7 +87,35 @@ async fn msg( Ok(()) } -#[derive(Serialize, Deserialize)] +#[tracing::instrument(level = "debug", skip_all)] +async fn msg_multi( + Extension(state): Extension>, + WithRejection(Json(encrypted), _): WithRejection>, Error>, +) -> Result<()> { + for encrypted in encrypted.into_iter() { + let message = match SignedMessage::decrypt( + &state.cipher_sk, + &state.protocol_state, + encrypted, + ) + .await + { + Ok(msg) => msg, + Err(err) => { + tracing::error!(?err, "failed to decrypt or verify an encrypted message"); + return Err(err.into()); + } + }; + + if let Err(err) = state.sender.send(message).await { + tracing::error!(?err, "failed to forward an encrypted protocol message"); + return Err(err.into()); + } + } + Ok(()) +} + +#[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] pub enum StateView {