diff --git a/Cargo.toml b/Cargo.toml index eb585cab4..400480ea1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,7 +102,7 @@ mime_guess = { version = "2.0", default-features = false, optional = true } encoding_rs = "0.8" http-body = "0.4.0" hyper = { version = "0.14.21", default-features = false, features = ["tcp", "http1", "http2", "client", "runtime"] } -h2 = "0.3.10" +h2 = "0.3.14" once_cell = "1" log = "0.4" mime = "0.3.16" @@ -155,6 +155,7 @@ libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } doc-comment = "0.3" tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] } +futures-util = { version = "0.3.0", default-features = false, features = ["std", "alloc"] } [target.'cfg(windows)'.dependencies] winreg = "0.50.0" diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index f40f2486f..ae977f019 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -2233,9 +2233,16 @@ fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool { if let Some(cause) = err.source() { if let Some(err) = cause.downcast_ref::() { // They sent us a graceful shutdown, try with a new connection! - return err.is_go_away() - && err.is_remote() - && err.reason() == Some(h2::Reason::NO_ERROR); + if err.is_go_away() && err.is_remote() && err.reason() == Some(h2::Reason::NO_ERROR) { + return true; + } + + // REFUSED_STREAM was sent from the server, which is safe to retry. + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.7-3.2 + if err.is_reset() && err.is_remote() && err.reason() == Some(h2::Reason::REFUSED_STREAM) + { + return true; + } } } false diff --git a/tests/client.rs b/tests/client.rs index 012751e5b..5b11b4a73 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2,6 +2,7 @@ mod support; use futures_util::stream::StreamExt; +use support::delay_server; use support::server; #[cfg(feature = "json")] @@ -487,3 +488,70 @@ async fn test_tls_info() { let tls_info = resp.extensions().get::(); assert!(tls_info.is_none()); } + +// NOTE: using the default "curernt_thread" runtime here would cause the test to +// fail, because the only thread would block until `panic_rx` receives a +// notification while the client needs to be driven to get the graceful shutdown +// done. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn highly_concurrent_requests_to_http2_server_with_low_max_concurrent_streams() { + let client = reqwest::Client::builder() + .http2_prior_knowledge() + .build() + .unwrap(); + + let server = server::http_with_config( + move |req| async move { + assert_eq!(req.version(), http::Version::HTTP_2); + http::Response::default() + }, + |builder| builder.http2_only(true).http2_max_concurrent_streams(1), + ); + + let url = format!("http://{}", server.addr()); + + let futs = (0..100).map(|_| { + let client = client.clone(); + let url = url.clone(); + async move { + let res = client.get(&url).send().await.unwrap(); + assert_eq!(res.status(), reqwest::StatusCode::OK); + } + }); + futures_util::future::join_all(futs).await; +} + +#[tokio::test] +async fn highly_concurrent_requests_to_slow_http2_server_with_low_max_concurrent_streams() { + let client = reqwest::Client::builder() + .http2_prior_knowledge() + .build() + .unwrap(); + + let server = delay_server::Server::new( + move |req| async move { + assert_eq!(req.version(), http::Version::HTTP_2); + http::Response::default() + }, + |mut http| { + http.http2_only(true).http2_max_concurrent_streams(1); + http + }, + std::time::Duration::from_secs(2), + ) + .await; + + let url = format!("http://{}", server.addr()); + + let futs = (0..100).map(|_| { + let client = client.clone(); + let url = url.clone(); + async move { + let res = client.get(&url).send().await.unwrap(); + assert_eq!(res.status(), reqwest::StatusCode::OK); + } + }); + futures_util::future::join_all(futs).await; + + server.shutdown().await; +} diff --git a/tests/support/delay_server.rs b/tests/support/delay_server.rs new file mode 100644 index 000000000..08f421598 --- /dev/null +++ b/tests/support/delay_server.rs @@ -0,0 +1,119 @@ +#![cfg(not(target_arch = "wasm32"))] +use std::convert::Infallible; +use std::future::Future; +use std::net; +use std::sync::Arc; +use std::time::Duration; + +use futures_util::FutureExt; +use http::{Request, Response}; +use hyper::service::service_fn; +use hyper::Body; +use tokio::net::TcpListener; +use tokio::select; +use tokio::sync::oneshot; + +/// This server, unlike [`super::server::Server`], allows for delaying the +/// specified amount of time after each TCP connection is established. This is +/// useful for testing the behavior of the client when the server is slow. +/// +/// For example, in case of HTTP/2, once the TCP/TLS connection is established, +/// both endpoints are supposed to send a preface and an initial `SETTINGS` +/// frame (See [RFC9113 3.4] for details). What if these frames are delayed for +/// whatever reason? This server allows for testing such scenarios. +/// +/// [RFC9113 3.4]: https://www.rfc-editor.org/rfc/rfc9113.html#name-http-2-connection-preface +pub struct Server { + addr: net::SocketAddr, + shutdown_tx: Option>, + server_terminated_rx: oneshot::Receiver<()>, +} + +impl Server { + pub async fn new(func: F1, apply_config: F2, delay: Duration) -> Self + where + F1: Fn(Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, + F2: FnOnce(hyper::server::conn::Http) -> hyper::server::conn::Http + Send + 'static, + { + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let (server_terminated_tx, server_terminated_rx) = oneshot::channel(); + + let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + tokio::spawn(async move { + let http = Arc::new(apply_config(hyper::server::conn::Http::new())); + + tokio::spawn(async move { + let (connection_shutdown_tx, connection_shutdown_rx) = oneshot::channel(); + let connection_shutdown_rx = connection_shutdown_rx.shared(); + let mut shutdown_rx = std::pin::pin!(shutdown_rx); + + let mut handles = Vec::new(); + loop { + select! { + _ = shutdown_rx.as_mut() => { + connection_shutdown_tx.send(()).unwrap(); + break; + } + res = tcp_listener.accept() => { + let (stream, _) = res.unwrap(); + + + let handle = tokio::spawn({ + let connection_shutdown_rx = connection_shutdown_rx.clone(); + let http = http.clone(); + let func = func.clone(); + + async move { + tokio::time::sleep(delay).await; + + let mut conn = std::pin::pin!(http.serve_connection( + stream, + service_fn(move |req| { + let fut = func(req); + async move { + Ok::<_, Infallible>(fut.await) + }}) + )); + + select! { + _ = conn.as_mut() => {} + _ = connection_shutdown_rx => { + conn.as_mut().graceful_shutdown(); + conn.await.unwrap(); + } + } + } + }); + + handles.push(handle); + } + } + } + + futures_util::future::join_all(handles).await; + server_terminated_tx.send(()).unwrap(); + }); + }); + + Self { + addr, + shutdown_tx: Some(shutdown_tx), + server_terminated_rx, + } + } + + pub async fn shutdown(mut self) { + if let Some(tx) = self.shutdown_tx.take() { + let _ = tx.send(()); + } + + self.server_terminated_rx.await.unwrap(); + } + + pub fn addr(&self) -> net::SocketAddr { + self.addr + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index cef2170f2..c796956d8 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_server; pub mod server; // TODO: remove once done converting to new support server? diff --git a/tests/support/server.rs b/tests/support/server.rs index 4ac1a4a77..5193a5fbe 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -1,15 +1,14 @@ #![cfg(not(target_arch = "wasm32"))] -use std::convert::Infallible; +use std::convert::{identity, Infallible}; use std::future::Future; use std::net; use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; -use tokio::sync::oneshot; - -pub use http::Response; +use hyper::server::conn::AddrIncoming; use tokio::runtime; +use tokio::sync::oneshot; pub struct Server { addr: net::SocketAddr, @@ -42,24 +41,35 @@ where F: Fn(http::Request) -> Fut + Clone + Send + 'static, Fut: Future> + Send + 'static, { - //Spawn new runtime in thread to prevent reactor execution context conflict + http_with_config(func, identity) +} + +pub fn http_with_config(func: F1, apply_config: F2) -> Server +where + F1: Fn(http::Request) -> Fut + Clone + Send + 'static, + Fut: Future> + Send + 'static, + F2: FnOnce(hyper::server::Builder) -> hyper::server::Builder + + Send + + 'static, +{ + // Spawn new runtime in thread to prevent reactor execution context conflict thread::spawn(move || { let rt = runtime::Builder::new_current_thread() .enable_all() .build() .expect("new rt"); let srv = rt.block_on(async move { - hyper::Server::bind(&([127, 0, 0, 1], 0).into()).serve(hyper::service::make_service_fn( - move |_| { - let func = func.clone(); - async move { - Ok::<_, Infallible>(hyper::service::service_fn(move |req| { - let fut = func(req); - async move { Ok::<_, Infallible>(fut.await) } - })) - } - }, - )) + let builder = hyper::Server::bind(&([127, 0, 0, 1], 0).into()); + + apply_config(builder).serve(hyper::service::make_service_fn(move |_| { + let func = func.clone(); + async move { + Ok::<_, Infallible>(hyper::service::service_fn(move |req| { + let fut = func(req); + async move { Ok::<_, Infallible>(fut.await) } + })) + } + })) }); let addr = srv.local_addr();