From 9f716d841184b8521720c6ed941af137ca2ee6a0 Mon Sep 17 00:00:00 2001 From: Moncef AOUDIA <22281426+aoudiamoncef@users.noreply.github.com> Date: Fri, 17 Feb 2023 16:58:22 +0100 Subject: [PATCH] feat(codec): Configure max request message size (#1274) * feat(codec): add max_message_size parameter resolves #1097 * refactor(client): add max size parameters * refactor(tonic-build): update server gen template * refactor(tonic-build): update client template * fix(tonic-build): update client template * fix(tonic-build): small typo in server.rs * fix(tonic-build): client.rs generator * fix(tonic): add apply max message setting size to server * fix(test): wrong message size * fix: doctest + generated rs --- tonic-build/src/client.rs | 14 ++ tonic-build/src/server.rs | 44 ++++++- tonic-health/src/generated/grpc.health.v1.rs | 42 ++++++ .../src/generated/grpc.reflection.v1alpha.rs | 36 ++++++ tonic/benches/decode.rs | 2 +- tonic/src/client/grpc.rs | 95 +++++++++++++- tonic/src/codec/decode.rs | 43 +++++- tonic/src/codec/encode.rs | 31 ++++- tonic/src/codec/mod.rs | 3 + tonic/src/codec/prost.rs | 73 ++++++++++- tonic/src/server/grpc.rs | 122 +++++++++++++++++- 11 files changed, 482 insertions(+), 23 deletions(-) diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index ae7861cc1..ce35d6616 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -137,6 +137,20 @@ pub(crate) fn generate_internal( self } + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + #methods } } diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index efaa70014..e2ba3bb4e 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -84,6 +84,22 @@ pub(crate) fn generate_internal( } }; + let configure_max_message_size_methods = quote! { + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + }; + quote! { /// Generated server implementations. #(#mod_attributes)* @@ -106,6 +122,8 @@ pub(crate) fn generate_internal( inner: _Inner, accept_compression_encodings: EnabledCompressionEncodings, send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, } struct _Inner(Arc); @@ -121,6 +139,8 @@ pub(crate) fn generate_internal( inner, accept_compression_encodings: Default::default(), send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } @@ -132,6 +152,8 @@ pub(crate) fn generate_internal( } #configure_compression_methods + + #configure_max_message_size_methods } impl tonic::codegen::Service> for #server_service @@ -173,6 +195,8 @@ pub(crate) fn generate_internal( inner, accept_compression_encodings: self.accept_compression_encodings, send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, } } } @@ -414,6 +438,8 @@ fn generate_unary( let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -421,7 +447,8 @@ fn generate_unary( let codec = #codec_name::default(); let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config(accept_compression_encodings, send_compression_encodings); + .apply_compression_config(accept_compression_encodings, send_compression_encodings) + .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size); let res = grpc.unary(method, req).await; Ok(res) @@ -466,6 +493,8 @@ fn generate_server_streaming( let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -473,7 +502,8 @@ fn generate_server_streaming( let codec = #codec_name::default(); let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config(accept_compression_encodings, send_compression_encodings); + .apply_compression_config(accept_compression_encodings, send_compression_encodings) + .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -516,6 +546,8 @@ fn generate_client_streaming( let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -523,7 +555,8 @@ fn generate_client_streaming( let codec = #codec_name::default(); let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config(accept_compression_encodings, send_compression_encodings); + .apply_compression_config(accept_compression_encodings, send_compression_encodings) + .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size); let res = grpc.client_streaming(method, req).await; Ok(res) @@ -569,6 +602,8 @@ fn generate_streaming( let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -576,7 +611,8 @@ fn generate_streaming( let codec = #codec_name::default(); let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config(accept_compression_encodings, send_compression_encodings); + .apply_compression_config(accept_compression_encodings, send_compression_encodings) + .apply_max_message_size_config(max_decoding_message_size, max_encoding_message_size); let res = grpc.streaming(method, req).await; Ok(res) diff --git a/tonic-health/src/generated/grpc.health.v1.rs b/tonic-health/src/generated/grpc.health.v1.rs index 835376aa2..f9794058c 100644 --- a/tonic-health/src/generated/grpc.health.v1.rs +++ b/tonic-health/src/generated/grpc.health.v1.rs @@ -114,6 +114,18 @@ pub mod health_client { self.inner = self.inner.accept_compressed(encoding); self } + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } /// If the requested service is unknown, the call will fail with status /// NOT_FOUND. pub async fn check( @@ -224,6 +236,8 @@ pub mod health_server { inner: _Inner, accept_compression_encodings: EnabledCompressionEncodings, send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, } struct _Inner(Arc); impl HealthServer { @@ -236,6 +250,8 @@ pub mod health_server { inner, accept_compression_encodings: Default::default(), send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } pub fn with_interceptor( @@ -259,6 +275,18 @@ pub mod health_server { self.send_compression_encodings.enable(encoding); self } + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } } impl tonic::codegen::Service> for HealthServer where @@ -301,6 +329,8 @@ pub mod health_server { } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -310,6 +340,10 @@ pub mod health_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.unary(method, req).await; Ok(res) @@ -340,6 +374,8 @@ pub mod health_server { } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -349,6 +385,10 @@ pub mod health_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -377,6 +417,8 @@ pub mod health_server { inner, accept_compression_encodings: self.accept_compression_encodings, send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, } } } diff --git a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs index 450a2b1ff..b3879091e 100644 --- a/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs +++ b/tonic-reflection/src/generated/grpc.reflection.v1alpha.rs @@ -211,6 +211,18 @@ pub mod server_reflection_client { self.inner = self.inner.accept_compressed(encoding); self } + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } /// The reflection service is structured as a bidirectional stream, ensuring /// all related requests go to a single server. pub async fn server_reflection_info( @@ -270,6 +282,8 @@ pub mod server_reflection_server { inner: _Inner, accept_compression_encodings: EnabledCompressionEncodings, send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, } struct _Inner(Arc); impl ServerReflectionServer { @@ -282,6 +296,8 @@ pub mod server_reflection_server { inner, accept_compression_encodings: Default::default(), send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } pub fn with_interceptor( @@ -305,6 +321,18 @@ pub mod server_reflection_server { self.send_compression_encodings.enable(encoding); self } + /// Limits the maximum size of a decoded message. + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } } impl tonic::codegen::Service> for ServerReflectionServer where @@ -352,6 +380,8 @@ pub mod server_reflection_server { } let accept_compression_encodings = self.accept_compression_encodings; let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; let inner = self.inner.clone(); let fut = async move { let inner = inner.0; @@ -361,6 +391,10 @@ pub mod server_reflection_server { .apply_compression_config( accept_compression_encodings, send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, ); let res = grpc.streaming(method, req).await; Ok(res) @@ -389,6 +423,8 @@ pub mod server_reflection_server { inner, accept_compression_encodings: self.accept_compression_encodings, send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, } } } diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 96f5b498d..5c7cd0159 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -22,7 +22,7 @@ macro_rules! bench { b.iter(|| { rt.block_on(async { let decoder = MockDecoder::new($message_size); - let mut stream = Streaming::new_request(decoder, body.clone(), None); + let mut stream = Streaming::new_request(decoder, body.clone(), None, None); let mut count = 0; while let Some(msg) = stream.message().await.unwrap() { diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index 85134ceae..70b7f07cd 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -39,6 +39,10 @@ struct GrpcConfig { accept_compression_encodings: EnabledCompressionEncodings, /// The compression encoding that will be applied to requests. send_compression_encodings: Option, + /// Limits the maximum size of a decoded message. + max_decoding_message_size: Option, + /// Limits the maximum size of an encoded message. + max_encoding_message_size: Option, } impl Grpc { @@ -58,6 +62,8 @@ impl Grpc { origin, send_compression_encodings: None, accept_compression_encodings: EnabledCompressionEncodings::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, }, } } @@ -124,6 +130,66 @@ impl Grpc { self } + /// Limits the maximum size of a decoded message. + /// + /// # Example + /// + /// The most common way of using this is through a client generated by tonic-build: + /// + /// ```rust + /// use tonic::transport::Channel; + /// # struct TestClient(T); + /// # impl TestClient { + /// # fn new(channel: T) -> Self { Self(channel) } + /// # fn max_decoding_message_size(self, _: usize) -> Self { self } + /// # } + /// + /// # async { + /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) + /// .connect() + /// .await + /// .unwrap(); + /// + /// // Set the limit to 2MB, Defaults to 4MB. + /// let limit = 2 * 1024 * 1024; + /// let client = TestClient::new(channel).max_decoding_message_size(limit); + /// # }; + /// ``` + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.config.max_decoding_message_size = Some(limit); + self + } + + /// Limits the maximum size of an ecoded message. + /// + /// # Example + /// + /// The most common way of using this is through a client generated by tonic-build: + /// + /// ```rust + /// use tonic::transport::Channel; + /// # struct TestClient(T); + /// # impl TestClient { + /// # fn new(channel: T) -> Self { Self(channel) } + /// # fn max_encoding_message_size(self, _: usize) -> Self { self } + /// # } + /// + /// # async { + /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap()) + /// .connect() + /// .await + /// .unwrap(); + /// + /// // Set the limit to 2MB, Defaults to 4MB. + /// let limit = 2 * 1024 * 1024; + /// let client = TestClient::new(channel).max_encoding_message_size(limit); + /// # }; + /// ``` + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.config.max_encoding_message_size = Some(limit); + self + } + /// Check if the inner [`GrpcService`] is able to accept a new request. /// /// This will call [`GrpcService::poll_ready`] until it returns ready or @@ -228,7 +294,14 @@ impl Grpc { M2: Send + Sync + 'static, { let request = request - .map(|s| encode_client(codec.encoder(), s, self.config.send_compression_encodings)) + .map(|s| { + encode_client( + codec.encoder(), + s, + self.config.send_compression_encodings, + self.config.max_encoding_message_size, + ) + }) .map(BoxBody::new); let request = self.config.prepare_request(request, path); @@ -278,7 +351,13 @@ impl Grpc { let response = response.map(|body| { if expect_additional_trailers { - Streaming::new_response(decoder, body, status_code, encoding) + Streaming::new_response( + decoder, + body, + status_code, + encoding, + self.config.max_decoding_message_size, + ) } else { Streaming::new_empty(decoder, body) } @@ -350,6 +429,8 @@ impl Clone for Grpc { origin: self.config.origin.clone(), send_compression_encodings: self.config.send_compression_encodings, accept_compression_encodings: self.config.accept_compression_encodings, + max_encoding_message_size: self.config.max_encoding_message_size, + max_decoding_message_size: self.config.max_decoding_message_size, }, } } @@ -373,6 +454,16 @@ impl fmt::Debug for Grpc { &self.config.accept_compression_encodings, ); + f.field( + "max_decoding_message_size", + &self.config.max_decoding_message_size, + ); + + f.field( + "max_encoding_message_size", + &self.config.max_encoding_message_size, + ); + f.finish() } } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index bb422652f..eb08ec219 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -1,5 +1,5 @@ use super::compression::{decompress, CompressionEncoding}; -use super::{DecodeBuf, Decoder, HEADER_SIZE}; +use super::{DecodeBuf, Decoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use futures_core::Stream; @@ -32,6 +32,7 @@ struct StreamingInner { trailers: Option, decompress_buf: BytesMut, encoding: Option, + max_message_size: Option, } impl Unpin for Streaming {} @@ -59,13 +60,20 @@ impl Streaming { body: B, status_code: StatusCode, encoding: Option, + max_message_size: Option, ) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { - Self::new(decoder, body, Direction::Response(status_code), encoding) + Self::new( + decoder, + body, + Direction::Response(status_code), + encoding, + max_message_size, + ) } pub(crate) fn new_empty(decoder: D, body: B) -> Self @@ -74,17 +82,28 @@ impl Streaming { B::Error: Into, D: Decoder + Send + 'static, { - Self::new(decoder, body, Direction::EmptyResponse, None) + Self::new(decoder, body, Direction::EmptyResponse, None, None) } #[doc(hidden)] - pub fn new_request(decoder: D, body: B, encoding: Option) -> Self + pub fn new_request( + decoder: D, + body: B, + encoding: Option, + max_message_size: Option, + ) -> Self where B: Body + Send + 'static, B::Error: Into, D: Decoder + Send + 'static, { - Self::new(decoder, body, Direction::Request, encoding) + Self::new( + decoder, + body, + Direction::Request, + encoding, + max_message_size, + ) } fn new( @@ -92,6 +111,7 @@ impl Streaming { body: B, direction: Direction, encoding: Option, + max_message_size: Option, ) -> Self where B: Body + Send + 'static, @@ -111,6 +131,7 @@ impl Streaming { trailers: None, decompress_buf: BytesMut::new(), encoding, + max_message_size, }, } } @@ -151,7 +172,19 @@ impl StreamingInner { return Err(Status::new(Code::Internal, message)); } }; + let len = self.buf.get_u32() as usize; + let limit = self.max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE); + if len > limit { + return Err(Status::new( + Code::OutOfRange, + format!( + "Error, message length too large: found {} bytes, the limit is: {} bytes", + len, limit + ), + )); + } + self.buf.reserve(len); self.state = State::ReadBody { diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index d94a1f0ba..5efc40ef3 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -1,5 +1,5 @@ use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride}; -use super::{EncodeBuf, Encoder, HEADER_SIZE}; +use super::{EncodeBuf, Encoder, DEFAULT_MAX_MESSAGE_SIZE, HEADER_SIZE}; use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; @@ -19,12 +19,20 @@ pub(crate) fn encode_server( source: U, compression_encoding: Option, compression_override: SingleMessageCompressionOverride, + max_message_size: Option, ) -> EncodeBody>> where T: Encoder, U: Stream>, { - let stream = encode(encoder, source, compression_encoding, compression_override).into_stream(); + let stream = encode( + encoder, + source, + compression_encoding, + compression_override, + max_message_size, + ) + .into_stream(); EncodeBody::new_server(stream) } @@ -33,6 +41,7 @@ pub(crate) fn encode_client( encoder: T, source: U, compression_encoding: Option, + max_message_size: Option, ) -> EncodeBody>> where T: Encoder, @@ -43,6 +52,7 @@ where source.map(Ok), compression_encoding, SingleMessageCompressionOverride::default(), + max_message_size, ) .into_stream(); EncodeBody::new_client(stream) @@ -53,6 +63,7 @@ fn encode( source: U, compression_encoding: Option, compression_override: SingleMessageCompressionOverride, + max_message_size: Option, ) -> impl TryStream where T: Encoder, @@ -81,6 +92,7 @@ where &mut buf, &mut uncompression_buf, compression_encoding, + max_message_size, item, ) }) @@ -91,6 +103,7 @@ fn encode_item( buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option, + max_message_size: Option, item: T::Item, ) -> Result where @@ -119,14 +132,26 @@ where } // now that we know length, we can write the header - finish_encoding(compression_encoding, buf) + finish_encoding(compression_encoding, max_message_size, buf) } fn finish_encoding( compression_encoding: Option, + max_message_size: Option, buf: &mut BytesMut, ) -> Result { let len = buf.len() - HEADER_SIZE; + let limit = max_message_size.unwrap_or(DEFAULT_MAX_MESSAGE_SIZE); + if len > limit { + return Err(Status::new( + Code::OutOfRange, + format!( + "Error, message length too large: found {} bytes, the limit is: {} bytes", + len, limit + ), + )); + } + if len > std::u32::MAX as usize { return Err(Status::resource_exhausted(format!( "Cannot return body with more than 4GB of data but got {len} bytes" diff --git a/tonic/src/codec/mod.rs b/tonic/src/codec/mod.rs index cc330b14c..30ca36a2c 100644 --- a/tonic/src/codec/mod.rs +++ b/tonic/src/codec/mod.rs @@ -29,6 +29,9 @@ const HEADER_SIZE: usize = // data length std::mem::size_of::(); +// The default maximum uncompressed size in bytes for a message. Defaults to 4MB. +const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024; + /// Trait that knows how to encode and decode gRPC messages. pub trait Codec { /// The encodable message. diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 6c2899706..f983d3cd7 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -81,11 +81,13 @@ mod tests { use crate::codec::{ encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE, }; - use crate::Status; + use crate::{Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; const LEN: usize = 10000; + // The maximum uncompressed size in bytes for a message. Set to 2MB. + const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024; #[tokio::test] async fn decode() { @@ -103,7 +105,7 @@ mod tests { let body = body::MockBody::new(&buf[..], 10005, 0); - let mut stream = Streaming::new_request(decoder, body, None); + let mut stream = Streaming::new_request(decoder, body, None, None); let mut i = 0usize; while let Some(output_msg) = stream.message().await.unwrap() { @@ -113,6 +115,39 @@ mod tests { assert_eq!(i, 1); } + #[tokio::test] + async fn decode_max_message_size_exceeded() { + let decoder = MockDecoder::default(); + + let msg = vec![0u8; MAX_MESSAGE_SIZE + 1]; + + let mut buf = BytesMut::new(); + + buf.reserve(msg.len() + HEADER_SIZE); + buf.put_u8(0); + buf.put_u32(msg.len() as u32); + + buf.put(&msg[..]); + + let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0); + + let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE)); + + let actual = stream.message().await.unwrap_err(); + + let expected = Status::new( + Code::OutOfRange, + format!( + "Error, message length too large: found {} bytes, the limit is: {} bytes", + msg.len(), + MAX_MESSAGE_SIZE + ), + ); + + assert_eq!(actual.code(), expected.code()); + assert_eq!(actual.message(), expected.message()); + } + #[tokio::test] async fn encode() { let encoder = MockEncoder::default(); @@ -127,6 +162,7 @@ mod tests { source, None, SingleMessageCompressionOverride::default(), + None, ); futures_util::pin_mut!(body); @@ -136,6 +172,38 @@ mod tests { } } + #[tokio::test] + async fn encode_max_message_size_exceeded() { + let encoder = MockEncoder::default(); + + let msg = vec![0u8; MAX_MESSAGE_SIZE + 1]; + + let messages = std::iter::once(Ok::<_, Status>(msg)); + let source = futures_util::stream::iter(messages); + + let body = encode_server( + encoder, + source, + None, + SingleMessageCompressionOverride::default(), + Some(MAX_MESSAGE_SIZE), + ); + + futures_util::pin_mut!(body); + + assert!(body.data().await.is_none()); + assert_eq!( + body.trailers() + .await + .expect("no error polling trailers") + .expect("some trailers") + .get("grpc-status") + .expect("grpc-status header"), + "11" + ); + assert!(body.is_end_stream()); + } + // skip on windows because CI stumbles over our 4GB allocation #[cfg(not(target_family = "windows"))] #[tokio::test] @@ -152,6 +220,7 @@ mod tests { source, None, SingleMessageCompressionOverride::default(), + Some(usize::MAX), ); futures_util::pin_mut!(body); diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index cbe8450ff..0749bac48 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -36,6 +36,10 @@ pub struct Grpc { accept_compression_encodings: EnabledCompressionEncodings, /// Which compression encodings might the server use for responses. send_compression_encodings: EnabledCompressionEncodings, + /// Limits the maximum size of a decoded message. + max_decoding_message_size: Option, + /// Limits the maximum size of an encoded message. + max_encoding_message_size: Option, } impl Grpc @@ -48,6 +52,8 @@ where codec, accept_compression_encodings: EnabledCompressionEncodings::default(), send_compression_encodings: EnabledCompressionEncodings::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, } } @@ -114,6 +120,66 @@ where self } + /// Limits the maximum size of a decoded message. + /// + /// # Example + /// + /// The most common way of using this is through a server generated by tonic-build: + /// + /// ```rust + /// # struct Svc; + /// # struct ExampleServer(T); + /// # impl ExampleServer { + /// # fn new(svc: T) -> Self { Self(svc) } + /// # fn max_decoding_message_size(self, _: usize) -> Self { self } + /// # } + /// # #[tonic::async_trait] + /// # trait Example {} + /// + /// #[tonic::async_trait] + /// impl Example for Svc { + /// // ... + /// } + /// + /// // Set the limit to 2MB, Defaults to 4MB. + /// let limit = 2 * 1024 * 1024; + /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit); + /// ``` + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + + /// Limits the maximum size of a encoded message. + /// + /// # Example + /// + /// The most common way of using this is through a server generated by tonic-build: + /// + /// ```rust + /// # struct Svc; + /// # struct ExampleServer(T); + /// # impl ExampleServer { + /// # fn new(svc: T) -> Self { Self(svc) } + /// # fn max_encoding_message_size(self, _: usize) -> Self { self } + /// # } + /// # #[tonic::async_trait] + /// # trait Example {} + /// + /// #[tonic::async_trait] + /// impl Example for Svc { + /// // ... + /// } + /// + /// // Set the limit to 2MB, Defaults to 4MB. + /// let limit = 2 * 1024 * 1024; + /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit); + /// ``` + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + #[doc(hidden)] pub fn apply_compression_config( self, @@ -134,6 +200,24 @@ where this } + #[doc(hidden)] + pub fn apply_max_message_size_config( + self, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + ) -> Self { + let mut this = self; + + if let Some(limit) = max_decoding_message_size { + this = this.max_decoding_message_size(limit); + } + if let Some(limit) = max_encoding_message_size { + this = this.max_encoding_message_size(limit); + } + + this + } + /// Handle a single unary gRPC request. pub async fn unary( &mut self, @@ -158,6 +242,7 @@ where Err(status), accept_encoding, SingleMessageCompressionOverride::default(), + self.max_encoding_message_size, ); } }; @@ -169,7 +254,12 @@ where let compression_override = compression_override_from_response(&response); - self.map_response(response, accept_encoding, compression_override) + self.map_response( + response, + accept_encoding, + compression_override, + self.max_encoding_message_size, + ) } /// Handle a server side streaming request. @@ -196,6 +286,7 @@ where Err(status), accept_encoding, SingleMessageCompressionOverride::default(), + self.max_encoding_message_size, ); } }; @@ -208,6 +299,7 @@ where // disabling compression of individual stream items must be done on // the items themselves SingleMessageCompressionOverride::default(), + self.max_encoding_message_size, ) } @@ -236,7 +328,12 @@ where let compression_override = compression_override_from_response(&response); - self.map_response(response, accept_encoding, compression_override) + self.map_response( + response, + accept_encoding, + compression_override, + self.max_encoding_message_size, + ) } /// Handle a bi-directional streaming gRPC request. @@ -264,6 +361,7 @@ where response, accept_encoding, SingleMessageCompressionOverride::default(), + self.max_encoding_message_size, ) } @@ -279,8 +377,12 @@ where let (parts, body) = request.into_parts(); - let stream = - Streaming::new_request(self.codec.decoder(), body, request_compression_encoding); + let stream = Streaming::new_request( + self.codec.decoder(), + body, + request_compression_encoding, + self.max_decoding_message_size, + ); futures_util::pin_mut!(stream); @@ -308,8 +410,14 @@ where { let encoding = self.request_encoding_if_supported(&request)?; - let request = - request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding)); + let request = request.map(|body| { + Streaming::new_request( + self.codec.decoder(), + body, + encoding, + self.max_decoding_message_size, + ) + }); Ok(Request::from_http(request)) } @@ -319,6 +427,7 @@ where response: Result, Status>, accept_encoding: Option, compression_override: SingleMessageCompressionOverride, + max_message_size: Option, ) -> http::Response where B: TryStream + Send + 'static, @@ -349,6 +458,7 @@ where body.into_stream(), accept_encoding, compression_override, + max_message_size, ); http::Response::from_parts(parts, BoxBody::new(body))