diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 292c397930..1995715100 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -75,7 +75,6 @@ use crate::{ RpcError, RpcServerError, RpcStatus, - RPC_CHUNKING_MAX_CHUNKS, }, ProtocolId, }, @@ -932,53 +931,17 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin pub async fn read_response(&mut self) -> Result { let timer = Instant::now(); - let mut resp = self.next().await?; + let resp = self.next().await?; self.time_to_first_msg = Some(timer.elapsed()); self.check_response(&resp)?; - let mut chunk_count = 1; - let mut last_chunk_flags = - RpcMessageFlags::from_bits(u8::try_from(resp.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - let mut last_chunk_size = resp.payload.len(); - self.bytes_read += last_chunk_size; - loop { - trace!( - target: LOG_TARGET, - "Chunk {} received (flags={:?}, {} bytes, {} total)", - chunk_count, - last_chunk_flags, - last_chunk_size, - resp.payload.len() - ); - if !last_chunk_flags.is_more() { - return Ok(resp); - } - - if chunk_count >= RPC_CHUNKING_MAX_CHUNKS { - return Err(RpcError::RemotePeerExceededMaxChunkCount { - expected: RPC_CHUNKING_MAX_CHUNKS, - }); - } - - let msg = self.next().await?; - last_chunk_flags = RpcMessageFlags::from_bits(u8::try_from(msg.flags).map_err(|_| { - RpcStatus::protocol_error(&format!("invalid message flag: must be less than {}", u8::MAX)) - })?) - .ok_or(RpcStatus::protocol_error(&format!( - "invalid message flag, does not match any flags ({})", - resp.flags - )))?; - last_chunk_size = msg.payload.len(); - self.bytes_read += last_chunk_size; - self.check_response(&resp)?; - resp.payload.extend(msg.payload); - chunk_count += 1; - } + self.bytes_read = resp.payload.len(); + trace!( + target: LOG_TARGET, + "Received {} bytes in {:.2?}", + resp.payload.len(), + self.time_to_first_msg.unwrap_or_default() + ); + Ok(resp) } pub async fn read_ack(&mut self) -> Result { diff --git a/comms/core/src/protocol/rpc/message.rs b/comms/core/src/protocol/rpc/message.rs index 17ffd03bcd..ce0a62b7df 100644 --- a/comms/core/src/protocol/rpc/message.rs +++ b/comms/core/src/protocol/rpc/message.rs @@ -24,19 +24,24 @@ use std::{convert::TryFrom, fmt, time::Duration}; use bitflags::bitflags; use bytes::Bytes; +use log::warn; use super::RpcError; use crate::{ proto, proto::rpc::rpc_session_reply::SessionResult, - protocol::rpc::{ - body::{Body, IntoBody}, - context::RequestContext, - error::HandshakeRejectReason, - RpcStatusCode, + protocol::{ + rpc, + rpc::{ + body::{Body, IntoBody}, + context::RequestContext, + error::HandshakeRejectReason, + RpcStatusCode, + }, }, }; +const LOG_TARGET: &str = "comms::rpc::message"; #[derive(Debug)] pub struct Request { pub(super) context: Option, @@ -203,8 +208,6 @@ bitflags! { const FIN = 0x01; /// Typically sent with empty contents and used to confirm a substream is alive. const ACK = 0x02; - /// Another chunk to be received - const MORE = 0x04; } } impl RpcMessageFlags { @@ -215,10 +218,6 @@ impl RpcMessageFlags { pub fn is_ack(self) -> bool { self.contains(Self::ACK) } - - pub fn is_more(self) -> bool { - self.contains(Self::MORE) - } } impl Default for RpcMessageFlags { @@ -267,6 +266,51 @@ pub struct RpcResponse { pub payload: Bytes, } +impl RpcResponse { + pub fn to_proto(&self) -> proto::rpc::RpcResponse { + proto::rpc::RpcResponse { + request_id: self.request_id, + status: self.status as u32, + flags: self.flags.bits().into(), + payload: self.payload.to_vec(), + } + } + + pub fn exceeded_message_size(&mut self) -> RpcResponse { + const BYTES_PER_MB: f32 = 1024.0 * 1024.0; + // Precision loss is acceptable because this is for display purposes only + let msg = format!( + "The response size exceeded the maximum allowed payload size. Max = {:.4} MiB, Got = {:.4} MiB", + rpc::max_response_payload_size() as f32 / BYTES_PER_MB, + self.payload.len() as f32 / BYTES_PER_MB, + ); + warn!(target: LOG_TARGET, "{}", msg); + RpcResponse { + request_id: self.request_id, + status: RpcStatusCode::MalformedResponse, + flags: RpcMessageFlags::FIN, + payload: msg.into_bytes().into(), + } + } + + // pub fn exceeded_message_size(&self) -> proto::rpc::RpcResponse { + // const BYTES_PER_MB: f32 = 1024.0 * 1024.0; + // // Precision loss is acceptable because this is for display purposes only + // let msg = format!( + // "The response size exceeded the maximum allowed payload size. Max = {:.4} MiB, Got = {:.4} MiB", + // rpc::max_response_payload_size() as f32 / BYTES_PER_MB, + // self.payload.len() as f32 / BYTES_PER_MB, + // ); + // warn!(target: LOG_TARGET, "{}", msg); + // proto::rpc::RpcResponse { + // request_id: self.request_id, + // status: RpcStatusCode::MalformedResponse as u32, + // flags: RpcMessageFlags::FIN.bits().into(), + // payload: msg.into_bytes(), + // } + // } +} + impl Default for RpcResponse { fn default() -> Self { Self { diff --git a/comms/core/src/protocol/rpc/mod.rs b/comms/core/src/protocol/rpc/mod.rs index c5dd41f848..cc221868d6 100644 --- a/comms/core/src/protocol/rpc/mod.rs +++ b/comms/core/src/protocol/rpc/mod.rs @@ -31,14 +31,25 @@ mod test; /// Maximum frame size of each RPC message. This is enforced in tokio's length delimited codec. /// This can be thought of as the hard limit on message size. pub const RPC_MAX_FRAME_SIZE: usize = 3 * 1024 * 1024; // 3 MiB -/// Maximum number of chunks into which a message can be broken up. -const RPC_CHUNKING_MAX_CHUNKS: usize = 16; // 16 x 256 Kib = 4 MiB max combined message size +/// The maximum size for a single RPC response message +pub const RPC_MAX_RESPONSE_SIZE: usize = 4 * 1024 * 1024; // 4 MiB /// The maximum request payload size const fn max_request_size() -> usize { RPC_MAX_FRAME_SIZE } +/// The maximum size for a single RPC response excluding overhead +const fn max_response_payload_size() -> usize { + // RpcResponse overhead is: + // - 4 varint protobuf fields, each field ID is 1 byte + // - 3 u32 fields, VarInt(u32::MAX) is 5 bytes + // - 1 length varint for the payload, allow for 5 bytes to be safe (max_payload_size being technically too small is + // fine, being too large isn't) + const MAX_HEADER_SIZE: usize = 4 + 4 * 5; + RPC_MAX_RESPONSE_SIZE - MAX_HEADER_SIZE +} + mod body; pub use body::{Body, ClientStreaming, IntoBody, Streaming}; diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 54e098836e..69bb5bb242 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -78,6 +78,7 @@ use crate::{ peer_manager::NodeId, proto, protocol::{ + rpc, rpc::{ body::BodyBytes, message::{RpcMethod, RpcResponse}, @@ -748,17 +749,15 @@ where let mut stream = body .into_message() .map(|result| into_response(request_id, result)) - .map(move |message| { + .map(move |mut message| { + if message.payload.len() > rpc::max_response_payload_size() { + message.exceeded_message_size(); + } #[cfg(feature = "metrics")] if !message.status.is_ok() { metrics::status_error_counter(&node_id, &protocol, message.status).inc(); } - proto::rpc::RpcResponse { - request_id, - status: message.status.as_u32(), - flags: message.flags.bits().into(), - payload: message.payload.to_vec(), - } + message.to_proto() }) .map(|resp| Bytes::from(resp.to_encoded_bytes())); diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index 78463f00f8..51c2471593 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -282,7 +282,7 @@ async fn response_too_big() { let (_inbound, outbound, _, _, _shutdown) = setup(GreetingService::new(&[]), 1).await; let socket = outbound.get_yamux_control().open_stream().await.unwrap(); - let framed = framing::canonical(socket, rpc::RPC_MAX_FRAME_SIZE); + let framed = framing::canonical(socket, rpc::RPC_MAX_RESPONSE_SIZE); let mut client = GreetingClient::builder() .with_deadline(Duration::from_secs(5)) .connect(framed) @@ -291,7 +291,7 @@ async fn response_too_big() { // RPC_MAX_FRAME_SIZE bytes will always be too large because of the overhead of the RpcResponse proto message let err = client - .reply_with_msg_of_size(rpc::RPC_MAX_FRAME_SIZE as u64 + 1) + .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 + 1) .await .unwrap_err(); unpack_enum!(RpcError::RequestFailed(status) = err); @@ -299,7 +299,7 @@ async fn response_too_big() { // Check that the exact frame size boundary works and that the session is still going let _string = client - .reply_with_msg_of_size(rpc::RPC_MAX_FRAME_SIZE as u64 - 9) + .reply_with_msg_of_size(rpc::max_response_payload_size() as u64 - 9) .await .unwrap(); }