Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: /msg endpoint now takes vectorized encrypted messages #549

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 92 additions & 39 deletions node/src/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::protocol::contract::primitives::{ParticipantInfo, Participants};
use crate::protocol::message::SignedMessage;
use crate::protocol::MpcMessage;
use cait_sith::protocol::Participant;
use mpc_keys::hpke;
use mpc_keys::hpke::Ciphered;
use reqwest::{Client, IntoUrl};
use std::collections::{HashMap, HashSet, VecDeque};
use std::str::Utf8Error;
Expand Down Expand Up @@ -35,15 +35,10 @@ pub enum SendError {

async fn send_encrypted<U: IntoUrl>(
from: Participant,
cipher_pk: &hpke::PublicKey,
sign_sk: &near_crypto::SecretKey,
client: &Client,
url: U,
message: &MpcMessage,
message: Vec<Ciphered>,
) -> Result<(), SendError> {
let encrypted = SignedMessage::encrypt(message, from, sign_sk, cipher_pk)
.map_err(|err| SendError::EncryptionError(err.to_string()))?;

let _span = tracing::info_span!("message_request");
let mut url = url.into_url()?;
url.set_path("msg");
Expand All @@ -52,7 +47,7 @@ async fn send_encrypted<U: IntoUrl>(
let response = client
.post(url.clone())
.header("content-type", "application/json")
.json(&encrypted)
.json(&message)
.send()
.await
.map_err(SendError::ReqwestClientError)?;
Expand Down Expand Up @@ -111,49 +106,77 @@ impl MessageQueue {
let mut failed = VecDeque::new();
let mut errors = Vec::new();
let mut participant_counter = HashMap::new();

let outer = Instant::now();
let uncompacted = self.deque.len();
let mut encrypted = HashMap::new();
while let Some((info, msg, instant)) = self.deque.pop_front() {
let account_id = info.account_id.clone();
if instant.elapsed() > message_type_to_timeout(&msg) {
errors.push(SendError::Timeout(format!(
"{} message has timed out: {info:?}",
msg.typename(),
)));
continue;
}

if !participants.contains_key(&Participant::from(info.id)) {
if instant.elapsed() > message_type_to_timeout(&msg) {
errors.push(SendError::Timeout(format!(
"message has timed out on offline node: {info:?}",
)));
continue;
}
let counter = participant_counter.entry(info.id).or_insert(0);
*counter += 1;
failed.push_back((info, msg, instant));
continue;
}

let start = Instant::now();
crate::metrics::NUM_SEND_ENCRYPTED_TOTAL
.with_label_values(&[&account_id.as_ref()])
.inc();
if let Err(err) =
send_encrypted(from, &info.cipher_pk, sign_sk, client, &info.url, &msg).await
{
crate::metrics::NUM_SEND_ENCRYPTED_FAILURE
.with_label_values(&[&account_id.as_ref()])
.inc();
crate::metrics::FAILED_SEND_ENCRYPTED_LATENCY
.with_label_values(&[&account_id.as_ref()])
.observe(start.elapsed().as_millis() as f64);
if instant.elapsed() > message_type_to_timeout(&msg) {
errors.push(SendError::Timeout(format!(
"message has timed out: {err:?}"
)));
let encrypted_msg = match SignedMessage::encrypt(&msg, from, sign_sk, &info.cipher_pk) {
Ok(encrypted) => encrypted,
Err(err) => {
errors.push(SendError::EncryptionError(err.to_string()));
continue;
}
};
let encrypted = encrypted.entry(info.id).or_insert_with(Vec::new);
encrypted.push((encrypted_msg, (info, msg, instant)));
}

failed.push_back((info, msg, instant));
errors.push(err);
} else {
crate::metrics::SEND_ENCRYPTED_LATENCY
.with_label_values(&[&account_id.as_ref()])
.observe(start.elapsed().as_millis() as f64);
let mut compacted = 0;
for (id, encrypted) in encrypted {
for partition in partition_ciphered_256kb(encrypted) {
let (encrypted_partition, msgs): (Vec<_>, Vec<_>) = partition.into_iter().unzip();
// guaranteed to unwrap due to our previous loop check:
let info = participants.get(&Participant::from(id)).unwrap();
let account_id = &info.account_id;

let start = Instant::now();
crate::metrics::NUM_SEND_ENCRYPTED_TOTAL
.with_label_values(&[account_id.as_ref()])
.inc();
if let Err(err) = send_encrypted(from, client, &info.url, encrypted_partition).await
{
crate::metrics::NUM_SEND_ENCRYPTED_FAILURE
.with_label_values(&[account_id.as_ref()])
.inc();
crate::metrics::FAILED_SEND_ENCRYPTED_LATENCY
.with_label_values(&[account_id.as_ref()])
.observe(start.elapsed().as_millis() as f64);

// since we failed, put back all the messages related to this
failed.extend(msgs);
errors.push(err);
} else {
compacted += msgs.len();
crate::metrics::SEND_ENCRYPTED_LATENCY
.with_label_values(&[account_id.as_ref()])
.observe(start.elapsed().as_millis() as f64);
}
}
}

if uncompacted > 0 {
tracing::debug!(
uncompacted,
compacted,
"{from:?} sent messages in {:?};",
outer.elapsed()
);
}
// only add the participant count if it hasn't been seen before.
let counts = format!("{participant_counter:?}");
if !participant_counter.is_empty() && self.seen_counts.insert(counts.clone()) {
Expand All @@ -168,6 +191,36 @@ impl MessageQueue {
}
}

/// Encrypted message with a reference to the old message. Only the ciphered portion of this
/// type will be sent over the wire, while the original message is kept just in case things
/// go wrong somewhere and the message needs to be requeued to be sent later.
type EncryptedMessage = (Ciphered, (ParticipantInfo, MpcMessage, Instant));

fn partition_ciphered_256kb(encrypted: Vec<EncryptedMessage>) -> Vec<Vec<EncryptedMessage>> {
let mut result = Vec::new();
let mut current_partition = Vec::new();
let mut current_size: usize = 0;

for ciphered in encrypted {
let bytesize = ciphered.0.text.len();
if current_size + bytesize > 256 * 1024 {
// If adding this byte vector exceeds 256kb, start a new partition
result.push(current_partition);
current_partition = Vec::new();
current_size = 0;
}
current_partition.push(ciphered);
current_size += bytesize;
}

if !current_partition.is_empty() {
// Add the last partition
result.push(current_partition);
}

result
}

fn message_type_to_timeout(msg: &MpcMessage) -> Duration {
match msg {
MpcMessage::Generating(_) => MESSAGE_TIMEOUT,
Expand Down
16 changes: 14 additions & 2 deletions node/src/protocol/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ pub enum MpcMessage {
Signature(SignatureMessage),
}

impl MpcMessage {
pub const fn typename(&self) -> &'static str {
match self {
MpcMessage::Generating(_) => "Generating",
MpcMessage::Resharing(_) => "Resharing",
MpcMessage::Triple(_) => "Triple",
MpcMessage::Presignature(_) => "Presignature",
MpcMessage::Signature(_) => "Signature",
}
}
}

#[derive(Default)]
pub struct MpcMessageQueue {
generating: VecDeque<GeneratingMessage>,
Expand Down Expand Up @@ -376,12 +388,12 @@ where
T: Serialize,
{
pub fn encrypt(
msg: T,
msg: &T,
from: Participant,
sign_sk: &near_crypto::SecretKey,
cipher_pk: &hpke::PublicKey,
) -> Result<Ciphered, CryptographicError> {
let msg = serde_json::to_vec(&msg)?;
let msg = serde_json::to_vec(msg)?;
let sig = sign_sk.sign(&msg);
let msg = SignedMessage { msg, sig, from };
let msg = serde_json::to_vec(&msg)?;
Expand Down
21 changes: 14 additions & 7 deletions node/src/web/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,32 @@ pub struct MsgRequest {
#[tracing::instrument(level = "debug", skip_all)]
async fn msg(
Extension(state): Extension<Arc<AxumState>>,
WithRejection(Json(encrypted), _): WithRejection<Json<Ciphered>, Error>,
WithRejection(Json(encrypted), _): WithRejection<Json<Vec<Ciphered>>, Error>,
) -> Result<()> {
let message =
match SignedMessage::decrypt(&state.cipher_sk, &state.protocol_state, encrypted).await {
for encrypted in encrypted.into_iter() {
let message = match SignedMessage::decrypt(
&state.cipher_sk,
&state.protocol_state,
encrypted,
)
.await
{
Ok(msg) => msg,
Err(err) => {
tracing::error!(?err, "failed to decrypt or verify an encrypted message");
return Err(err.into());
}
};

if let Err(err) = state.sender.send(message).await {
tracing::error!(?err, "failed to forward an encrypted protocol message");
return Err(err.into());
if let Err(err) = state.sender.send(message).await {
tracing::error!(?err, "failed to forward an encrypted protocol message");
return Err(err.into());
}
}
Ok(())
}

#[derive(Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum StateView {
Expand Down
Loading