From 2587ae753ef642ca110d94b33a292bd1a09deb9e Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Thu, 16 Nov 2023 09:23:04 -0800 Subject: [PATCH] mendes: upgrade to hyper/http 1 --- mendes/Cargo.toml | 23 ++- mendes/src/application.rs | 100 +++------ mendes/src/body.rs | 393 ++++++++++++++++++++++++++++++++++++ mendes/src/hyper.rs | 413 ++++++++++++++++++++++---------------- mendes/src/lib.rs | 6 + mendes/tests/hyper.rs | 27 +-- mendes/tests/readme.rs | 16 +- 7 files changed, 708 insertions(+), 270 deletions(-) create mode 100644 mendes/src/body.rs diff --git a/mendes/Cargo.toml b/mendes/Cargo.toml index ddf291b..14e238d 100644 --- a/mendes/Cargo.toml +++ b/mendes/Cargo.toml @@ -14,21 +14,22 @@ readme = "../README.md" [features] default = ["application"] -application = ["http", "dep:async-trait", "dep:mendes-macros", "dep:percent-encoding", "dep:serde", "dep:serde_urlencoded"] +application = ["http", "dep:async-trait", "dep:bytes", "dep:http-body", "dep:mendes-macros", "dep:percent-encoding", "dep:pin-project", "dep:serde", "dep:serde_urlencoded"] brotli = ["compression", "async-compression?/brotli"] chrono = ["dep:chrono"] -compression = ["dep:async-compression", "dep:futures-util", "dep:tokio-util"] +compression = ["dep:async-compression", "dep:tokio", "dep:tokio-util"] cookies = ["http", "key", "dep:chrono", "dep:data-encoding", "dep:mendes-macros", "dep:postcard", "serde?/derive"] deflate = ["compression", "async-compression?/deflate"] forms = ["dep:mendes-macros", "dep:serde_urlencoded", "serde?/derive"] gzip = ["compression", "async-compression?/gzip"] -http = ["dep:http"] -http-body = ["dep:bytes", "dep:http-body", "dep:pin-utils"] -hyper = ["application", "http", "dep:async-trait", "dep:bytes", "dep:futures-util", "futures-util?/std", "dep:hyper"] +hyper = ["application", "http", "dep:async-trait", "dep:bytes", "dep:futures-util", "futures-util?/std", "dep:hyper", "dep:hyper-util", "dep:tokio", "tokio?/macros", "tracing"] key = ["dep:data-encoding", "dep:ring"] json = ["dep:serde_json"] uploads = ["http", "dep:httparse", "dep:memchr"] +body = ["dep:http-body"] +body-util = ["dep:http-body-util", "dep:bytes", "dep:http-body"] static = ["application", "http", "dep:mime_guess", "dep:tokio", "tokio?/fs"] +tracing = ["dep:tracing"] [dependencies] async-compression = { version = "0.4.0", features = ["tokio"], optional = true } @@ -37,15 +38,17 @@ bytes = { version = "1", optional = true } chrono = { version = "0.4.23", optional = true, features = ["serde"] } data-encoding = { version = "2.1.2", optional = true } futures-util = { version = "0.3.7", optional = true, default-features = false } -http = { version = "0.2", optional = true } -http-body = { version = "0.4", optional = true } +http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } httparse = { version = "1.3.4", optional = true } -hyper = { version = "0.14.1", optional = true, features = ["http1", "http2", "runtime", "server", "stream"] } +hyper = { version = "1", optional = true, features = ["http1", "http2", "server"] } +hyper-util = { version = "0.1.3", features = ["http1", "http2", "server", "tokio"], optional = true } memchr = { version = "2.5", optional = true } mendes-macros = { version = "0.4", path = "../mendes-macros", optional = true } mime_guess = { version = "2.0.3", default-features = false, optional = true } percent-encoding = { version = "2.1.0", default-features = false, optional = true } -pin-utils = { version = "0.1.0", optional = true } +pin-project = { version = "1.1.5", optional = true } postcard = { version = "1.0.6", default-features = false, features = ["use-std"], optional = true } ring = { version = "0.17.0", optional = true } serde = { version = "1.0.104", optional = true } @@ -57,7 +60,7 @@ tokio-util = { version = "0.7", optional = true, features = ["codec", "compat", tracing = { version = "0.1.26", optional = true } [dev-dependencies] -reqwest = { version = "0.11.11", default-features = false } +reqwest = { version = "0.12", default-features = false } tokio = { version = "1", features = ["macros", "rt"] } [package.metadata.docs.rs] diff --git a/mendes/src/application.rs b/mendes/src/application.rs index f4bc0e0..48e0ccb 100644 --- a/mendes/src/application.rs +++ b/mendes/src/application.rs @@ -1,27 +1,21 @@ use std::borrow::Cow; -#[cfg(feature = "http-body")] +#[cfg(feature = "body-util")] use std::error::Error as StdError; -use std::str; use std::str::FromStr; use std::sync::Arc; use async_trait::async_trait; -#[cfg(feature = "http-body")] -use bytes::{Buf, BufMut, Bytes}; +#[cfg(feature = "body-util")] +use bytes::Bytes; use http::header::LOCATION; use http::request::Parts; use http::Request; use http::{Response, StatusCode}; -#[cfg(feature = "http-body")] -use http_body::Body as HttpBody; use percent_encoding::percent_decode_str; use thiserror::Error; pub use mendes_macros::{handler, route, scope}; -#[cfg(feature = "hyper")] -use crate::hyper::ApplicationService; - /// Main interface for an application or service /// /// The `Application` holds state and routes request to the proper handlers. A handler gets @@ -34,7 +28,7 @@ use crate::hyper::ApplicationService; #[async_trait] pub trait Application: Send + Sized { type RequestBody: Send; - type ResponseBody: Send; + type ResponseBody: http_body::Body; type Error: IntoResponse + WithStatus + From + Send; async fn handle(cx: Context) -> Response; @@ -53,17 +47,17 @@ pub trait Application: Send + Sized { from_bytes::(req, bytes) } - #[cfg(feature = "http-body")] - #[cfg_attr(docsrs, doc(cfg(feature = "http-body")))] + #[cfg(feature = "http-body-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "http-body-util")))] async fn from_body( req: &Parts, body: Self::RequestBody, max_len: usize, ) -> Result where - Self::RequestBody: HttpBody + Send, - ::Data: Send, - ::Error: Into>, + Self::RequestBody: Body + Send, + ::Data: Send, + ::Error: Into>, { // Check if the Content-Length header suggests the body is larger than our max len // to avoid allocation if we drop the request in any case. @@ -78,12 +72,11 @@ pub trait Application: Send + Sized { from_body::(req, body, max_len).await } - #[cfg(feature = "http-body")] - #[cfg_attr(docsrs, doc(cfg(feature = "http-body")))] - async fn body_bytes(body: B, max_len: usize) -> Result + #[cfg(feature = "body-util")] + #[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] + async fn body_bytes(body: B, max_len: usize) -> Result where - B: HttpBody + Send, - ::Data: Send, + B::Data: Send, B::Error: Into>, { // Check if the Content-Length header suggests the body is larger than our max len @@ -92,6 +85,7 @@ pub trait Application: Send + Sized { Some(length) => length, None => body.size_hint().lower(), }; + if expected_len > max_len as u64 { return Err(Error::BodyTooLarge); } @@ -109,11 +103,6 @@ pub trait Application: Send + Sized { .body(Self::ResponseBody::default()) .unwrap() } - - #[cfg(feature = "hyper")] - fn into_service(self) -> ApplicationService { - ApplicationService(Arc::new(self)) - } } pub trait WithStatus {} @@ -513,56 +502,21 @@ fn from_bytes<'de, T: serde::de::Deserialize<'de>>( deserialize_body!(req, bytes) } -#[cfg(feature = "http-body")] -#[cfg_attr(docsrs, doc(cfg(feature = "http-body")))] +#[cfg(feature = "body-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "body-util")))] #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))] -async fn to_bytes(body: B, max_len: usize) -> Result +async fn to_bytes(body: B, max_len: usize) -> Result where - B: HttpBody, B::Error: Into>, { - pin_utils::pin_mut!(body); - - // If there's only 1 chunk, we can just return Buf::to_bytes() - let mut first = if let Some(buf) = body.data().await { - buf.map_err(|err| Error::BodyReceive(err.into()))? - } else { - return Ok(Bytes::new()); - }; + #[cfg(feature = "body-util")] + use http_body_util::BodyExt; - let mut received = first.remaining(); - if received > max_len { - return Err(Error::BodyTooLarge); + let limited = http_body_util::Limited::new(body, max_len); + match limited.collect().await { + Ok(collected) => Ok(collected.to_bytes()), + Err(err) => Err(Error::BodyReceive(err)), } - - let second = if let Some(buf) = body.data().await { - buf.map_err(|err| Error::BodyReceive(err.into()))? - } else { - return Ok(first.copy_to_bytes(first.remaining())); - }; - - received += second.remaining(); - if received > max_len { - return Err(Error::BodyTooLarge); - } - - // With more than 1 buf, we gotta flatten into a Vec first. - let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; - let mut vec = Vec::with_capacity(cap); - vec.put(first); - vec.put(second); - - while let Some(buf) = body.data().await { - let buf = buf.map_err(|err| Error::BodyReceive(err.into()))?; - received += buf.remaining(); - if received > max_len { - return Err(Error::BodyTooLarge); - } - - vec.put(buf); - } - - Ok(vec.into()) } // This should only be used by procedural routing macros. @@ -645,10 +599,10 @@ pub enum Error { QueryMissing, #[error("unable to decode request URI query: {0}")] QueryDecode(serde_urlencoded::de::Error), - #[cfg(feature = "http-body")] + #[cfg(feature = "body-util")] #[error("unable to receive request body: {0}")] BodyReceive(Box), - #[cfg(feature = "http-body")] + #[cfg(feature = "body-util")] #[error("request body too large")] BodyTooLarge, #[cfg(feature = "json")] @@ -676,9 +630,9 @@ impl From<&Error> for StatusCode { QueryMissing | QueryDecode(_) | BodyNoType => StatusCode::BAD_REQUEST, BodyUnknownType(_) => StatusCode::UNSUPPORTED_MEDIA_TYPE, PathNotFound | PathComponentMissing | PathParse | PathDecode => StatusCode::NOT_FOUND, - #[cfg(feature = "http-body")] + #[cfg(feature = "body-util")] BodyReceive(_) => StatusCode::INTERNAL_SERVER_ERROR, - #[cfg(feature = "http-body")] + #[cfg(feature = "body-util")] BodyTooLarge => StatusCode::BAD_REQUEST, BodyDecodeForm(_) => StatusCode::UNPROCESSABLE_ENTITY, #[cfg(feature = "json")] diff --git a/mendes/src/body.rs b/mendes/src/body.rs new file mode 100644 index 0000000..82861d3 --- /dev/null +++ b/mendes/src/body.rs @@ -0,0 +1,393 @@ +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; +use std::task::ready; +use std::task::Poll; +use std::{io, mem, str}; + +#[cfg(feature = "brotli")] +use async_compression::tokio::bufread::BrotliEncoder; +#[cfg(feature = "deflate")] +use async_compression::tokio::bufread::DeflateEncoder; +#[cfg(feature = "gzip")] +use async_compression::tokio::bufread::GzipEncoder; +use bytes::{Buf, Bytes, BytesMut}; +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING}; +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +use http::HeaderMap; +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +use http::{request, HeaderValue, Response}; +use http_body::{Frame, SizeHint}; +use pin_project::pin_project; +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +use tokio_util::io::poll_read_buf; + +#[pin_project] +pub struct Body { + #[pin] + inner: InnerBody, + full_size: u64, + done: bool, +} + +impl Body { + pub fn empty() -> Self { + Self { + inner: InnerBody::Bytes(Bytes::new()), + full_size: 0, + done: true, + } + } + + pub fn lazy(future: impl Future> + Send + 'static) -> Self { + Self { + inner: InnerBody::Lazy { + future: Box::pin(future), + encoding: Encoding::Identity, + }, + full_size: 0, + done: false, + } + } + + pub fn stream( + stream: impl http_body::Body + Send + 'static, + ) -> Self { + Self { + inner: InnerBody::Streaming(Box::pin(stream)), + full_size: 0, + done: false, + } + } +} + +impl http_body::Body for Body { + type Data = Bytes; + type Error = io::Error; + + #[allow(unused_variables)] // Depends on features + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.project(); + if *this.done { + return Poll::Ready(None); + } + + #[allow(unused_mut)] // Depends on features + let mut buf = BytesMut::new(); + let result = match this.inner.project() { + #[cfg(feature = "brotli")] + PinnedBody::Brotli(encoder) => poll_read_buf(encoder, cx, &mut buf), + #[cfg(feature = "deflate")] + PinnedBody::Deflate(encoder) => poll_read_buf(encoder, cx, &mut buf), + #[cfg(feature = "gzip")] + PinnedBody::Gzip(encoder) => poll_read_buf(encoder, cx, &mut buf), + PinnedBody::Bytes(bytes) => { + *this.done = true; + let bytes = mem::take(bytes.get_mut()); + return Poll::Ready(match bytes.has_remaining() { + true => Some(Ok(Frame::data(bytes))), + false => None, + }); + } + PinnedBody::Streaming(inner) => match ready!(inner.as_mut().poll_frame(cx)) { + Some(item) => return Poll::Ready(Some(item)), + None => { + *this.done = true; + return Poll::Ready(None); + } + }, + PinnedBody::Lazy { future, encoding } => { + let bytes = match ready!(future.as_mut().poll(cx)) { + Ok(bytes) => bytes, + Err(error) => return Poll::Ready(Some(Err(error))), + }; + + let len = bytes.len(); + let mut inner = InnerBody::wrap(bytes, *encoding); + *this.full_size = len as u64; + // The duplication here is pretty ugly, but I couldn't come up with anything better. + match &mut inner { + #[cfg(feature = "brotli")] + InnerBody::Brotli(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), + #[cfg(feature = "deflate")] + InnerBody::Deflate(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), + #[cfg(feature = "gzip")] + InnerBody::Gzip(encoder) => poll_read_buf(Pin::new(encoder), cx, &mut buf), + InnerBody::Bytes(bytes) => { + *this.done = true; + let bytes = mem::take(bytes); + return Poll::Ready(match bytes.has_remaining() { + true => Some(Ok(Frame::data(bytes))), + false => None, + }); + } + InnerBody::Lazy { .. } | InnerBody::Streaming(_) => unreachable!(), + } + } + }; + + #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] + match ready!(result) { + Ok(0) => { + *this.done = true; + Poll::Ready(None) + } + Ok(n) => { + *this.full_size = this.full_size.saturating_sub(n as u64); + Poll::Ready(Some(Ok(Frame::data(buf.freeze())))) + } + Err(error) => Poll::Ready(Some(Err(error))), + } + } + + fn is_end_stream(&self) -> bool { + self.done + } + + fn size_hint(&self) -> http_body::SizeHint { + match (self.done, &self.inner) { + (true, _) => SizeHint::with_exact(0), + (false, InnerBody::Bytes(body)) => SizeHint::with_exact(body.len() as u64), + (false, InnerBody::Lazy { .. } | InnerBody::Streaming(_)) => SizeHint::default(), + #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] + (false, InnerBody::Brotli(_) | InnerBody::Deflate(_) | InnerBody::Gzip(_)) => { + let mut hint = SizeHint::default(); + hint.set_lower(1); + hint.set_upper(self.full_size + 256); + hint + } + } + } +} + +impl From> for Body { + fn from(data: Vec) -> Self { + Self::from(Bytes::from(data)) + } +} + +impl From for Body { + fn from(data: String) -> Self { + Self::from(Bytes::from(data)) + } +} + +impl From<&'static str> for Body { + fn from(data: &'static str) -> Self { + Self::from(Bytes::from(data)) + } +} + +impl From for Body { + fn from(data: Bytes) -> Self { + Self { + done: !data.has_remaining(), + full_size: data.len() as u64, + inner: InnerBody::Bytes(data), + } + } +} + +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +impl EncodeResponse for Response { + fn encoded(mut self, req: &request::Parts) -> Response { + let buf = match self.body_mut() { + Body { done: true, .. } => return self, + Body { + inner: InnerBody::Bytes(buf), + .. + } => mem::take(buf), + Body { + inner: + InnerBody::Lazy { + encoding: enc @ Encoding::Identity, + .. + }, + .. + } => { + let new = Encoding::from_accept(&req.headers).unwrap_or(Encoding::Identity); + *enc = new; + return self; + } + Body { + inner: + InnerBody::Brotli(_) + | InnerBody::Deflate(_) + | InnerBody::Gzip(_) + | InnerBody::Lazy { .. } + | InnerBody::Streaming(_), + .. + } => return self, + }; + + let len = buf.len(); + let encoding = Encoding::from_accept(&req.headers).unwrap_or(Encoding::Identity); + let inner = InnerBody::wrap(buf, encoding); + if let Some(encoding) = encoding.as_str() { + self.headers_mut() + .insert(CONTENT_ENCODING, HeaderValue::from_static(encoding)); + } + + let body = self.body_mut(); + body.full_size = len as u64; + body.inner = inner; + self + } +} + +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +pub trait EncodeResponse { + fn encoded(self, req: &request::Parts) -> Self; +} + +#[pin_project(project = PinnedBody)] +enum InnerBody { + #[cfg(feature = "brotli")] + Brotli(#[pin] BrotliEncoder), + #[cfg(feature = "deflate")] + Deflate(#[pin] DeflateEncoder), + #[cfg(feature = "gzip")] + Gzip(#[pin] GzipEncoder), + Bytes(#[pin] Bytes), + Lazy { + future: Pin> + Send>>, + encoding: Encoding, + }, + Streaming(Pin + Send>>), +} + +impl InnerBody { + fn wrap(buf: Bytes, encoding: Encoding) -> Self { + match encoding { + #[cfg(feature = "brotli")] + Encoding::Brotli => Self::Brotli(BrotliEncoder::new(BufReader { buf })), + #[cfg(feature = "deflate")] + Encoding::Deflate => Self::Deflate(DeflateEncoder::new(BufReader { buf })), + #[cfg(feature = "gzip")] + Encoding::Gzip => Self::Gzip(GzipEncoder::new(BufReader { buf })), + Encoding::Identity => Self::Bytes(buf), + } + } +} + +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +struct BufReader { + pub(crate) buf: Bytes, +} + +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +impl AsyncBufRead for BufReader { + fn poll_fill_buf( + self: Pin<&mut Self>, + _: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().buf.chunk())) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.get_mut().buf.advance(amt); + } +} + +#[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] +impl AsyncRead for BufReader { + fn poll_read( + self: Pin<&mut Self>, + _: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let len = Ord::min(self.buf.remaining(), buf.remaining()); + self.get_mut() + .buf + .copy_to_slice(buf.initialize_unfilled_to(len)); + Poll::Ready(Ok(())) + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)] +enum Encoding { + #[cfg(feature = "brotli")] + Brotli, + #[cfg(feature = "deflate")] + Deflate, + #[cfg(feature = "gzip")] + Gzip, + Identity, +} + +impl Encoding { + #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] + fn from_accept(headers: &HeaderMap) -> Option { + let accept = match headers.get(ACCEPT_ENCODING).map(|hv| hv.to_str()) { + Some(Ok(accept)) => accept, + _ => return None, + }; + + let mut encodings = accept + .split(',') + .filter_map(|s| { + let mut parts = s.splitn(2, ';'); + let alg = match Encoding::from_str(parts.next()?.trim()) { + Ok(encoding) => encoding, + Err(()) => return None, + }; + + let qual = parts + .next() + .and_then(|s| { + let mut parts = s.splitn(2, '='); + if parts.next()?.trim() != "q" { + return None; + } + + let value = parts.next()?; + f64::from_str(value).ok() + }) + .unwrap_or(1.0); + + Some((alg, (qual * 100.0) as u64)) + }) + .collect::>(); + encodings.sort_by_key(|(algo, qual)| (-(*qual as i64), *algo)); + + encodings.into_iter().next().map(|(algo, _)| algo) + } +} + +impl Encoding { + #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] + pub fn as_str(self) -> Option<&'static str> { + match self { + #[cfg(feature = "brotli")] + Self::Brotli => Some("br"), + #[cfg(feature = "deflate")] + Self::Deflate => Some("deflate"), + #[cfg(feature = "gzip")] + Self::Gzip => Some("gzip"), + Self::Identity => None, + } + } +} + +impl FromStr for Encoding { + type Err = (); + + fn from_str(s: &str) -> Result { + Ok(match s { + #[cfg(feature = "brotli")] + "br" => Encoding::Brotli, + #[cfg(feature = "deflate")] + "deflate" => Encoding::Deflate, + #[cfg(feature = "gzip")] + "gzip" => Encoding::Gzip, + "identity" => Encoding::Identity, + _ => return Err(()), + }) + } +} diff --git a/mendes/src/hyper.rs b/mendes/src/hyper.rs index 3d244d1..884f75c 100644 --- a/mendes/src/hyper.rs +++ b/mendes/src/hyper.rs @@ -1,51 +1,257 @@ use std::convert::Infallible; -use std::future::Future; +use std::error::Error as StdError; +use std::future::{poll_fn, Future, Pending}; +use std::io; use std::net::SocketAddr; use std::panic::AssertUnwindSafe; use std::pin::Pin; use std::sync::Arc; use std::task::Poll; +use std::time::Duration; -use futures_util::future::{ready, CatchUnwind, FutureExt, Map, Ready}; +use futures_util::future::{CatchUnwind, FutureExt, Map}; +use futures_util::pin_mut; use http::request::Parts; use http::{Request, Response, StatusCode}; -use hyper::server::conn::AddrStream; +use hyper::body::{Body, Incoming}; use hyper::service::Service; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; +use tokio::time::sleep; +use tracing::{debug, error, info, trace}; use super::Application; use crate::application::{Context, FromContext, PathState}; -pub use hyper::Body; +pub use hyper::body; -/// `ApplicationService` wraps an `Arc` to implement the service trait used in hyper -pub struct ApplicationService(pub(crate) Arc); +pub struct Server { + listener: TcpListener, + app: Arc, + signal: Option, +} -impl<'t, A: Application> Service<&'t AddrStream> for ApplicationService { - type Response = ConnectionService; - type Error = hyper::Error; - type Future = Ready>; +impl Server> { + pub fn new(listener: TcpListener, app: A) -> Server> { + Server { + listener, + app: Arc::new(app), + signal: None, + } + } +} - fn poll_ready( - &mut self, - _: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - Poll::Ready(Ok(())) +impl Server> { + pub fn with_graceful_shutdown>(self, signal: F) -> Server { + let Server { listener, app, .. } = self; + Server { + listener, + app, + signal: Some(signal), + } } +} + +impl Server +where + A: Application + Sync + 'static, + <::ResponseBody as Body>::Data: Send, + <::ResponseBody as Body>::Error: StdError + Send + Sync, + ::ResponseBody: From<&'static str> + Send, + F: Future + Send + 'static, +{ + pub async fn serve(self) -> Result<(), io::Error> { + let Server { + listener, + app, + signal, + } = self; + + let (listener_state, conn_state) = states(signal); + loop { + let (stream, addr) = tokio::select! { + res = listener.accept() => { + match res { + Ok((stream, addr)) => (stream, addr), + Err(error) => { + use io::ErrorKind::*; + if matches!(error.kind(), ConnectionRefused | ConnectionAborted | ConnectionReset) { + continue; + } + + // Sleep for a bit to see if the error clears + error!(%error, "error accepting connection"); + sleep(Duration::from_secs(1)).await; + continue; + } + } + } + _ = listener_state.is_shutting_down() => break, + }; + + debug!("connection accepted from {addr}"); + tokio::spawn( + Connection { + stream, + addr, + state: conn_state.clone(), + app: app.clone(), + } + .run(), + ); + } - fn call(&mut self, conn: &'t AddrStream) -> Self::Future { - ready(Ok(ConnectionService { - app: self.0.clone(), - addr: conn.remote_addr(), - })) + let ListenerState { task_monitor, .. } = listener_state; + drop(listener); + if let Some(task_monitor) = task_monitor { + trace!( + "waiting for {} task(s) to finish", + task_monitor.receiver_count() + ); + task_monitor.closed().await; + } + + Ok(()) } } -pub struct ConnectionService { +fn states( + future: Option + Send + 'static>, +) -> (ListenerState, ConnectionState) { + let future = match future { + Some(future) => future, + None => return (ListenerState::default(), ConnectionState::default()), + }; + + let (shutting_down, signal) = watch::channel(()); // Axum: `signal_tx`, `signal_rx` + let shutting_down = Arc::new(shutting_down); + tokio::spawn(async move { + future.await; + info!("shutdown signal received, draining..."); + drop(signal); + }); + + let (task_monitor, task_done) = watch::channel(()); // Axum: `close_tx`, `close_rx` + ( + ListenerState { + shutting_down: Some(shutting_down.clone()), + task_monitor: Some(task_monitor), + _task_done: Some(task_done.clone()), + }, + ConnectionState { + shutting_down: Some(shutting_down), + _task_done: Some(task_done), + }, + ) +} + +#[derive(Default)] +struct ListenerState { + /// If `Some` and `closed()`, the server is shutting down + shutting_down: Option>>, + /// If `Some`, `receiver_count()` can be used whether any connections are still going + /// + /// Call `closed().await` to wait for all connections to finish. + task_monitor: Option>, + /// Given to each connection so we can monitor the number of receivers via `_task_monitor` + _task_done: Option>, +} + +impl ListenerState { + async fn is_shutting_down(&self) { + poll_fn(|cx| match &self.shutting_down { + Some(tx) => { + let future = tx.closed(); + pin_mut!(future); + future.poll(cx) + } + None => Poll::Pending, + }) + .await + } +} + +struct Connection { + stream: TcpStream, + addr: SocketAddr, + state: ConnectionState, app: Arc, +} + +impl + 'static> Connection +where + A::ResponseBody: From<&'static str> + Send, + ::Data: Send, + ::Error: StdError + Send + Sync, +{ + async fn run(self) { + let Connection { + stream, + addr, + state, + app, + } = self; + + let service = ConnectionService { addr, app }; + + let builder = Builder::new(TokioExecutor::new()); + let stream = TokioIo::new(stream); + let conn = builder.serve_connection_with_upgrades(stream, service); + pin_mut!(conn); + + let shutting_down = state.is_shutting_down(); + pin_mut!(shutting_down); + + loop { + tokio::select! { + result = conn.as_mut() => { + if let Err(error) = result { + error!(%addr, %error, "failed to serve connection"); + } + break; + } + _ = &mut shutting_down => { + debug!("shutting down connection to {addr}"); + conn.as_mut().graceful_shutdown(); + } + } + } + + debug!("connection to {addr} closed"); + } +} + +#[derive(Clone, Default)] +struct ConnectionState { + /// If `Some` and `closed()`, the server is shutting down; don't accept new requests + shutting_down: Option>>, + /// Keeping this around will allow the server to wait for the connection to finish + _task_done: Option>, +} + +impl ConnectionState { + async fn is_shutting_down(&self) { + poll_fn(|cx| match &self.shutting_down { + Some(tx) => { + let future = tx.closed().fuse(); + pin_mut!(future); + future.poll(cx) + } + None => Poll::Pending, + }) + .await + } +} + +pub struct ConnectionService { addr: SocketAddr, + app: Arc, } -impl Service> for ConnectionService +impl + 'static> Service> + for ConnectionService where A::ResponseBody: From<&'static str>, { @@ -53,11 +259,7 @@ where type Error = Infallible; type Future = UnwindSafeHandlerFuture; - fn poll_ready(&mut self, _: &mut std::task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&self, mut req: Request) -> Self::Future { req.extensions_mut().insert(ClientAddr(self.addr)); let cx = Context::new(self.app.clone(), req); AssertUnwindSafe(A::handle(cx)) @@ -99,15 +301,12 @@ fn panic_response>( .unwrap()) } -impl<'a, A: Application> FromContext<'a, A> for Body -where - A: Application, -{ +impl<'a, A: Application> FromContext<'a, A> for Incoming { fn from_context( _: &'a Arc, _: &'a Parts, _: &mut PathState, - body: &mut Option, + body: &mut Option, ) -> Result { match body.take() { Some(body) => Ok(body), @@ -116,15 +315,12 @@ where } } -impl<'a, A: Application> FromContext<'a, A> for ClientAddr -where - A: Application, -{ +impl<'a, A: Application> FromContext<'a, A> for ClientAddr { fn from_context( _: &'a Arc, req: &'a Parts, _: &mut PathState, - _: &mut Option, + _: &mut Option, ) -> Result { // This is safe because we insert ClientAddr into the request extensions // unconditionally in the ConnectionService::call method. @@ -132,143 +328,24 @@ where } } -#[cfg(feature = "compression")] -#[cfg_attr(docsrs, doc(cfg(feature = "compression")))] -mod encoding { - use std::str::FromStr; - #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] - use std::{io, mem}; - - #[cfg(feature = "brotli")] - use async_compression::tokio::bufread::BrotliEncoder; - #[cfg(feature = "deflate")] - use async_compression::tokio::bufread::DeflateEncoder; - #[cfg(feature = "gzip")] - use async_compression::tokio::bufread::GzipEncoder; - #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] - use futures_util::stream::TryStreamExt; - use http::header::ACCEPT_ENCODING; - #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] - use http::header::{HeaderValue, CONTENT_ENCODING}; - use http::request::Parts; - use http::Response; - use hyper::body::Body; - #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] - use tokio_util::codec::{BytesCodec, FramedRead}; - #[cfg(any(feature = "brotli", feature = "deflate", feature = "gzip"))] - use tokio_util::io::StreamReader; - - #[allow(unused_mut)] // Depends on features - pub fn encode_content(req: &Parts, mut rsp: Response) -> Response { - let accept = match req.headers.get(ACCEPT_ENCODING).map(|hv| hv.to_str()) { - Some(Ok(accept)) => accept, - _ => return rsp, - }; - - let mut encodings = accept - .split(',') - .filter_map(|s| { - let mut parts = s.splitn(2, ';'); - let alg = match Encoding::from_str(parts.next()?.trim()) { - Ok(encoding) => encoding, - Err(()) => return None, - }; - - let qual = parts - .next() - .and_then(|s| { - let mut parts = s.splitn(2, '='); - if parts.next()?.trim() != "q" { - return None; - } - - let value = parts.next()?; - f64::from_str(value).ok() - }) - .unwrap_or(1.0); - - Some((alg, (qual * 100.0) as u64)) - }) - .collect::>(); - encodings.sort_by_key(|(algo, qual)| (-(*qual as i64), *algo)); - - match encodings.first().map(|v| v.0) { - #[cfg(feature = "brotli")] - Some(Encoding::Brotli) => { - let orig = mem::replace(rsp.body_mut(), Body::empty()); - rsp.headers_mut() - .insert(CONTENT_ENCODING, HeaderValue::from_static("br")); - *rsp.body_mut() = Body::wrap_stream(FramedRead::new( - BrotliEncoder::new(StreamReader::new( - orig.map_err(|e| io::Error::new(io::ErrorKind::Other, e)), - )), - BytesCodec::new(), - )); - rsp - } - #[cfg(feature = "gzip")] - Some(Encoding::Gzip) => { - rsp.headers_mut() - .insert(CONTENT_ENCODING, HeaderValue::from_static("gzip")); - let orig = mem::replace(rsp.body_mut(), Body::empty()); - *rsp.body_mut() = Body::wrap_stream(FramedRead::new( - GzipEncoder::new(StreamReader::new( - orig.map_err(|e| io::Error::new(io::ErrorKind::Other, e)), - )), - BytesCodec::new(), - )); - rsp - } - #[cfg(feature = "deflate")] - Some(Encoding::Deflate) => { - rsp.headers_mut() - .insert(CONTENT_ENCODING, HeaderValue::from_static("deflate")); - let orig = mem::replace(rsp.body_mut(), Body::empty()); - *rsp.body_mut() = Body::wrap_stream(FramedRead::new( - DeflateEncoder::new(StreamReader::new( - orig.map_err(|e| io::Error::new(io::ErrorKind::Other, e)), - )), - BytesCodec::new(), - )); - rsp - } - Some(Encoding::Identity) | None => rsp, - } - } +#[derive(Debug)] +pub struct IncomingStream<'a> { + tcp_stream: &'a TokioIo, + remote_addr: SocketAddr, +} - #[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)] - enum Encoding { - #[cfg(feature = "brotli")] - Brotli, - #[cfg(feature = "gzip")] - Gzip, - #[cfg(feature = "deflate")] - Deflate, - Identity, +impl IncomingStream<'_> { + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> std::io::Result { + self.tcp_stream.inner().local_addr() } - impl FromStr for Encoding { - type Err = (); - - fn from_str(s: &str) -> Result { - Ok(match s { - "identity" => Encoding::Identity, - #[cfg(feature = "gzip")] - "gzip" => Encoding::Gzip, - #[cfg(feature = "deflate")] - "deflate" => Encoding::Deflate, - #[cfg(feature = "brotli")] - "br" => Encoding::Brotli, - _ => return Err(()), - }) - } + /// Returns the remote address that this stream is bound to. + pub fn remote_addr(&self) -> SocketAddr { + self.remote_addr } } -#[cfg(feature = "compression")] -#[cfg_attr(docsrs, doc(cfg(feature = "application")))] -pub use encoding::encode_content; - #[derive(Debug, Clone, Copy)] pub struct ClientAddr(SocketAddr); diff --git a/mendes/src/lib.rs b/mendes/src/lib.rs index d7a7b9c..09762d7 100644 --- a/mendes/src/lib.rs +++ b/mendes/src/lib.rs @@ -12,6 +12,12 @@ pub mod application; #[cfg(feature = "application")] pub use application::{handler, route, scope, Application, Context, Error, FromContext}; +#[cfg(feature = "application")] +#[cfg_attr(docsrs, doc(cfg(feature = "application")))] +pub mod body; +#[cfg(feature = "application")] +pub use body::Body; + #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] /// Cookie support diff --git a/mendes/tests/hyper.rs b/mendes/tests/hyper.rs index 14f00ea..66badd6 100644 --- a/mendes/tests/hyper.rs +++ b/mendes/tests/hyper.rs @@ -1,30 +1,30 @@ #![cfg(feature = "hyper")] use std::fmt::{self, Display}; +use std::io; use std::net::SocketAddr; use std::time::Duration; use async_trait::async_trait; +use bytes::Bytes; use mendes::application::IntoResponse; use mendes::http::request::Parts; use mendes::http::{Response, StatusCode}; -use mendes::hyper::{Body, ClientAddr}; -use mendes::{handler, route, Application, Context}; +use mendes::hyper::body::Incoming; +use mendes::hyper::{ClientAddr, Server}; +use mendes::{handler, route, Application, Body, Context}; +use tokio::net::TcpListener; use tokio::task::JoinHandle; use tokio::time::sleep; struct ServerRunner { - handle: JoinHandle<()>, + handle: JoinHandle>, } impl ServerRunner { async fn run(addr: SocketAddr) -> Self { - let handle = tokio::spawn(async move { - hyper::Server::bind(&addr) - .serve(App::default().into_service()) - .await - .unwrap(); - }); + let listener = TcpListener::bind(addr).await.unwrap(); + let handle = tokio::spawn(Server::new(listener, App::default()).serve()); sleep(Duration::from_millis(10)).await; Self { handle } } @@ -61,7 +61,7 @@ struct App {} #[async_trait] impl Application for App { - type RequestBody = Body; + type RequestBody = Incoming; type ResponseBody = Body; type Error = Error; @@ -76,7 +76,10 @@ impl Application for App { async fn client_addr(_: &App, client_addr: ClientAddr) -> Result, Error> { Ok(Response::builder() .status(StatusCode::OK) - .body(Body::from(format!("client_addr: {}", client_addr.ip()))) + .body(Body::from(Bytes::from(format!( + "client_addr: {}", + client_addr.ip() + )))) .unwrap()) } @@ -113,7 +116,7 @@ impl IntoResponse for Error { let Error::Mendes(err) = self; Response::builder() .status(StatusCode::from(&err)) - .body(Body::from(err.to_string())) + .body(Body::from(Bytes::from(err.to_string()))) .unwrap() } } diff --git a/mendes/tests/readme.rs b/mendes/tests/readme.rs index 3ee559f..08ec6dc 100644 --- a/mendes/tests/readme.rs +++ b/mendes/tests/readme.rs @@ -1,14 +1,16 @@ #![cfg(all(feature = "application", feature = "hyper"))] use async_trait::async_trait; -use hyper::Body; +use bytes::Bytes; +use http_body_util::Full; use mendes::application::IntoResponse; use mendes::http::request::Parts; use mendes::http::{Response, StatusCode}; +use mendes::hyper::body::Incoming; use mendes::{handler, route, Application, Context}; #[handler(GET)] -async fn hello(_: &App) -> Result, Error> { +async fn hello(_: &App) -> Result>, Error> { Ok(Response::builder() .status(StatusCode::OK) .body("Hello, world".into()) @@ -19,11 +21,11 @@ struct App {} #[async_trait] impl Application for App { - type RequestBody = (); - type ResponseBody = Body; + type RequestBody = Incoming; + type ResponseBody = Full; type Error = Error; - async fn handle(mut cx: Context) -> Response { + async fn handle(mut cx: Context) -> Response> { route!(match cx.path() { _ => hello, }) @@ -49,11 +51,11 @@ impl From<&Error> for StatusCode { } impl IntoResponse for Error { - fn into_response(self, _: &App, _: &Parts) -> Response { + fn into_response(self, _: &App, _: &Parts) -> Response> { let Error::Mendes(err) = self; Response::builder() .status(StatusCode::from(&err)) - .body(err.to_string().into()) + .body(Full::new(Bytes::from(err.to_string()))) .unwrap() } }