diff --git a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs index e3a19ef185..b799ca4585 100644 --- a/applications/tari_base_node/src/grpc/base_node_grpc_server.rs +++ b/applications/tari_base_node/src/grpc/base_node_grpc_server.rs @@ -1003,11 +1003,7 @@ impl tari_rpc::base_node_server::BaseNode for BaseNodeGrpcServer { .state_info .get_block_sync_info() .map(|info| { - let node_ids = info - .sync_peers - .iter() - .map(|x| x.to_string().as_bytes().to_vec()) - .collect(); + let node_ids = info.sync_peers.iter().map(|x| x.to_string().into_bytes()).collect(); tari_rpc::SyncInfoResponse { tip_height: info.tip_height, local_height: info.local_height, diff --git a/applications/tari_base_node/src/main.rs b/applications/tari_base_node/src/main.rs index 8fabbf59ca..7a35a38021 100644 --- a/applications/tari_base_node/src/main.rs +++ b/applications/tari_base_node/src/main.rs @@ -1,4 +1,3 @@ -#![recursion_limit = "1024"] // Copyright 2019. The Tari Project // // Redistribution and use in source and binary forms, with or without modification, are permitted provided that the diff --git a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs index c18784d82a..cf20599a9b 100644 --- a/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/tari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -98,9 +98,9 @@ impl wallet_server::Wallet for WalletGrpcServer { async fn identify(&self, _: Request) -> Result, Status> { let identity = self.wallet.comms.node_identity(); Ok(Response::new(GetIdentityResponse { - public_key: identity.public_key().to_string().as_bytes().to_vec(), + public_key: identity.public_key().to_string().into_bytes(), public_address: identity.public_address().to_string(), - node_id: identity.node_id().to_string().as_bytes().to_vec(), + node_id: identity.node_id().to_string().into_bytes(), })) } diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index 052bad580b..5459d6e17f 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -1546,8 +1546,7 @@ struct KeyManagerStateUpdateSql { impl Encryptable for KeyManagerStateSql { fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), Error> { let encrypted_master_key = encrypt_bytes_integral_nonce(&cipher, self.master_key.clone())?; - let encrypted_branch_seed = - encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().as_bytes().to_vec())?; + let encrypted_branch_seed = encrypt_bytes_integral_nonce(&cipher, self.branch_seed.clone().into_bytes())?; self.master_key = encrypted_master_key; self.branch_seed = encrypted_branch_seed.to_hex(); Ok(()) diff --git a/base_layer/wallet/src/storage/sqlite_db.rs b/base_layer/wallet/src/storage/sqlite_db.rs index 95c7cc52ce..881dd7aa04 100644 --- a/base_layer/wallet/src/storage/sqlite_db.rs +++ b/base_layer/wallet/src/storage/sqlite_db.rs @@ -588,7 +588,7 @@ impl ClientKeyValueSql { impl Encryptable for ClientKeyValueSql { #[allow(unused_assignments)] fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> { - let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.clone().value.as_bytes().to_vec())?; + let encrypted_value = encrypt_bytes_integral_nonce(&cipher, self.value.as_bytes().to_vec())?; self.value = encrypted_value.to_hex(); Ok(()) } diff --git a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs index 700590562e..f3287dcc8d 100644 --- a/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/transaction_service/storage/sqlite_db.rs @@ -1028,8 +1028,7 @@ impl InboundTransactionSql { impl Encryptable for InboundTransactionSql { fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> { - let encrypted_protocol = - encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.clone().as_bytes().to_vec())?; + let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.receiver_protocol.as_bytes().to_vec())?; self.receiver_protocol = encrypted_protocol.to_hex(); Ok(()) } @@ -1211,8 +1210,7 @@ impl OutboundTransactionSql { impl Encryptable for OutboundTransactionSql { fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> { - let encrypted_protocol = - encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.clone().as_bytes().to_vec())?; + let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.sender_protocol.as_bytes().to_vec())?; self.sender_protocol = encrypted_protocol.to_hex(); Ok(()) } @@ -1534,8 +1532,7 @@ impl CompletedTransactionSql { impl Encryptable for CompletedTransactionSql { fn encrypt(&mut self, cipher: &Aes256Gcm) -> Result<(), AeadError> { - let encrypted_protocol = - encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.clone().as_bytes().to_vec())?; + let encrypted_protocol = encrypt_bytes_integral_nonce(&cipher, self.transaction_protocol.as_bytes().to_vec())?; self.transaction_protocol = encrypted_protocol.to_hex(); Ok(()) } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index c6224a7af7..0e056bf62b 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -599,7 +599,7 @@ mod test { dht_header: DhtMessageHeader, stored_at: NaiveDateTime, ) -> StoredMessage { - let body = message.as_bytes().to_vec(); + let body = message.into_bytes(); let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize()); StoredMessage { id: 1, diff --git a/comms/src/proto/rpc.proto b/comms/src/proto/rpc.proto index d82b006d9d..5d03b89224 100644 --- a/comms/src/proto/rpc.proto +++ b/comms/src/proto/rpc.proto @@ -16,7 +16,7 @@ message RpcRequest { uint64 deadline = 4; // The message payload - bytes message = 10; + bytes payload = 10; } // Message type for all RPC responses @@ -29,7 +29,7 @@ message RpcResponse { uint32 flags = 3; // The message payload. If the status is non-zero, this contains additional error details. - bytes message = 10; + bytes payload = 10; } // Message sent by the client when negotiating an RPC session. A server may close the substream if it does diff --git a/comms/src/protocol/rpc/body.rs b/comms/src/protocol/rpc/body.rs index e563d6483e..cda655afc5 100644 --- a/comms/src/protocol/rpc/body.rs +++ b/comms/src/protocol/rpc/body.rs @@ -177,6 +177,10 @@ impl BodyBytes { pub fn into_vec(self) -> Vec { self.0.map(|bytes| bytes.to_vec()).unwrap_or_else(Vec::new) } + + pub fn into_bytes(self) -> Option { + self.0 + } } #[allow(clippy::from_over_into)] @@ -186,10 +190,9 @@ impl Into for BodyBytes { } } -#[allow(clippy::from_over_into)] -impl Into> for BodyBytes { - fn into(self) -> Vec { - self.into_vec() +impl From for Vec { + fn from(body: BodyBytes) -> Self { + body.into_vec() } } diff --git a/comms/src/protocol/rpc/client.rs b/comms/src/protocol/rpc/client.rs index a6d6554f3e..acf9d19114 100644 --- a/comms/src/protocol/rpc/client.rs +++ b/comms/src/protocol/rpc/client.rs @@ -34,6 +34,7 @@ use crate::{ Response, RpcError, RpcStatus, + RPC_CHUNKING_MAX_CHUNKS, }, ProtocolId, }, @@ -239,7 +240,7 @@ where TClient: From + NamedProtocolService } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct RpcClientConfig { pub deadline: Option, pub deadline_grace_period: Duration, @@ -489,7 +490,8 @@ impl RpcClientWorker { self.protocol_name(), start.elapsed() ); - let resp = match self.read_reply().await { + let mut reader = RpcResponseReader::new(&mut self.framed, self.config, 0); + let resp = match reader.read_ack().await { Ok(resp) => resp, Err(RpcError::ReplyTimeout) => { debug!( @@ -529,7 +531,7 @@ impl RpcClientWorker { Ok(()) } - #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply))] + #[tracing::instrument(name = "rpc_do_request_response", skip(self, reply, request), fields(request_method = ?request.method, request_size = request.message.len()))] async fn do_request_response( &mut self, request: BaseRequest, @@ -542,7 +544,7 @@ impl RpcClientWorker { method, deadline: self.config.deadline.map(|t| t.as_secs()).unwrap_or(0), flags: 0, - message: request.message.to_vec(), + payload: request.message.to_vec(), }; debug!(target: LOG_TARGET, "Sending request: {}", req); @@ -575,14 +577,14 @@ impl RpcClientWorker { } loop { - let resp = match self.read_reply().await { + let resp = match self.read_response(request_id).await { Ok(resp) => { let latency = start.elapsed(); event!(Level::TRACE, "Message received"); trace!( target: LOG_TARGET, "Received response ({} byte(s)) from request #{} (protocol = {}, method={}) in {:.0?}", - resp.message.len(), + resp.payload.len(), request_id, self.protocol_name(), method, @@ -617,12 +619,19 @@ impl RpcClientWorker { break; }, Err(err) => { - event!(Level::ERROR, "Errored:{}", err); + event!( + Level::WARN, + "Request {} (method={}) returned an error after {:.0?}: {}", + request_id, + method, + start.elapsed(), + err + ); return Err(err); }, }; - match Self::convert_to_result(resp, request_id) { + match Self::convert_to_result(resp) { Ok(Ok(resp)) => { // The consumer may drop the receiver before all responses are received. // We just ignore that as we still want obey the protocol and receive messages until the FIN flag or @@ -665,27 +674,10 @@ impl RpcClientWorker { Ok(()) } - async fn read_reply(&mut self) -> Result { - // Wait until the timeout, allowing an extra grace period to account for latency - let next_msg_fut = match self.config.timeout_with_grace_period() { - Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())), - None => Either::Right(self.framed.next().map(Ok)), - }; - - let result = tokio::select! { - biased; - _ = &mut self.shutdown_signal => { - return Err(RpcError::ClientClosed); - } - result = next_msg_fut => result, - }; - - match result { - Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), - Ok(Some(Err(err))) => Err(err.into()), - Ok(None) => Err(RpcError::ServerClosedRequest), - Err(_) => Err(RpcError::ReplyTimeout), - } + async fn read_response(&mut self, request_id: u16) -> Result { + let mut reader = RpcResponseReader::new(&mut self.framed, self.config, request_id); + let resp = reader.read_response().await?; + Ok(resp) } fn next_request_id(&mut self) -> u16 { @@ -695,25 +687,7 @@ impl RpcClientWorker { next_id } - fn convert_to_result( - resp: proto::rpc::RpcResponse, - request_id: u16, - ) -> Result, RpcStatus>, RpcError> { - let resp_id = u16::try_from(resp.request_id) - .map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?; - - let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8); - if flags.contains(RpcMessageFlags::ACK) { - return Err(RpcError::UnexpectedAckResponse); - } - - if resp_id != request_id { - return Err(RpcError::ResponseIdDidNotMatchRequest { - expected: request_id, - actual: resp.request_id as u16, - }); - } - + fn convert_to_result(resp: proto::rpc::RpcResponse) -> Result, RpcStatus>, RpcError> { let status = RpcStatus::from(&resp); if !status.is_ok() { return Ok(Err(status)); @@ -721,7 +695,7 @@ impl RpcClientWorker { let resp = Response { flags: resp.flags(), - message: resp.message.into(), + payload: resp.payload.into(), }; Ok(Ok(resp)) @@ -736,3 +710,91 @@ pub enum ClientRequest { GetLastRequestLatency(oneshot::Sender>), SendPing(oneshot::Sender>), } + +struct RpcResponseReader<'a> { + framed: &'a mut CanonicalFraming, + config: RpcClientConfig, + request_id: u16, +} +impl<'a> RpcResponseReader<'a> { + pub fn new(framed: &'a mut CanonicalFraming, config: RpcClientConfig, request_id: u16) -> Self { + Self { + framed, + config, + request_id, + } + } + + pub async fn read_response(&mut self) -> Result { + let mut resp = self.next().await?; + self.check_response(&resp)?; + let mut chunk_count = 1; + let mut last_chunk_flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8); + let mut last_chunk_size = resp.payload.len(); + loop { + trace!( + target: LOG_TARGET, + "Chunk {} received (flags={:?}, {} bytes, {} total)", + chunk_count, + last_chunk_flags, + last_chunk_size, + resp.payload.len() + ); + if !last_chunk_flags.is_more() { + return Ok(resp); + } + + if chunk_count >= RPC_CHUNKING_MAX_CHUNKS { + return Err(RpcError::ExceededMaxChunkCount { + expected: RPC_CHUNKING_MAX_CHUNKS, + }); + } + + let msg = self.next().await?; + last_chunk_flags = RpcMessageFlags::from_bits_truncate(msg.flags as u8); + last_chunk_size = msg.payload.len(); + self.check_response(&resp)?; + resp.payload.extend(msg.payload); + chunk_count += 1; + } + } + + pub async fn read_ack(&mut self) -> Result { + let resp = self.next().await?; + Ok(resp) + } + + fn check_response(&self, resp: &proto::rpc::RpcResponse) -> Result<(), RpcError> { + let resp_id = u16::try_from(resp.request_id) + .map_err(|_| RpcStatus::protocol_error(format!("invalid request_id: must be less than {}", u16::MAX)))?; + + let flags = RpcMessageFlags::from_bits_truncate(resp.flags as u8); + if flags.contains(RpcMessageFlags::ACK) { + return Err(RpcError::UnexpectedAckResponse); + } + + if resp_id != self.request_id { + return Err(RpcError::ResponseIdDidNotMatchRequest { + expected: self.request_id, + actual: resp.request_id as u16, + }); + } + + Ok(()) + } + + async fn next(&mut self) -> Result { + // Wait until the timeout, allowing an extra grace period to account for latency + let next_msg_fut = match self.config.timeout_with_grace_period() { + Some(timeout) => Either::Left(time::timeout(timeout, self.framed.next())), + None => Either::Right(self.framed.next().map(Ok)), + }; + + match next_msg_fut.await { + Ok(Some(Ok(resp))) => Ok(proto::rpc::RpcResponse::decode(resp)?), + Ok(Some(Err(err))) => Err(err.into()), + Ok(None) => Err(RpcError::ServerClosedRequest), + Err(_) => Err(RpcError::ReplyTimeout), + } + } +} diff --git a/comms/src/protocol/rpc/error.rs b/comms/src/protocol/rpc/error.rs index 64f811f9d8..42abfebfd3 100644 --- a/comms/src/protocol/rpc/error.rs +++ b/comms/src/protocol/rpc/error.rs @@ -65,6 +65,8 @@ pub enum RpcError { InvalidPingResponse, #[error("Unexpected ACK response. This is likely because of a previous ACK timeout")] UnexpectedAckResponse, + #[error("Attempted to send more than {expected} payload chunks")] + ExceededMaxChunkCount { expected: usize }, #[error(transparent)] UnknownError(#[from] anyhow::Error), } diff --git a/comms/src/protocol/rpc/message.rs b/comms/src/protocol/rpc/message.rs index dedd7e04fb..099a70331b 100644 --- a/comms/src/protocol/rpc/message.rs +++ b/comms/src/protocol/rpc/message.rs @@ -28,6 +28,7 @@ use crate::{ body::{Body, IntoBody}, context::RequestContext, error::HandshakeRejectReason, + RpcStatusCode, }, }; use bitflags::bitflags; @@ -136,14 +137,14 @@ impl BaseRequest { #[derive(Debug, Clone)] pub struct Response { pub flags: RpcMessageFlags, - pub message: T, + pub payload: T, } impl Response { pub fn from_message(message: T) -> Self { Self { flags: Default::default(), - message: message.into_body(), + payload: message.into_body(), } } } @@ -151,7 +152,7 @@ impl Response { impl Response { pub fn new(message: T) -> Self { Self { - message, + payload: message, flags: Default::default(), } } @@ -160,7 +161,7 @@ impl Response { where F: FnMut(T) -> U { Response { flags: self.flags, - message: f(self.message), + payload: f(self.payload), } } @@ -169,7 +170,7 @@ impl Response { } pub fn into_message(self) -> T { - self.message + self.payload } } @@ -201,6 +202,8 @@ bitflags! { const FIN = 0x01; /// Typically sent with empty contents and used to confirm a substream is alive. const ACK = 0x02; + /// Another chunk to be received + const MORE = 0x04; } } impl RpcMessageFlags { @@ -211,6 +214,10 @@ impl RpcMessageFlags { pub fn is_ack(&self) -> bool { self.contains(Self::ACK) } + + pub fn is_more(&self) -> bool { + self.contains(Self::MORE) + } } impl Default for RpcMessageFlags { @@ -239,12 +246,42 @@ impl fmt::Display for proto::rpc::RpcRequest { self.request_id, self.deadline(), self.flags(), - self.message.len() + self.payload.len() ) } } //---------------------------------- RpcResponse --------------------------------------------// +#[derive(Debug, Clone)] +pub struct RpcResponse { + pub request_id: u32, + pub status: RpcStatusCode, + pub flags: RpcMessageFlags, + pub payload: Bytes, +} + +impl RpcResponse { + pub fn to_proto(&self) -> proto::rpc::RpcResponse { + proto::rpc::RpcResponse { + request_id: self.request_id, + status: self.status as u32, + flags: self.flags.bits().into(), + payload: self.payload.to_vec(), + } + } +} + +impl Default for RpcResponse { + fn default() -> Self { + Self { + request_id: 0, + status: RpcStatusCode::Ok, + flags: Default::default(), + payload: Default::default(), + } + } +} + impl proto::rpc::RpcResponse { pub fn flags(&self) -> RpcMessageFlags { RpcMessageFlags::from_bits_truncate(self.flags as u8) @@ -262,7 +299,7 @@ impl fmt::Display for proto::rpc::RpcResponse { "RequestID={}, Flags={:?}, Message={} byte(s)", self.request_id, self.flags(), - self.message.len() + self.payload.len() ) } } diff --git a/comms/src/protocol/rpc/mod.rs b/comms/src/protocol/rpc/mod.rs index 33208df391..f48b286ba0 100644 --- a/comms/src/protocol/rpc/mod.rs +++ b/comms/src/protocol/rpc/mod.rs @@ -26,6 +26,24 @@ #[cfg(test)] mod test; +/// Maximum frame size of each RPC message. This is enforced in tokio's length delimited codec. +pub const RPC_MAX_FRAME_SIZE: usize = 1024 * 1024; // 1 MiB +/// Maximum allowed chunks +const RPC_CHUNKING_MAX_CHUNKS: usize = 16; // 16 x 256 Kib = 4 MiB max combined message size +const RPC_CHUNKING_SIZE_LIMIT: usize = 384 * 1024; +const RPC_CHUNKING_THRESHOLD: usize = 256 * 1024; + +/// Convenience function that returns the maximum size for a single RPC message +const fn max_message_size() -> usize { + RPC_CHUNKING_MAX_CHUNKS * RPC_CHUNKING_THRESHOLD +} + +const fn max_payload_size() -> usize { + // 3 fields. VarInt(u32::MAX) is 5 bytes and each field id will be 1 byte each + const MAX_HEADER_SIZE: usize = 3 + 3 * 5; + max_message_size() - MAX_HEADER_SIZE +} + mod body; pub use body::{Body, ClientStreaming, IntoBody, Streaming}; @@ -56,9 +74,6 @@ pub use status::{RpcStatus, RpcStatusCode}; mod not_found; -/// Maximum frame size of each RPC message. This is enforced in tokio's length delimited codec. -pub const RPC_MAX_FRAME_SIZE: usize = 4 * 1024 * 1024; // 4 MiB - // Re-exports used to keep things orderly in the #[tari_rpc] proc macro pub mod __macro_reexports { pub use crate::{ diff --git a/comms/src/protocol/rpc/server/chunking.rs b/comms/src/protocol/rpc/server/chunking.rs new file mode 100644 index 0000000000..69e3c6153e --- /dev/null +++ b/comms/src/protocol/rpc/server/chunking.rs @@ -0,0 +1,256 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::LOG_TARGET; +use crate::{ + proto, + protocol::{ + rpc, + rpc::{ + message::{RpcMessageFlags, RpcResponse}, + RpcStatusCode, + RPC_CHUNKING_SIZE_LIMIT, + RPC_CHUNKING_THRESHOLD, + }, + }, +}; +use bytes::Bytes; +use log::*; +use std::cmp; + +pub(super) struct ChunkedResponseIter { + message: RpcResponse, + is_empty_payload: bool, + has_emitted_once: bool, + num_chunks: usize, + total_chunks: usize, +} + +fn calculate_total_chunk_count(payload_len: usize) -> usize { + let mut total_chunks = payload_len / RPC_CHUNKING_THRESHOLD; + let excess = (payload_len % RPC_CHUNKING_THRESHOLD) + RPC_CHUNKING_THRESHOLD; + if total_chunks == 0 || excess > RPC_CHUNKING_SIZE_LIMIT { + // If the chunk (threshold size) + excess cannot fit in the RPC_CHUNKING_SIZE_LIMIT, then we'll emit another + // frame smaller than threshold size + total_chunks += 1; + } + + total_chunks +} + +impl ChunkedResponseIter { + pub fn new(message: RpcResponse) -> Self { + let len = message.payload.len(); + Self { + is_empty_payload: message.payload.is_empty(), + message, + has_emitted_once: false, + num_chunks: 0, + total_chunks: calculate_total_chunk_count(len), + } + } + + fn remaining(&self) -> usize { + self.message.payload.len() + } + + fn payload_mut(&mut self) -> &mut Bytes { + &mut self.message.payload + } + + fn payload(&self) -> &Bytes { + &self.message.payload + } + + fn next_chunk(&mut self) -> Option { + let len = self.payload().len(); + if len == 0 { + if self.num_chunks > 1 { + debug!(target: LOG_TARGET, "Sent {} chunks", self.num_chunks); + } + return None; + } + + // If the payload is within the maximum chunk size, simply return the rest of it + if len <= RPC_CHUNKING_SIZE_LIMIT { + let chunk = self.payload_mut().split_to(len); + self.num_chunks += 1; + trace!( + target: LOG_TARGET, + "Emitting chunk {}/{} ({} bytes)", + self.num_chunks, + self.total_chunks, + chunk.len() + ); + return Some(chunk); + } + + let chunk_size = cmp::min(len, RPC_CHUNKING_THRESHOLD); + let chunk = self.payload_mut().split_to(chunk_size); + + self.num_chunks += 1; + trace!( + target: LOG_TARGET, + "Emitting chunk {}/{} ({} bytes)", + self.num_chunks, + self.total_chunks, + chunk.len() + ); + Some(chunk) + } + + fn is_last_chunk(&self) -> bool { + self.num_chunks == self.total_chunks + } + + fn exceeded_message_size(&self) -> proto::rpc::RpcResponse { + const BYTES_PER_MB: f32 = 1024.0 * 1024.0; + let msg = format!( + "The response size exceeded the maximum allowed payload size. Max = {:.4} MiB, Got = {:.4} MiB", + rpc::max_payload_size() as f32 / BYTES_PER_MB, + self.message.payload.len() as f32 / BYTES_PER_MB, + ); + warn!(target: LOG_TARGET, "{}", msg); + proto::rpc::RpcResponse { + request_id: self.message.request_id, + status: RpcStatusCode::MalformedResponse as u32, + flags: RpcMessageFlags::FIN.bits().into(), + payload: msg.into_bytes(), + } + } +} + +impl Iterator for ChunkedResponseIter { + type Item = proto::rpc::RpcResponse; + + fn next(&mut self) -> Option { + // Edge case: the initial message has an empty payload. + if self.is_empty_payload { + if self.has_emitted_once { + return None; + } + self.has_emitted_once = true; + return Some(self.message.to_proto()); + } + + // Edge case: the total message size cannot fit into the maximum allowed chunks + if self.remaining() > rpc::max_payload_size() { + if self.has_emitted_once { + return None; + } + self.has_emitted_once = true; + return Some(self.exceeded_message_size()); + } + + let request_id = self.message.request_id; + let chunk = self.next_chunk()?; + + // status MUST be set for the first chunked message, all subsequent chunk messages MUST have a status of 0 + let mut status = 0; + if !self.has_emitted_once { + status = self.message.status as u32; + } + self.has_emitted_once = true; + + let mut flags = self.message.flags; + if !self.is_last_chunk() { + // For all chunks except the last the MORE flag MUST be set + flags |= RpcMessageFlags::MORE; + } + let msg = proto::rpc::RpcResponse { + request_id, + status, + flags: flags.bits().into(), + payload: chunk.to_vec(), + }; + + Some(msg) + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::iter; + + fn create(size: usize) -> ChunkedResponseIter { + let msg = RpcResponse { + payload: iter::repeat(0).take(size).collect(), + ..Default::default() + }; + ChunkedResponseIter::new(msg) + } + + #[test] + fn it_emits_a_zero_size_message() { + let iter = create(0); + assert_eq!(iter.total_chunks, 1); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 1); + assert!(!RpcMessageFlags::from_bits_truncate(msgs[0].flags as u8).is_more()); + } + + #[test] + fn it_emits_one_message_below_threshold() { + let iter = create(RPC_CHUNKING_THRESHOLD - 1); + assert_eq!(iter.total_chunks, 1); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 1); + assert!(!RpcMessageFlags::from_bits_truncate(msgs[0].flags as u8).is_more()); + } + + #[test] + fn it_emits_a_single_message() { + let iter = create(RPC_CHUNKING_SIZE_LIMIT - 1); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 1); + + let iter = create(RPC_CHUNKING_SIZE_LIMIT); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 1); + } + + #[test] + fn it_emits_an_expected_number_of_chunks() { + let iter = create(RPC_CHUNKING_THRESHOLD * 2); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 2); + + let diff = RPC_CHUNKING_SIZE_LIMIT - RPC_CHUNKING_THRESHOLD; + let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 2); + + let iter = create(RPC_CHUNKING_THRESHOLD * 2 + diff + 1); + let msgs = iter.collect::>(); + assert_eq!(msgs.len(), 3); + } + + #[test] + fn it_sets_the_more_flag_except_last() { + let iter = create(RPC_CHUNKING_THRESHOLD * 3); + let msgs = iter.collect::>(); + assert!(RpcMessageFlags::from_bits_truncate(msgs[0].flags as u8).is_more()); + assert!(RpcMessageFlags::from_bits_truncate(msgs[1].flags as u8).is_more()); + assert!(!RpcMessageFlags::from_bits_truncate(msgs[2].flags as u8).is_more()); + } +} diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index d2b5f842eb..50f5244781 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -20,6 +20,9 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +mod chunking; +use chunking::ChunkedResponseIter; + mod error; pub use error::RpcServerError; @@ -40,7 +43,6 @@ use super::{ not_found::ProtocolServiceNotFound, status::RpcStatus, Handshake, - RpcStatusCode, RPC_MAX_FRAME_SIZE, }; use crate::{ @@ -50,11 +52,11 @@ use crate::{ message::MessageExt, peer_manager::NodeId, proto, - protocol::{ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, + protocol::{rpc::message::RpcResponse, ProtocolEvent, ProtocolId, ProtocolNotification, ProtocolNotificationRx}, Bytes, Substream, }; -use futures::SinkExt; +use futures::{stream, SinkExt, StreamExt}; use prost::Message; use std::{ future::Future, @@ -62,7 +64,6 @@ use std::{ time::{Duration, Instant}, }; use tokio::{sync::mpsc, time}; -use tokio_stream::StreamExt; use tower::Service; use tower_make::MakeService; use tracing::{debug, error, instrument, span, trace, warn, Instrument, Level}; @@ -431,7 +432,7 @@ where match result { Ok(frame) => { let start = Instant::now(); - if let Err(err) = self.handle(frame.freeze()).await { + if let Err(err) = self.handle_request(frame.freeze()).await { self.framed.close().await?; return Err(err); } @@ -460,8 +461,8 @@ where Ok(()) } - #[instrument(name = "rpc::server::handle_req", skip(self), err)] - async fn handle(&mut self, mut request: Bytes) -> Result<(), RpcServerError> { + #[instrument(name = "rpc::server::handle_req", skip(self, request), err, fields(request_size = request.len()))] + async fn handle_request(&mut self, mut request: Bytes) -> Result<(), RpcServerError> { let decoded_msg = proto::rpc::RpcRequest::decode(&mut request)?; let request_id = decoded_msg.request_id; @@ -483,7 +484,8 @@ where request_id, status: status.as_code(), flags: RpcMessageFlags::FIN.bits().into(), - message: status.details_bytes(), + payload: status.to_details_bytes(), + ..Default::default() }; self.framed.send(bad_request.to_encoded_bytes().into()).await?; return Ok(()); @@ -507,13 +509,13 @@ where debug!( target: LOG_TARGET, - "({}) Got request {}", self.logging_context_string, decoded_msg + "({}) Request: {}", self.logging_context_string, decoded_msg ); let req = Request::with_context( self.create_request_context(request_id), method, - decoded_msg.message.into(), + decoded_msg.payload.into(), ); let service_call = log_timing( @@ -536,90 +538,64 @@ where match service_result { Ok(body) => { - // This is the most basic way we can push responses back to the peer. Keeping this here for reference - // and possible future evaluation - // - // body.into_message() - // .map(|msg| match msg { - // Ok(msg) => { - // trace!(target: LOG_TARGET, "Sending body len = {}", msg.len()); - // let mut flags = RpcMessageFlags::empty(); - // if msg.is_finished() { - // flags |= RpcMessageFlags::FIN; - // } - // proto::rpc::RpcResponse { - // request_id, - // status: RpcStatus::ok().as_code(), - // flags: flags.bits().into(), - // message: msg.into(), - // } - // }, - // Err(err) => { - // debug!(target: LOG_TARGET, "Body contained an error: {}", err); - // proto::rpc::RpcResponse { - // request_id, - // status: err.as_code(), - // flags: RpcMessageFlags::FIN.bits().into(), - // message: err.details().as_bytes().to_vec(), - // } - // }, - // }) - // .map(|resp| Ok(resp.to_encoded_bytes().into())) - // .forward(PreventClose::new(sink)) - // .await?; - - let mut message = body.into_message(); + trace!(target: LOG_TARGET, "Service call succeeded"); + let mut stream = body + .into_message() + .map(|msg| match msg { + Ok(msg) => { + trace!(target: LOG_TARGET, "Sending body len = {}", msg.len()); + let mut flags = RpcMessageFlags::empty(); + if msg.is_finished() { + flags |= RpcMessageFlags::FIN; + } + RpcResponse { + request_id, + status: RpcStatus::ok().status_code(), + flags, + payload: msg.into_bytes().unwrap_or_else(Bytes::new), + } + }, + Err(err) => { + debug!(target: LOG_TARGET, "Body contained an error: {}", err); + RpcResponse { + request_id, + status: err.status_code(), + flags: RpcMessageFlags::FIN, + payload: Bytes::from(err.to_details_bytes()), + } + }, + }) + .flat_map(|message| stream::iter(ChunkedResponseIter::new(message))) + .map(|resp| Bytes::from(resp.to_encoded_bytes())); + loop { - let msg_read = log_timing( + let next_item = log_timing( self.logging_context_string.clone(), request_id, "message read", - message.next(), + stream.next(), ); - match time::timeout(deadline, msg_read).await { + match time::timeout(deadline, next_item).await { Ok(Some(msg)) => { - let resp = match msg { - Ok(msg) => { - trace!(target: LOG_TARGET, "Sending body len = {}", msg.len()); - let mut flags = RpcMessageFlags::empty(); - if msg.is_finished() { - flags |= RpcMessageFlags::FIN; - } - proto::rpc::RpcResponse { - request_id, - status: RpcStatus::ok().as_code(), - flags: flags.bits().into(), - message: msg.into(), - } - }, - Err(err) => { - debug!(target: LOG_TARGET, "Body contained an error: {}", err); - proto::rpc::RpcResponse { - request_id, - status: err.as_code(), - flags: RpcMessageFlags::FIN.bits().into(), - message: err.details().as_bytes().to_vec(), - } - }, - }; - - let is_valid = log_timing( - self.logging_context_string.clone(), - request_id, - "transmit", - self.send_response(request_id, resp), - ) - .await?; + trace!( + target: LOG_TARGET, + "({}) Sending body len = {}", + self.logging_context_string, + msg.len() + ); - if !is_valid { - break; - } + self.framed.send(msg).await?; + }, + Ok(None) => { + debug!(target: LOG_TARGET, "{} Request complete", self.logging_context_string,); + break; }, - Ok(None) => break, Err(_) => { debug!( target: LOG_TARGET, - "Failed to return result within client deadline ({:.0?})", deadline + "({}) Failed to return result within client deadline ({:.0?})", + self.logging_context_string, + deadline ); break; @@ -633,7 +609,7 @@ where request_id, status: err.as_code(), flags: RpcMessageFlags::FIN.bits().into(), - message: err.details_bytes(), + payload: err.to_details_bytes(), }; self.framed.send(resp.to_encoded_bytes().into()).await?; @@ -643,40 +619,6 @@ where Ok(()) } - /// Sends an RpcResponse on the given Sink. If the size of the message exceeds the RPC_MAX_FRAME_SIZE, an error is - /// returned to the client and false is returned from this function, otherwise the message is sent and true is - /// returned - async fn send_response(&mut self, request_id: u32, resp: proto::rpc::RpcResponse) -> Result { - match resp.to_encoded_bytes() { - buf if buf.len() > RPC_MAX_FRAME_SIZE => { - let msg = format!( - "This node tried to return a message that exceeds the maximum frame size. Max = {:.4} MiB, Got = \ - {:.4} MiB", - RPC_MAX_FRAME_SIZE as f32 / (1024.0 * 1024.0), - buf.len() as f32 / (1024.0 * 1024.0) - ); - warn!(target: LOG_TARGET, "{}", msg); - self.framed - .send( - proto::rpc::RpcResponse { - request_id, - status: RpcStatusCode::MalformedResponse as u32, - flags: RpcMessageFlags::FIN.bits().into(), - message: msg.as_bytes().to_vec(), - } - .to_encoded_bytes() - .into(), - ) - .await?; - Ok(false) - }, - buf => { - self.framed.send(buf.into()).await?; - Ok(true) - }, - } - } - fn create_request_context(&self, request_id: u32) -> RequestContext { RequestContext::new(request_id, self.node_id.clone(), Box::new(self.comms_provider.clone())) } diff --git a/comms/src/protocol/rpc/status.rs b/comms/src/protocol/rpc/status.rs index e0ddf7fe22..c0453f161d 100644 --- a/comms/src/protocol/rpc/status.rs +++ b/comms/src/protocol/rpc/status.rs @@ -118,7 +118,7 @@ impl RpcStatus { &self.details } - pub fn details_bytes(&self) -> Vec { + pub fn to_details_bytes(&self) -> Vec { self.details.as_bytes().to_vec() } @@ -155,7 +155,7 @@ impl<'a> From<&'a proto::rpc::RpcResponse> for RpcStatus { RpcStatus { code: status_code, - details: String::from_utf8_lossy(&resp.message).to_string(), + details: String::from_utf8_lossy(&resp.payload).to_string(), } } } diff --git a/comms/src/protocol/rpc/test/smoke.rs b/comms/src/protocol/rpc/test/smoke.rs index 553c0001cd..757db3f73d 100644 --- a/comms/src/protocol/rpc/test/smoke.rs +++ b/comms/src/protocol/rpc/test/smoke.rs @@ -24,6 +24,7 @@ use crate::{ framing, multiplexing::Yamux, protocol::{ + rpc, rpc::{ context::RpcCommsBackend, error::HandshakeRejectReason, @@ -192,8 +193,8 @@ async fn request_response_errors_and_streaming() { match err { // Because of the race between closing the request stream and sending on that stream in the above call // We can either get "this client was closed" or "the request you made was cancelled". - // If we delay some small time, we'll always get the former (but arbitrary delays cause flakiness and should be - // avoided) + // If we delay some small time, we'll probably always get the former (but arbitrary delays cause flakiness and + // should be avoided) RpcError::ClientClosed | RpcError::RequestCancelled => {}, err => panic!("Unexpected error {:?}", err), } @@ -248,21 +249,26 @@ async fn response_too_big() { let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; let socket = muxer.incoming_mut().next().await.unwrap(); - let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); - let mut client = GreetingClient::builder().connect(framed).await.unwrap(); + let framed = framing::canonical(socket, rpc::max_message_size()); + let mut client = GreetingClient::builder() + .with_deadline(Duration::from_secs(5)) + .connect(framed) + .await + .unwrap(); // RPC_MAX_FRAME_SIZE bytes will always be too large because of the overhead of the RpcResponse proto message let err = client - .reply_with_msg_of_size(RPC_MAX_FRAME_SIZE as u64) + .reply_with_msg_of_size(rpc::max_payload_size() as u64 + 1) .await .unwrap_err(); unpack_enum!(RpcError::RequestFailed(status) = err); unpack_enum!(RpcStatusCode::MalformedResponse = status.status_code()); // Check that the exact frame size boundary works and that the session is still going - // Take off 14 bytes for the RpcResponse overhead (i.e request_id + status + flags + msg field + vec_char(len(msg))) - let max_size = RPC_MAX_FRAME_SIZE - 14; - let _ = client.reply_with_msg_of_size(max_size as u64).await.unwrap(); + let _ = client + .reply_with_msg_of_size(rpc::max_payload_size() as u64 - 10) + .await + .unwrap(); } #[runtime::test] @@ -270,7 +276,7 @@ async fn ping_latency() { let (mut muxer, _outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; let socket = muxer.incoming_mut().next().await.unwrap(); - let framed = framing::canonical(socket, RPC_MAX_FRAME_SIZE); + let framed = framing::canonical(socket, 1024); let mut client = GreetingClient::builder().connect(framed).await.unwrap(); let latency = client.ping().await.unwrap();