Skip to content

Commit

Permalink
Support graceful shutdown on serve (#2398)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn authored Dec 29, 2023
1 parent 56159b0 commit 12e8c62
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
<axum::http::request::Parts as FromRequestParts<S>>
<Uri as FromRequestParts<S>>
<Version as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
<Extensions as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
and $N others
= note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>`
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
<axum::http::request::Parts as FromRequestParts<S>>
<Uri as FromRequestParts<S>>
<Version as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
<Extensions as FromRequestParts<S>>
and $N others
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** `Body` implements `From<()>` now ([#2411])
- **change:** Update version of multer used internally for multipart ([#2433])
- **change:** Update tokio-tungstenite to 0.21 ([#2435])
- **added:** Support graceful shutdown on `serve` ([#2398])

[#2411]: https://github.com/tokio-rs/axum/pull/2411
[#2433]: https://github.com/tokio-rs/axum/pull/2433
[#2435]: https://github.com/tokio-rs/axum/pull/2435
[#2398]: https://github.com/tokio-rs/axum/pull/2398

# 0.7.2 (03. December, 2023)

Expand Down
6 changes: 3 additions & 3 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ matched-path = []
multipart = ["dep:multer"]
original-uri = []
query = ["dep:serde_urlencoded"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"]
tower-log = ["tower/log"]
tracing = ["dep:tracing", "axum-core/tracing"]
ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"]
Expand Down Expand Up @@ -53,8 +53,8 @@ tower-service = "0.3"
# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true }
base64 = { version = "0.21.0", optional = true }
hyper = { version = "1.0.0", optional = true }
hyper-util = { version = "0.1.1", features = ["tokio", "server", "server-auto"], optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"], optional = true }
multer = { version = "3.0.0", optional = true }
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
serde_path_to_error = { version = "0.1.8", optional = true }
Expand Down
12 changes: 12 additions & 0 deletions axum/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,25 @@ macro_rules! all_the_tuples {
};
}

#[cfg(feature = "tracing")]
macro_rules! trace {
($($tt:tt)*) => {
tracing::trace!($($tt)*)
}
}

#[cfg(feature = "tracing")]
macro_rules! error {
($($tt:tt)*) => {
tracing::error!($($tt)*)
};
}

#[cfg(not(feature = "tracing"))]
macro_rules! trace {
($($tt:tt)*) => {};
}

#[cfg(not(feature = "tracing"))]
macro_rules! error {
($($tt:tt)*) => {};
Expand Down
249 changes: 219 additions & 30 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@

use std::{
convert::Infallible,
future::{Future, IntoFuture},
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::future::poll_fn;
use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use pin_project_lite::pin_project;
use tokio::net::{TcpListener, TcpStream};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::util::{Oneshot, ServiceExt};
use tower_service::Service;

Expand Down Expand Up @@ -110,9 +115,45 @@ pub struct Serve<M, S> {
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> std::fmt::Debug for Serve<M, S>
impl<M, S> Serve<M, S> {
/// Prepares a server to handle graceful shutdown when the provided future completes.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get};
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, router)
/// .with_graceful_shutdown(shutdown_signal())
/// .await
/// .unwrap();
/// # };
///
/// async fn shutdown_signal() {
/// // ...
/// }
/// ```
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
make_service: self.make_service,
signal,
_marker: PhantomData,
}
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
where
M: std::fmt::Debug,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
Expand Down Expand Up @@ -148,30 +189,9 @@ where
} = self;

loop {
let (tcp_stream, remote_addr) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(e) => {
// Connection errors can be ignored directly, continue
// by accepting the next request.
if is_connection_error(&e) {
continue;
}

// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};
let tcp_stream = TokioIo::new(tcp_stream);

Expand All @@ -191,7 +211,7 @@ where
service: tower_service,
};

tokio::task::spawn(async move {
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
Expand All @@ -212,6 +232,149 @@ where
}
}

/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
make_service: M,
signal: F,
_marker: PhantomData<S>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
where
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
signal,
_marker: _,
} = self;

f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut make_service,
signal,
_marker: _,
} = self;

let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});

let (close_tx, close_rx) = watch::channel(());

private::ServeFuture(Box::pin(async move {
loop {
let (tcp_stream, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => {
match conn {
Some(conn) => conn,
None => continue,
}
}
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
let tcp_stream = TokioIo::new(tcp_stream);

trace!("connection {remote_addr} accepted");

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});

let hyper_service = TowerToHyperService {
service: tower_service,
};

let signal_tx = Arc::clone(&signal_tx);

let close_rx = close_rx.clone();

tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
pin_mut!(conn);

let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);

loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}

trace!("connection {remote_addr} closed");

drop(close_rx);
});
}

drop(close_rx);
drop(tcp_listener);

trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;

Ok(())
}))
}
}

fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
Expand All @@ -221,6 +384,32 @@ fn is_connection_error(e: &io::Error) -> bool {
)
}

async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
match listener.accept().await {
Ok(conn) => Some(conn),
Err(e) => {
if is_connection_error(&e) {
return None;
}

// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
}
}

mod private {
use std::{
future::Future,
Expand Down
2 changes: 1 addition & 1 deletion examples/graceful-shutdown/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"
publish = false

[dependencies]
axum = { path = "../../axum" }
axum = { path = "../../axum", features = ["tracing"] }
hyper = { version = "1.0", features = [] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
Expand Down
Loading

0 comments on commit 12e8c62

Please sign in to comment.