diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 33b3c84abe..bd5bacbd05 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -287,3 +287,28 @@ message = "The modules in generated client crates have been reorganized. See the references = ["smithy-rs#2448"] meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "client" } author = "jdisanti" + +[[aws-sdk-rust]] +message = """Reconnect on transient errors. + +If a transient error (timeout, 500, 503, 503) is encountered, the connection will be evicted from the pool and will not +be reused. This is enabled by default for all AWS services. It can be disabled by setting `RetryConfig::with_reconnect_mode` + +Although there is no API breakage from this change, it alters the client behavior in a way that may cause breakage for customers. +""" +references = ["aws-sdk-rust#160", "smithy-rs#2445"] +meta = { "breaking" = true, "tada" = false, "bug" = false } +author = "rcoh" + +[[smithy-rs]] +message = """Reconnect on transient errors. + +Note: **this behavior is disabled by default for generic clients**. It can be enabled with +`aws_smithy_client::Builder::reconnect_on_transient_errors` + +If a transient error (timeout, 500, 503, 503) is encountered, the connection will be evicted from the pool and will not +be reused. +""" +references = ["aws-sdk-rust#160", "smithy-rs#2445"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client" } +author = "rcoh" diff --git a/aws/rust-runtime/aws-config/Cargo.toml b/aws/rust-runtime/aws-config/Cargo.toml index 2a43647379..87fb754a01 100644 --- a/aws/rust-runtime/aws-config/Cargo.toml +++ b/aws/rust-runtime/aws-config/Cargo.toml @@ -30,7 +30,7 @@ aws-smithy-types = { path = "../../sdk/build/aws-sdk/sdk/aws-smithy-types" } aws-types = { path = "../../sdk/build/aws-sdk/sdk/aws-types" } hyper = { version = "0.14.12", default-features = false } time = { version = "0.3.4", features = ["parsing"] } -tokio = { version = "1.8.4", features = ["sync"] } +tokio = { version = "1.13.1", features = ["sync"] } tracing = { version = "0.1" } # implementation detail of SSO credential caching diff --git a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt index 43cef7a3de..1c3ceed72e 100644 --- a/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt +++ b/aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsFluentClientDecorator.kt @@ -208,6 +208,7 @@ private class AwsFluentClientExtensions(types: Types) { }; let mut builder = builder .middleware(#{DynMiddleware}::new(#{Middleware}::new())) + .reconnect_mode(retry_config.reconnect_mode()) .retry_config(retry_config.into()) .operation_timeout_config(timeout_config.into()); builder.set_sleep_impl(sleep_impl); @@ -257,6 +258,7 @@ private fun renderCustomizableOperationSendMethod( "combined_generics_decl" to combinedGenerics.declaration(), "handle_generics_bounds" to handleGenerics.bounds(), "SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig), + "SdkError" to RuntimeType.sdkError(runtimeConfig), "ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig), "ParseHttpResponse" to RuntimeType.parseHttpResponse(runtimeConfig), ) @@ -272,7 +274,7 @@ private fun renderCustomizableOperationSendMethod( where E: std::error::Error + Send + Sync + 'static, O: #{ParseHttpResponse}> + Send + Sync + Clone + 'static, - Retry: #{ClassifyRetry}<#{SdkSuccess}, SdkError> + Send + Sync + Clone, + Retry: #{ClassifyRetry}<#{SdkSuccess}, #{SdkError}> + Send + Sync + Clone, { self.handle.client.call(self.operation).await } diff --git a/aws/sdk/integration-tests/s3/tests/reconnects.rs b/aws/sdk/integration-tests/s3/tests/reconnects.rs new file mode 100644 index 0000000000..85afcd40a9 --- /dev/null +++ b/aws/sdk/integration-tests/s3/tests/reconnects.rs @@ -0,0 +1,99 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_credential_types::provider::SharedCredentialsProvider; +use aws_credential_types::Credentials; +use aws_smithy_async::rt::sleep::TokioSleep; +use aws_smithy_client::test_connection::wire_mock::{ + check_matches, ReplayedEvent, WireLevelTestConnection, +}; +use aws_smithy_client::{ev, match_events}; +use aws_smithy_types::retry::{ReconnectMode, RetryConfig}; +use aws_types::region::Region; +use aws_types::SdkConfig; +use std::sync::Arc; + +#[tokio::test] +/// test that disabling reconnects on retry config disables them for the client +async fn disable_reconnects() { + let mock = WireLevelTestConnection::spinup(vec![ + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::with_body("here-is-your-object"), + ]) + .await; + + let sdk_config = SdkConfig::builder() + .region(Region::from_static("us-east-2")) + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) + .sleep_impl(Arc::new(TokioSleep::new())) + .endpoint_url(mock.endpoint_url()) + .http_connector(mock.http_connector()) + .retry_config( + RetryConfig::standard().with_reconnect_mode(ReconnectMode::ReuseAllConnections), + ) + .build(); + let client = aws_sdk_s3::Client::new(&sdk_config); + let resp = client + .get_object() + .bucket("bucket") + .key("key") + .send() + .await + .expect("succeeds after retries"); + assert_eq!( + resp.body.collect().await.unwrap().to_vec(), + b"here-is-your-object" + ); + match_events!( + ev!(dns), + ev!(connect), + ev!(http(503)), + ev!(http(503)), + ev!(http(200)) + )(&mock.events()); +} + +#[tokio::test] +async fn reconnect_on_503() { + let mock = WireLevelTestConnection::spinup(vec![ + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::with_body("here-is-your-object"), + ]) + .await; + + let sdk_config = SdkConfig::builder() + .region(Region::from_static("us-east-2")) + .credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests())) + .sleep_impl(Arc::new(TokioSleep::new())) + .endpoint_url(mock.endpoint_url()) + .http_connector(mock.http_connector()) + .retry_config(RetryConfig::standard()) + .build(); + let client = aws_sdk_s3::Client::new(&sdk_config); + let resp = client + .get_object() + .bucket("bucket") + .key("key") + .send() + .await + .expect("succeeds after retries"); + assert_eq!( + resp.body.collect().await.unwrap().to_vec(), + b"here-is-your-object" + ); + match_events!( + ev!(dns), + ev!(connect), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(200)) + )(&mock.events()); +} diff --git a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt index ba17a3e4a1..31ca1eae9f 100644 --- a/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt +++ b/codegen-client/src/main/kotlin/software/amazon/smithy/rust/codegen/client/smithy/generators/client/CustomizableOperationGenerator.kt @@ -41,7 +41,7 @@ class CustomizableOperationGenerator( "Operation" to smithyHttp.resolve("operation::Operation"), "Request" to smithyHttp.resolve("operation::Request"), "Response" to smithyHttp.resolve("operation::Response"), - "ClassifyRetry" to smithyHttp.resolve("retry::ClassifyRetry"), + "ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig), "RetryKind" to smithyTypes.resolve("retry::RetryKind"), ) renderCustomizableOperationModule(this) @@ -150,6 +150,9 @@ class CustomizableOperationGenerator( "ParseHttpResponse" to smithyHttp.resolve("response::ParseHttpResponse"), "NewRequestPolicy" to smithyClient.resolve("retry::NewRequestPolicy"), "SmithyRetryPolicy" to smithyClient.resolve("bounds::SmithyRetryPolicy"), + "ClassifyRetry" to RuntimeType.classifyRetry(runtimeConfig), + "SdkSuccess" to RuntimeType.sdkSuccess(runtimeConfig), + "SdkError" to RuntimeType.sdkError(runtimeConfig), ) writer.rustTemplate( @@ -164,6 +167,7 @@ class CustomizableOperationGenerator( E: std::error::Error + Send + Sync + 'static, O: #{ParseHttpResponse}> + Send + Sync + Clone + 'static, Retry: Send + Sync + Clone, + Retry: #{ClassifyRetry}<#{SdkSuccess}, #{SdkError}> + Send + Sync + Clone, ::Policy: #{SmithyRetryPolicy} + Clone, { self.handle.client.call(self.operation).await diff --git a/rust-runtime/Cargo.toml b/rust-runtime/Cargo.toml index 6a53b080e9..9c03916d1c 100644 --- a/rust-runtime/Cargo.toml +++ b/rust-runtime/Cargo.toml @@ -1,5 +1,6 @@ [workspace] + members = [ "inlineable", "aws-smithy-async", diff --git a/rust-runtime/aws-smithy-client/Cargo.toml b/rust-runtime/aws-smithy-client/Cargo.toml index d8f8041bbd..bdd038f8bf 100644 --- a/rust-runtime/aws-smithy-client/Cargo.toml +++ b/rust-runtime/aws-smithy-client/Cargo.toml @@ -9,12 +9,13 @@ repository = "https://github.com/awslabs/smithy-rs" [features] rt-tokio = ["aws-smithy-async/rt-tokio"] -test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls"] +test-util = ["aws-smithy-protocol-test", "serde/derive", "rustls", "hyper/server", "hyper/h2", "tokio/full"] native-tls = ["client-hyper", "hyper-tls", "rt-tokio"] rustls = ["client-hyper", "hyper-rustls", "rt-tokio", "lazy_static"] client-hyper = ["hyper"] hyper-webpki-doctest-only = ["hyper-rustls/webpki-roots"] + [dependencies] aws-smithy-async = { path = "../aws-smithy-async" } aws-smithy-http = { path = "../aws-smithy-http" } @@ -25,7 +26,7 @@ bytes = "1" fastrand = "1.4.0" http = "0.2.3" http-body = "0.4.4" -hyper = { version = "0.14.12", features = ["client", "http2", "http1", "tcp"], optional = true } +hyper = { version = "0.14.25", features = ["client", "http2", "http1", "tcp"], optional = true } # cargo does not support optional test dependencies, so to completely disable rustls when # the native-tls feature is enabled, we need to add the webpki-roots feature here. # https://github.com/rust-lang/cargo/issues/1596 @@ -34,7 +35,7 @@ hyper-tls = { version = "0.5.0", optional = true } lazy_static = { version = "1", optional = true } pin-project-lite = "0.2.7" serde = { version = "1", features = ["derive"], optional = true } -tokio = { version = "1.8.4" } +tokio = { version = "1.13.1" } tower = { version = "0.4.6", features = ["util", "retry"] } tracing = "0.1" @@ -44,6 +45,9 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1.8.4", features = ["full", "test-util"] } tower-test = "0.4.0" +tracing-subscriber = "0.3.16" +tracing-test = "0.2.4" + [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-client/external-types.toml b/rust-runtime/aws-smithy-client/external-types.toml index fd2d76368f..0bab3e6536 100644 --- a/rust-runtime/aws-smithy-client/external-types.toml +++ b/rust-runtime/aws-smithy-client/external-types.toml @@ -21,10 +21,12 @@ allowed_external_types = [ "tokio::io::async_read::AsyncRead", "tokio::io::async_write::AsyncWrite", + # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Once tooling permits it, only allow the following types in the `test-utils` feature "bytes::bytes::Bytes", "serde::ser::Serialize", "serde::de::Deserialize", + "hyper::client::connect::dns::Name", # TODO(https://github.com/awslabs/smithy-rs/issues/1193): Decide if we want to continue exposing tower_layer "tower_layer::Layer", diff --git a/rust-runtime/aws-smithy-client/src/builder.rs b/rust-runtime/aws-smithy-client/src/builder.rs index 1fe4ba12eb..d226dcb1cf 100644 --- a/rust-runtime/aws-smithy-client/src/builder.rs +++ b/rust-runtime/aws-smithy-client/src/builder.rs @@ -7,6 +7,7 @@ use crate::{bounds, erase, retry, Client}; use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep}; use aws_smithy_http::body::SdkBody; use aws_smithy_http::result::ConnectorError; +use aws_smithy_types::retry::ReconnectMode; use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig}; use std::sync::Arc; @@ -37,6 +38,12 @@ pub struct Builder { retry_policy: MaybeRequiresSleep, operation_timeout_config: Option, sleep_impl: Option>, + reconnect_mode: Option, +} + +/// transitional default: disable this behavior by default +const fn default_reconnect_mode() -> ReconnectMode { + ReconnectMode::ReuseAllConnections } impl Default for Builder @@ -55,6 +62,7 @@ where ), operation_timeout_config: None, sleep_impl: default_async_sleep(), + reconnect_mode: Some(default_reconnect_mode()), } } } @@ -173,6 +181,7 @@ impl Builder<(), M, R> { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -229,6 +238,7 @@ impl Builder { operation_timeout_config: self.operation_timeout_config, middleware, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -280,6 +290,7 @@ impl Builder { operation_timeout_config: self.operation_timeout_config, middleware: self.middleware, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } } @@ -347,6 +358,7 @@ impl Builder { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } @@ -361,9 +373,41 @@ impl Builder { retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, no reconnection occurs. + /// + /// When enabled and a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host. + pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self { + self.set_reconnect_mode(Some(reconnect_mode)); + self + } + + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, no reconnection occurs. + /// + /// When enabled and a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host. + pub fn set_reconnect_mode(&mut self, reconnect_mode: Option) -> &mut Self { + self.reconnect_mode = reconnect_mode; + self + } + + /// Enable reconnection on transient errors + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + pub fn reconnect_on_transient_errors(self) -> Self { + self.reconnect_mode(ReconnectMode::ReconnectOnTransientError) + } + /// Build a Smithy service [`Client`]. pub fn build(self) -> Client { let operation_timeout_config = self @@ -392,6 +436,7 @@ impl Builder { middleware: self.middleware, operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode.unwrap_or(default_reconnect_mode()), } } } diff --git a/rust-runtime/aws-smithy-client/src/erase.rs b/rust-runtime/aws-smithy-client/src/erase.rs index 2cac5afeaa..648562192c 100644 --- a/rust-runtime/aws-smithy-client/src/erase.rs +++ b/rust-runtime/aws-smithy-client/src/erase.rs @@ -61,6 +61,7 @@ where retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } } @@ -101,6 +102,7 @@ where retry_policy: self.retry_policy, operation_timeout_config: self.operation_timeout_config, sleep_impl: self.sleep_impl, + reconnect_mode: self.reconnect_mode, } } diff --git a/rust-runtime/aws-smithy-client/src/hyper_ext.rs b/rust-runtime/aws-smithy-client/src/hyper_ext.rs index 11a27c1d53..f059467829 100644 --- a/rust-runtime/aws-smithy-client/src/hyper_ext.rs +++ b/rust-runtime/aws-smithy-client/src/hyper_ext.rs @@ -92,13 +92,22 @@ use crate::never::stream::EmptyStream; use aws_smithy_async::future::timeout::TimedOutError; use aws_smithy_async::rt::sleep::{default_async_sleep, AsyncSleep}; use aws_smithy_http::body::SdkBody; + use aws_smithy_http::result::ConnectorError; use aws_smithy_types::error::display::DisplayErrorContext; use aws_smithy_types::retry::ErrorKind; -use http::Uri; -use hyper::client::connect::{Connected, Connection}; +use http::{Extensions, Uri}; +use hyper::client::connect::{ + capture_connection, CaptureConnection, Connected, Connection, HttpInfo, +}; + use std::error::Error; +use std::fmt::Debug; + use std::sync::Arc; + +use crate::erase::boxclone::BoxFuture; +use aws_smithy_http::connection::{CaptureSmithyConnection, ConnectionMetadata}; use tokio::io::{AsyncRead, AsyncWrite}; use tower::{BoxError, Service}; @@ -107,8 +116,30 @@ use tower::{BoxError, Service}; /// This adapter also enables TCP `CONNECT` and HTTP `READ` timeouts via [`Adapter::builder`]. For examples /// see [the module documentation](crate::hyper_ext). #[derive(Clone, Debug)] -#[non_exhaustive] -pub struct Adapter(HttpReadTimeout, SdkBody>>); +pub struct Adapter { + client: HttpReadTimeout, SdkBody>>, +} + +/// Extract a smithy connection from a hyper CaptureConnection +fn extract_smithy_connection(capture_conn: &CaptureConnection) -> Option { + let capture_conn = capture_conn.clone(); + if let Some(conn) = capture_conn.clone().connection_metadata().as_ref() { + let mut extensions = Extensions::new(); + conn.get_extras(&mut extensions); + let http_info = extensions.get::(); + let smithy_connection = ConnectionMetadata::new( + conn.is_proxied(), + http_info.map(|info| info.remote_addr()), + move || match capture_conn.connection_metadata().as_ref() { + Some(conn) => conn.poison(), + None => tracing::trace!("no connection existed to poison"), + }, + ); + Some(smithy_connection) + } else { + None + } +} impl Service> for Adapter where @@ -121,20 +152,22 @@ where type Response = http::Response; type Error = ConnectorError; - #[allow(clippy::type_complexity)] - type Future = std::pin::Pin< - Box> + Send + 'static>, - >; + type Future = BoxFuture; fn poll_ready( &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.0.poll_ready(cx).map_err(downcast_error) + self.client.poll_ready(cx).map_err(downcast_error) } - fn call(&mut self, req: http::Request) -> Self::Future { - let fut = self.0.call(req); + fn call(&mut self, mut req: http::Request) -> Self::Future { + let capture_connection = capture_connection(&mut req); + if let Some(capture_smithy_connection) = req.extensions().get::() { + capture_smithy_connection + .set_connection_retriever(move || extract_smithy_connection(&capture_connection)); + } + let fut = self.client.call(req); Box::pin(async move { Ok(fut.await.map_err(downcast_error)?.map(SdkBody::from)) }) } } @@ -271,7 +304,9 @@ impl Builder { ), None => HttpReadTimeout::no_timeout(base), }; - Adapter(read_timeout) + Adapter { + client: read_timeout, + } } /// Set the async sleep implementation used for timeouts @@ -343,7 +378,6 @@ mod timeout_middleware { use pin_project_lite::pin_project; use tower::BoxError; - use aws_smithy_async::future; use aws_smithy_async::future::timeout::{TimedOutError, Timeout}; use aws_smithy_async::rt::sleep::AsyncSleep; use aws_smithy_async::rt::sleep::Sleep; @@ -493,7 +527,7 @@ mod timeout_middleware { Some((sleep, duration)) => { let sleep = sleep.sleep(*duration); MaybeTimeoutFuture::Timeout { - timeout: future::timeout::Timeout::new(self.inner.call(req), sleep), + timeout: Timeout::new(self.inner.call(req), sleep), error_type: "HTTP connect", duration: *duration, } @@ -522,7 +556,7 @@ mod timeout_middleware { Some((sleep, duration)) => { let sleep = sleep.sleep(*duration); MaybeTimeoutFuture::Timeout { - timeout: future::timeout::Timeout::new(self.inner.call(req), sleep), + timeout: Timeout::new(self.inner.call(req), sleep), error_type: "HTTP read", duration: *duration, } diff --git a/rust-runtime/aws-smithy-client/src/lib.rs b/rust-runtime/aws-smithy-client/src/lib.rs index 6e5e5ba9ee..479641704d 100644 --- a/rust-runtime/aws-smithy-client/src/lib.rs +++ b/rust-runtime/aws-smithy-client/src/lib.rs @@ -26,6 +26,7 @@ pub mod bounds; pub mod erase; pub mod http_connector; pub mod never; +mod poison; pub mod retry; pub mod timeout; @@ -50,14 +51,17 @@ pub mod hyper_ext; #[doc(hidden)] pub mod static_tests; +use crate::poison::PoisonLayer; use aws_smithy_async::rt::sleep::AsyncSleep; + use aws_smithy_http::operation::Operation; use aws_smithy_http::response::ParseHttpResponse; pub use aws_smithy_http::result::{SdkError, SdkSuccess}; +use aws_smithy_http::retry::ClassifyRetry; use aws_smithy_http_tower::dispatch::DispatchLayer; use aws_smithy_http_tower::parse_response::ParseResponseLayer; use aws_smithy_types::error::display::DisplayErrorContext; -use aws_smithy_types::retry::ProvideErrorKind; +use aws_smithy_types::retry::{ProvideErrorKind, ReconnectMode}; use aws_smithy_types::timeout::OperationTimeoutConfig; use std::sync::Arc; use timeout::ClientTimeoutParams; @@ -93,6 +97,7 @@ pub struct Client< connector: Connector, middleware: Middleware, retry_policy: RetryPolicy, + reconnect_mode: ReconnectMode, operation_timeout_config: OperationTimeoutConfig, sleep_impl: Option>, } @@ -140,6 +145,7 @@ where E: std::error::Error + Send + Sync + 'static, Retry: Send + Sync, R::Policy: bounds::SmithyRetryPolicy, + Retry: ClassifyRetry, SdkError>, bounds::Parsed<>::Service, O, Retry>: Service, Response = SdkSuccess, Error = SdkError> + Clone, { @@ -159,6 +165,7 @@ where E: std::error::Error + Send + Sync + 'static, Retry: Send + Sync, R::Policy: bounds::SmithyRetryPolicy, + Retry: ClassifyRetry, SdkError>, // This bound is not _technically_ inferred by all the previous bounds, but in practice it // is because _we_ know that there is only implementation of Service for Parsed // (ParsedResponseService), and it will apply as long as the bounds on C, M, and R hold, @@ -179,6 +186,7 @@ where self.retry_policy .new_request_policy(self.sleep_impl.clone()), ) + .layer(PoisonLayer::new(self.reconnect_mode)) .layer(TimeoutLayer::new(timeout_params.operation_attempt_timeout)) .layer(ParseResponseLayer::::new()) // These layers can be considered as occurring in order. That is, first invoke the diff --git a/rust-runtime/aws-smithy-client/src/poison.rs b/rust-runtime/aws-smithy-client/src/poison.rs new file mode 100644 index 0000000000..ffbaaf8abc --- /dev/null +++ b/rust-runtime/aws-smithy-client/src/poison.rs @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Connection Poisoning +//! +//! The client supports behavior where on transient errors (e.g. timeouts, 503, etc.) it will ensure +//! that the offending connection is not reused. This happens to ensure that in the case where the +//! connection itself is broken (e.g. connected to a bad host) we don't reuse it for other requests. +//! +//! This relies on a series of mechanisms: +//! 1. [`CaptureSmithyConnection`] is a container which exists in the operation property bag. It is +//! inserted by this layer before the request is sent. +//! 2. The [`DispatchLayer`](aws_smithy_http_tower::dispatch::DispatchLayer) copies the field from operation extensions HTTP request extensions. +//! 3. The HTTP layer (e.g. Hyper) sets [`ConnectionMetadata`](aws_smithy_http::connection::ConnectionMetadata) +//! when it is available. +//! 4. When the response comes back, if indicated, this layer invokes +//! [`ConnectionMetadata::poison`](aws_smithy_http::connection::ConnectionMetadata::poison). +//! +//! ### Why isn't this integrated into `retry.rs`? +//! If the request has a streaming body, we won't attempt to retry because [`Operation::try_clone()`] will +//! return `None`. Therefore, we need to handle this inside of the retry loop. + +use std::future::Future; + +use aws_smithy_http::operation::Operation; +use aws_smithy_http::result::{SdkError, SdkSuccess}; +use aws_smithy_http::retry::ClassifyRetry; + +use aws_smithy_http::connection::CaptureSmithyConnection; +use aws_smithy_types::retry::{ErrorKind, ReconnectMode, RetryKind}; +use pin_project_lite::pin_project; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// PoisonLayer that poisons connections depending on the error kind +pub(crate) struct PoisonLayer { + inner: PhantomData, + mode: ReconnectMode, +} + +impl PoisonLayer { + pub(crate) fn new(mode: ReconnectMode) -> Self { + Self { + inner: Default::default(), + mode, + } + } +} + +impl Clone for PoisonLayer { + fn clone(&self) -> Self { + Self { + inner: Default::default(), + mode: self.mode, + } + } +} + +impl tower::Layer for PoisonLayer { + type Service = PoisonService; + + fn layer(&self, inner: S) -> Self::Service { + PoisonService { + inner, + mode: self.mode, + } + } +} + +#[derive(Clone)] +pub(crate) struct PoisonService { + inner: S, + mode: ReconnectMode, +} + +impl tower::Service> for PoisonService +where + R: ClassifyRetry, SdkError>, + S: tower::Service, Response = SdkSuccess, Error = SdkError>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = PoisonServiceFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Operation) -> Self::Future { + let classifier = req.retry_classifier().clone(); + let capture_smithy_connection = CaptureSmithyConnection::new(); + req.properties_mut() + .insert(capture_smithy_connection.clone()); + PoisonServiceFuture { + inner: self.inner.call(req), + conn: capture_smithy_connection, + mode: self.mode, + classifier, + } + } +} + +pin_project! { + pub struct PoisonServiceFuture { + #[pin] + inner: F, + classifier: R, + conn: CaptureSmithyConnection, + mode: ReconnectMode + } +} + +impl Future for PoisonServiceFuture +where + F: Future, SdkError>>, + R: ClassifyRetry, SdkError>, +{ + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.inner.poll(cx) { + Poll::Ready(resp) => { + let retry_kind = this.classifier.classify_retry(resp.as_ref()); + if this.mode == &ReconnectMode::ReconnectOnTransientError + && retry_kind == RetryKind::Error(ErrorKind::TransientError) + { + if let Some(smithy_conn) = this.conn.get() { + tracing::info!("poisoning connection: {:?}", smithy_conn); + smithy_conn.poison(); + } else { + tracing::trace!("No smithy connection found! The underlying HTTP connection never set a connection."); + } + } + Poll::Ready(resp) + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/rust-runtime/aws-smithy-client/src/retry.rs b/rust-runtime/aws-smithy-client/src/retry.rs index 7e6ceff10c..10df4ae5fc 100644 --- a/rust-runtime/aws-smithy-client/src/retry.rs +++ b/rust-runtime/aws-smithy-client/src/retry.rs @@ -17,14 +17,15 @@ use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::time::Duration; -use crate::{SdkError, SdkSuccess}; +use tracing::Instrument; use aws_smithy_async::rt::sleep::AsyncSleep; + use aws_smithy_http::operation::Operation; use aws_smithy_http::retry::ClassifyRetry; use aws_smithy_types::retry::{ErrorKind, RetryKind}; -use tracing::Instrument; +use crate::{SdkError, SdkSuccess}; /// A policy instantiator. /// @@ -292,9 +293,20 @@ impl RetryHandler { fn should_retry_error(&self, error_kind: &ErrorKind) -> Option<(Self, Duration)> { let quota_used = { if self.local.attempts == self.config.max_attempts { + tracing::trace!( + attempts = self.local.attempts, + max_attempts = self.config.max_attempts, + "not retrying becuase we are out of attempts" + ); return None; } - self.shared.quota_acquire(error_kind, &self.config)? + match self.shared.quota_acquire(error_kind, &self.config) { + Some(quota) => quota, + None => { + tracing::trace!(state = ?self.shared, "not retrying because no quota is available"); + return None; + } + } }; let backoff = calculate_exponential_backoff( // Generate a random base multiplier to create jitter @@ -334,7 +346,9 @@ impl RetryHandler { } fn retry_for(&self, retry_kind: RetryKind) -> Option> { - let (next, dur) = self.should_retry(&retry_kind)?; + let retry = self.should_retry(&retry_kind); + tracing::trace!(retry=?retry, retry_kind = ?retry_kind, "retry action"); + let (next, dur) = retry?; let sleep = match &self.sleep_impl { Some(sleep) => sleep, @@ -377,6 +391,7 @@ where ) -> Option { let classifier = req.retry_classifier(); let retry_kind = classifier.classify_retry(result); + tracing::trace!(retry_kind = ?retry_kind, "retry classification"); self.retry_for(retry_kind) } diff --git a/rust-runtime/aws-smithy-client/src/test_connection.rs b/rust-runtime/aws-smithy-client/src/test_connection.rs index d7b0b15ece..3a119cb314 100644 --- a/rust-runtime/aws-smithy-client/src/test_connection.rs +++ b/rust-runtime/aws-smithy-client/src/test_connection.rs @@ -90,7 +90,7 @@ impl tower::Service> for CaptureRequestHandler { /// If response is `None`, it will reply with a 200 response with an empty body /// /// Example: -/// ```rust,compile_fail +/// ```compile_fail /// let (server, request) = capture_request(None); /// let conf = aws_sdk_sts::Config::builder() /// .http_connector(server) @@ -271,6 +271,347 @@ where } } +/// [`wire_mock`] contains utilities for mocking at the socket level +/// +/// Other tools in this module actually operate at the `http::Request` / `http::Response` level. This +/// is useful, but it shortcuts the HTTP implementation (e.g. Hyper). [`wire_mock::WireLevelTestConnection`] binds +/// to an actual socket on the host +/// +/// # Examples +/// ``` +/// use tower::layer::util::Identity; +/// use aws_smithy_client::http_connector::ConnectorSettings; +/// use aws_smithy_client::{match_events, ev}; +/// use aws_smithy_client::test_connection::wire_mock::check_matches; +/// # async fn example() { +/// use aws_smithy_client::test_connection::wire_mock::{ReplayedEvent, WireLevelTestConnection}; +/// // This connection binds to a local address +/// let mock = WireLevelTestConnection::spinup(vec![ +/// ReplayedEvent::status(503), +/// ReplayedEvent::status(200) +/// ]).await; +/// let client = aws_smithy_client::Client::builder() +/// .connector(mock.http_connector().connector(&ConnectorSettings::default(), None).unwrap()) +/// .middleware(Identity::new()) +/// .build(); +/// /* do something with */ +/// // assert that you got the events you expected +/// match_events!(ev!(dns), ev!(connect), ev!(http(200)))(&mock.events()); +/// # } +/// ``` +pub mod wire_mock { + use bytes::Bytes; + use http::{Request, Response}; + use hyper::client::connect::dns::Name; + use hyper::server::conn::AddrStream; + use hyper::service::{make_service_fn, service_fn}; + use hyper::{Body, Server}; + use std::collections::HashSet; + use std::convert::Infallible; + use std::error::Error; + + use hyper::client::HttpConnector as HyperHttpConnector; + use std::iter; + use std::iter::Once; + use std::net::{SocketAddr, TcpListener}; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll}; + + use tokio::spawn; + use tower::Service; + + /// An event recorded by [`WireLevelTestConnection`] + #[derive(Debug, Clone)] + pub enum RecordedEvent { + DnsLookup(String), + NewConnection, + Response(ReplayedEvent), + } + + type Matcher = ( + Box Result<(), Box>>, + &'static str, + ); + + /// This method should only be used by the macro + #[doc(hidden)] + pub fn check_matches(events: &[RecordedEvent], matchers: &[Matcher]) { + let mut events_iter = events.iter(); + let mut matcher_iter = matchers.iter(); + let mut idx = -1; + loop { + idx += 1; + let bail = |err: Box| panic!("failed on event {}:\n {}", idx, err); + match (events_iter.next(), matcher_iter.next()) { + (Some(event), Some((matcher, _msg))) => matcher(event).unwrap_or_else(bail), + (None, None) => return, + (Some(event), None) => { + bail(format!("got {:?} but no more events were expected", event).into()) + } + (None, Some((_expect, msg))) => { + bail(format!("expected {:?} but no more events were expected", msg).into()) + } + } + } + } + + #[macro_export] + macro_rules! matcher { + ($expect:tt) => { + ( + Box::new( + |event: &::aws_smithy_client::test_connection::wire_mock::RecordedEvent| { + if !matches!(event, $expect) { + return Err(format!( + "expected `{}` but got {:?}", + stringify!($expect), + event + ) + .into()); + } + Ok(()) + }, + ), + stringify!($expect), + ) + }; + } + + /// Helper macro to generate a series of test expectations + #[macro_export] + macro_rules! match_events { + ($( $expect:pat),*) => { + |events| { + check_matches(events, &[$( ::aws_smithy_client::matcher!($expect) ),*]); + } + }; + } + + /// Helper to generate match expressions for events + #[macro_export] + macro_rules! ev { + (http($status:expr)) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response( + ReplayedEvent::HttpResponse { + status: $status, + .. + }, + ) + }; + (dns) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::DnsLookup(_) + }; + (connect) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::NewConnection + }; + (timeout) => { + ::aws_smithy_client::test_connection::wire_mock::RecordedEvent::Response( + ReplayedEvent::Timeout, + ) + }; + } + + pub use {ev, match_events, matcher}; + + #[derive(Clone, Debug, PartialEq, Eq)] + pub enum ReplayedEvent { + Timeout, + HttpResponse { status: u16, body: Bytes }, + } + + impl ReplayedEvent { + pub fn ok() -> Self { + Self::HttpResponse { + status: 200, + body: Bytes::new(), + } + } + + pub fn with_body(body: &str) -> Self { + Self::HttpResponse { + status: 200, + body: Bytes::copy_from_slice(body.as_ref()), + } + } + + pub fn status(status: u16) -> Self { + Self::HttpResponse { + status, + body: Bytes::new(), + } + } + } + + use crate::erase::boxclone::BoxFuture; + use crate::http_connector::HttpConnector; + use crate::hyper_ext; + use aws_smithy_async::future::never::Never; + use tokio::sync::oneshot; + + /// Test connection that starts a server bound to 0.0.0.0 + /// + /// See the [module docs](crate::test_connection::wire_mock) for a usage example. + /// + /// Usage: + /// - Call [`WireLevelTestConnection::spinup`] to start the server + /// - Use [`WireLevelTestConnection::http_connector`] or [`dns_resolver`](WireLevelTestConnection::dns_resolver) to configure your client. + /// - Make requests to [`endpoint_url`](WireLevelTestConnection::endpoint_url). + /// - Once the test is complete, retrieve a list of events from [`WireLevelTestConnection::events`] + #[derive(Debug)] + pub struct WireLevelTestConnection { + event_log: Arc>>, + bind_addr: SocketAddr, + // when the sender is dropped, that stops the server + shutdown_hook: oneshot::Sender<()>, + } + + impl WireLevelTestConnection { + pub async fn spinup(mut response_events: Vec) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let (tx, rx) = oneshot::channel(); + let listener_addr = listener.local_addr().unwrap(); + response_events.reverse(); + let response_events = Arc::new(Mutex::new(response_events)); + let handler_events = response_events; + let wire_events = Arc::new(Mutex::new(vec![])); + let wire_log_for_service = wire_events.clone(); + let poisoned_conns: Arc>> = Default::default(); + let make_service = make_service_fn(move |connection: &AddrStream| { + let poisoned_conns = poisoned_conns.clone(); + let events = handler_events.clone(); + let wire_log = wire_log_for_service.clone(); + let remote_addr = connection.remote_addr(); + tracing::info!("established connection: {:?}", connection); + wire_log.lock().unwrap().push(RecordedEvent::NewConnection); + async move { + Ok::<_, Infallible>(service_fn(move |_: Request| { + if poisoned_conns.lock().unwrap().contains(&remote_addr) { + tracing::error!("poisoned connection {:?} was reused!", &remote_addr); + panic!("poisoned connection was reused!"); + } + let next_event = events.clone().lock().unwrap().pop(); + let wire_log = wire_log.clone(); + let poisoned_conns = poisoned_conns.clone(); + async move { + let next_event = next_event + .unwrap_or_else(|| panic!("no more events! Log: {:?}", wire_log)); + wire_log + .lock() + .unwrap() + .push(RecordedEvent::Response(next_event.clone())); + if next_event == ReplayedEvent::Timeout { + tracing::info!("{} is poisoned", remote_addr); + poisoned_conns.lock().unwrap().insert(remote_addr); + } + tracing::debug!("replying with {:?}", next_event); + let event = generate_response_event(next_event).await; + dbg!(event) + } + })) + } + }); + let server = Server::from_tcp(listener) + .unwrap() + .serve(make_service) + .with_graceful_shutdown(async { + rx.await.ok(); + tracing::info!("server shutdown!"); + }); + spawn(async move { server.await }); + Self { + event_log: wire_events, + bind_addr: listener_addr, + shutdown_hook: tx, + } + } + + /// Retrieve the events recorded by this connection + pub fn events(&self) -> Vec { + self.event_log.lock().unwrap().clone() + } + + fn bind_addr(&self) -> SocketAddr { + self.bind_addr + } + + pub fn dns_resolver(&self) -> LoggingDnsResolver { + let event_log = self.event_log.clone(); + let bind_addr = self.bind_addr; + LoggingDnsResolver { + log: event_log, + socket_addr: bind_addr, + } + } + + /// Prebuilt HTTP connector with correctly wired DNS resolver + /// + /// **Note**: This must be used in tandem with [`Self::dns_resolver`] + pub fn http_connector(&self) -> HttpConnector { + let http_connector = HyperHttpConnector::new_with_resolver(self.dns_resolver()); + hyper_ext::Adapter::builder().build(http_connector).into() + } + + /// Endpoint to use when connecting + /// + /// This works in tandem with the [`Self::dns_resolver`] to bind to the correct local IP Address + pub fn endpoint_url(&self) -> String { + format!( + "http://this-url-is-converted-to-localhost.com:{}", + self.bind_addr().port() + ) + } + + pub fn shutdown(self) { + let _ = self.shutdown_hook.send(()); + } + } + + async fn generate_response_event(event: ReplayedEvent) -> Result, Infallible> { + let resp = match event { + ReplayedEvent::HttpResponse { status, body } => http::Response::builder() + .status(status) + .body(hyper::Body::from(body)) + .unwrap(), + ReplayedEvent::Timeout => { + Never::new().await; + unreachable!() + } + }; + Ok::<_, Infallible>(resp) + } + + /// DNS resolver that keeps a log of all lookups + /// + /// Regardless of what hostname is requested, it will always return the same socket address. + #[derive(Clone, Debug)] + pub struct LoggingDnsResolver { + log: Arc>>, + socket_addr: SocketAddr, + } + + impl Service for LoggingDnsResolver { + type Response = Once; + type Error = Infallible; + type Future = BoxFuture; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Name) -> Self::Future { + let sock_addr = self.socket_addr; + let log = self.log.clone(); + Box::pin(async move { + println!("looking up {:?}, replying with {:?}", req, sock_addr); + log.lock() + .unwrap() + .push(RecordedEvent::DnsLookup(req.to_string())); + Ok(iter::once(sock_addr)) + }) + } + } +} + #[cfg(test)] mod tests { use hyper::service::Service; diff --git a/rust-runtime/aws-smithy-client/src/timeout.rs b/rust-runtime/aws-smithy-client/src/timeout.rs index 4cfc1938f6..85957eb11e 100644 --- a/rust-runtime/aws-smithy-client/src/timeout.rs +++ b/rust-runtime/aws-smithy-client/src/timeout.rs @@ -208,7 +208,7 @@ where InnerService: tower::Service, Error = SdkError>, { type Response = InnerService::Response; - type Error = aws_smithy_http::result::SdkError; + type Error = SdkError; type Future = TimeoutServiceFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { diff --git a/rust-runtime/aws-smithy-client/tests/e2e_test.rs b/rust-runtime/aws-smithy-client/tests/e2e_test.rs index 99b689d430..0a8594d6b3 100644 --- a/rust-runtime/aws-smithy-client/tests/e2e_test.rs +++ b/rust-runtime/aws-smithy-client/tests/e2e_test.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +mod test_operation; use crate::test_operation::{TestOperationParser, TestRetryClassifier}; use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_client::test_connection::TestConnection; @@ -15,78 +16,6 @@ use std::sync::Arc; use std::time::Duration; use tower::layer::util::Identity; -mod test_operation { - use aws_smithy_http::operation; - use aws_smithy_http::response::ParseHttpResponse; - use aws_smithy_http::result::SdkError; - use aws_smithy_http::retry::ClassifyRetry; - use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind}; - use bytes::Bytes; - use std::error::Error; - use std::fmt::{self, Debug, Display, Formatter}; - - #[derive(Clone)] - pub(super) struct TestOperationParser; - - #[derive(Debug)] - pub(super) struct OperationError; - - impl Display for OperationError { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "{:?}", self) - } - } - - impl Error for OperationError {} - - impl ProvideErrorKind for OperationError { - fn retryable_error_kind(&self) -> Option { - Some(ErrorKind::ThrottlingError) - } - - fn code(&self) -> Option<&str> { - None - } - } - - impl ParseHttpResponse for TestOperationParser { - type Output = Result; - - fn parse_unloaded(&self, response: &mut operation::Response) -> Option { - if response.http().status().is_success() { - Some(Ok("Hello!".to_string())) - } else { - Some(Err(OperationError)) - } - } - - fn parse_loaded(&self, _response: &http::Response) -> Self::Output { - Ok("Hello!".to_string()) - } - } - - #[derive(Clone)] - pub(super) struct TestRetryClassifier; - - impl ClassifyRetry> for TestRetryClassifier - where - E: ProvideErrorKind + Debug, - T: Debug, - { - fn classify_retry(&self, err: Result<&T, &SdkError>) -> RetryKind { - let kind = match err { - Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(), - Ok(_) => return RetryKind::Unnecessary, - _ => panic!("test handler only handles modeled errors got: {:?}", err), - }; - match kind { - Some(kind) => RetryKind::Error(kind), - None => RetryKind::UnretryableFailure, - } - } - } -} - fn test_operation() -> Operation { let req = operation::Request::new( http::Request::builder() @@ -108,14 +37,14 @@ async fn end_to_end_retry_test() { fn ok() -> http::Response<&'static str> { http::Response::builder() .status(200) - .body("response body") + .body("Hello!") .unwrap() } fn err() -> http::Response<&'static str> { http::Response::builder() .status(500) - .body("response body") + .body("This was an error") .unwrap() } // 1 failing response followed by 1 successful response diff --git a/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs b/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs new file mode 100644 index 0000000000..475c076b12 --- /dev/null +++ b/rust-runtime/aws-smithy-client/tests/reconnect_on_transient_error.rs @@ -0,0 +1,230 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#![cfg(feature = "test-util")] + +mod test_operation; + +use aws_smithy_async::rt::sleep::TokioSleep; +use aws_smithy_client::test_connection::wire_mock; +use aws_smithy_client::test_connection::wire_mock::{check_matches, RecordedEvent, ReplayedEvent}; +use aws_smithy_client::{hyper_ext, Builder}; +use aws_smithy_client::{match_events, Client}; +use aws_smithy_http::body::SdkBody; +use aws_smithy_http::operation; +use aws_smithy_http::operation::Operation; +use aws_smithy_types::retry::ReconnectMode; +use aws_smithy_types::timeout::{OperationTimeoutConfig, TimeoutConfig}; +use http::Uri; +use http_body::combinators::BoxBody; +use hyper::client::{Builder as HyperBuilder, HttpConnector}; +use std::convert::Infallible; +use std::sync::Arc; +use std::time::Duration; +use test_operation::{TestOperationParser, TestRetryClassifier}; +use tower::layer::util::Identity; +use wire_mock::ev; + +fn end_of_test() -> &'static str { + "end_of_test" +} + +fn test_operation( + uri: Uri, + retryable: bool, +) -> Operation { + let mut req = operation::Request::new( + http::Request::builder() + .uri(uri) + .body(SdkBody::from("request body")) + .unwrap(), + ); + if !retryable { + req = req + .augment(|req, _conf| { + Ok::<_, Infallible>( + req.map(|_| SdkBody::from_dyn(BoxBody::new(SdkBody::from("body")))), + ) + }) + .unwrap(); + } + Operation::new(req, TestOperationParser).with_retry_classifier(TestRetryClassifier) +} + +async fn h1_and_h2(events: Vec, match_clause: impl Fn(&[RecordedEvent])) { + wire_level_test(events.clone(), |_b| {}, |b| b, &match_clause).await; + wire_level_test( + events, + |b| { + b.http2_only(true); + }, + |b| b, + match_clause, + ) + .await; + println!("h2 ok!"); +} + +/// Repeatedly send test operation until `end_of_test` is received +/// +/// When the test is over, match_clause is evaluated +async fn wire_level_test( + events: Vec, + hyper_builder_settings: impl Fn(&mut HyperBuilder), + client_builder_settings: impl Fn(Builder) -> Builder, + match_clause: impl Fn(&[RecordedEvent]), +) { + let connection = wire_mock::WireLevelTestConnection::spinup(events).await; + + let http_connector = HttpConnector::new_with_resolver(connection.dns_resolver()); + let mut hyper_builder = hyper::Client::builder(); + hyper_builder_settings(&mut hyper_builder); + let hyper_adapter = hyper_ext::Adapter::builder() + .hyper_builder(hyper_builder) + .build(http_connector); + let client = client_builder_settings( + Client::builder().reconnect_mode(ReconnectMode::ReconnectOnTransientError), + ) + .connector(hyper_adapter) + .middleware(Identity::new()) + .operation_timeout_config(OperationTimeoutConfig::from( + &TimeoutConfig::builder() + .operation_attempt_timeout(Duration::from_millis(100)) + .build(), + )) + .sleep_impl(Arc::new(TokioSleep::new())) + .build(); + loop { + match client + .call(test_operation( + connection.endpoint_url().parse().unwrap(), + false, + )) + .await + { + Ok(resp) => { + tracing::info!("response: {:?}", resp); + if resp == end_of_test() { + break; + } + } + Err(e) => tracing::info!("error: {:?}", e), + } + } + let events = connection.events(); + match_clause(&events); +} + +#[tokio::test] +async fn non_transient_errors_no_reconect() { + h1_and_h2( + vec![ + ReplayedEvent::status(400), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!(ev!(dns), ev!(connect), ev!(http(400)), ev!(http(200))), + ) + .await +} + +#[tokio::test] +async fn reestablish_dns_on_503() { + h1_and_h2( + vec![ + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + // first request + ev!(dns), + ev!(connect), + ev!(http(503)), + // second request + ev!(dns), + ev!(connect), + ev!(http(503)), + // third request + ev!(dns), + ev!(connect), + ev!(http(503)), + // all good + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} + +#[tokio::test] +async fn connection_shared_on_success() { + h1_and_h2( + vec![ + ReplayedEvent::ok(), + ReplayedEvent::ok(), + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + ev!(dns), + ev!(connect), + ev!(http(200)), + ev!(http(200)), + ev!(http(503)), + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} + +#[tokio::test] +async fn no_reconnect_when_disabled() { + use wire_mock::ev; + wire_level_test( + vec![ + ReplayedEvent::status(503), + ReplayedEvent::with_body(end_of_test()), + ], + |_b| {}, + |b| b.reconnect_mode(ReconnectMode::ReuseAllConnections), + match_events!(ev!(dns), ev!(connect), ev!(http(503)), ev!(http(200))), + ) + .await; +} + +#[tokio::test] +async fn connection_reestablished_after_timeout() { + use wire_mock::ev; + h1_and_h2( + vec![ + ReplayedEvent::ok(), + ReplayedEvent::Timeout, + ReplayedEvent::ok(), + ReplayedEvent::Timeout, + ReplayedEvent::with_body(end_of_test()), + ], + match_events!( + // first connection + ev!(dns), + ev!(connect), + ev!(http(200)), + // reuse but got a timeout + ev!(timeout), + // so we reconnect + ev!(dns), + ev!(connect), + ev!(http(200)), + ev!(timeout), + ev!(dns), + ev!(connect), + ev!(http(200)) + ), + ) + .await; +} diff --git a/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs b/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs new file mode 100644 index 0000000000..db193e4bd9 --- /dev/null +++ b/rust-runtime/aws-smithy-client/tests/test_operation/mod.rs @@ -0,0 +1,84 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use aws_smithy_http::operation; +use aws_smithy_http::response::ParseHttpResponse; +use aws_smithy_http::result::SdkError; +use aws_smithy_http::retry::ClassifyRetry; +use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind, RetryKind}; +use bytes::Bytes; +use std::error::Error; +use std::fmt::{self, Debug, Display, Formatter}; +use std::str; + +#[derive(Clone)] +pub(super) struct TestOperationParser; + +#[derive(Debug)] +pub(super) struct OperationError(ErrorKind); + +impl Display for OperationError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Error for OperationError {} + +impl ProvideErrorKind for OperationError { + fn retryable_error_kind(&self) -> Option { + Some(self.0) + } + + fn code(&self) -> Option<&str> { + None + } +} + +impl ParseHttpResponse for TestOperationParser { + type Output = Result; + + fn parse_unloaded(&self, response: &mut operation::Response) -> Option { + tracing::debug!("got response: {:?}", response); + match response.http().status() { + s if s.is_success() => None, + s if s.is_client_error() => Some(Err(OperationError(ErrorKind::ServerError))), + s if s.is_server_error() => Some(Err(OperationError(ErrorKind::TransientError))), + _ => panic!("unexpected status: {}", response.http().status()), + } + } + + fn parse_loaded(&self, response: &http::Response) -> Self::Output { + Ok(str::from_utf8(response.body().as_ref()) + .unwrap() + .to_string()) + } +} + +#[derive(Clone)] +pub(super) struct TestRetryClassifier; + +impl ClassifyRetry> for TestRetryClassifier +where + E: ProvideErrorKind + Debug, + T: Debug, +{ + fn classify_retry(&self, err: Result<&T, &SdkError>) -> RetryKind { + tracing::info!("got response: {:?}", err); + let kind = match err { + Err(SdkError::ServiceError(context)) => context.err().retryable_error_kind(), + Err(SdkError::DispatchFailure(err)) if err.is_timeout() => { + Some(ErrorKind::TransientError) + } + Err(SdkError::TimeoutError(_)) => Some(ErrorKind::TransientError), + Ok(_) => return RetryKind::Unnecessary, + _ => panic!("test handler only handles modeled errors got: {:?}", err), + }; + match kind { + Some(kind) => RetryKind::Error(kind), + None => RetryKind::UnretryableFailure, + } + } +} diff --git a/rust-runtime/aws-smithy-http-tower/src/dispatch.rs b/rust-runtime/aws-smithy-http-tower/src/dispatch.rs index 8a1119d61b..a10693a62b 100644 --- a/rust-runtime/aws-smithy-http-tower/src/dispatch.rs +++ b/rust-runtime/aws-smithy-http-tower/src/dispatch.rs @@ -5,6 +5,7 @@ use crate::SendOperationError; use aws_smithy_http::body::SdkBody; +use aws_smithy_http::connection::CaptureSmithyConnection; use aws_smithy_http::operation; use aws_smithy_http::result::ConnectorError; use std::future::Future; @@ -41,7 +42,13 @@ where } fn call(&mut self, req: operation::Request) -> Self::Future { - let (req, property_bag) = req.into_parts(); + let (mut req, property_bag) = req.into_parts(); + // copy the smithy connection + if let Some(smithy_conn) = property_bag.acquire().get::() { + req.extensions_mut().insert(smithy_conn.clone()); + } else { + println!("nothing to copy!"); + } let mut inner = self.inner.clone(); let future = async move { trace!(request = ?req, "dispatching request"); diff --git a/rust-runtime/aws-smithy-http/src/connection.rs b/rust-runtime/aws-smithy-http/src/connection.rs new file mode 100644 index 0000000000..4649b0e8ee --- /dev/null +++ b/rust-runtime/aws-smithy-http/src/connection.rs @@ -0,0 +1,96 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +use std::fmt::{Debug, Formatter}; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +#[derive(Clone)] +pub struct ConnectionMetadata { + is_proxied: bool, + remote_addr: Option, + poison_fn: Arc, +} + +impl Debug for ConnectionMetadata { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SmithyConnection") + .field("is_proxied", &self.is_proxied) + .field("remote_addr", &self.remote_addr) + .finish() + } +} + +type LoaderFn = dyn Fn() -> Option + Send + Sync; + +#[derive(Clone, Default)] +pub struct CaptureSmithyConnection { + loader: Arc>>>, +} + +impl CaptureSmithyConnection { + pub fn new() -> Self { + Self { + loader: Default::default(), + } + } + pub fn set_connection_retriever(&self, f: F) + where + F: Fn() -> Option + Send + Sync + 'static, + { + *self.loader.lock().unwrap() = Some(Box::new(f)); + } + + pub fn get(&self) -> Option { + match self.loader.lock().unwrap().as_ref() { + Some(loader) => loader(), + None => { + println!("no loader was set :-/"); + None + } + } + } +} + +impl ConnectionMetadata { + pub fn poison(&self) { + tracing::info!("smithy connection was poisoned"); + (self.poison_fn)() + } +} + +impl ConnectionMetadata { + pub fn new( + is_proxied: bool, + remote_addr: Option, + poison: impl Fn() + Send + Sync + 'static, + ) -> Self { + Self { + is_proxied, + remote_addr, + poison_fn: Arc::new(poison), + } + } + + pub fn remote_addr(&self) -> Option { + self.remote_addr + } +} + +#[cfg(test)] +mod test { + use crate::connection::{CaptureSmithyConnection, ConnectionMetadata}; + + #[test] + fn retrieve_connection_metadata() { + let retriever = CaptureSmithyConnection::new(); + let retriever_clone = retriever.clone(); + assert!(retriever.get().is_none()); + retriever.set_connection_retriever(|| Some(ConnectionMetadata::new(true, None, || {}))); + + assert!(retriever.get().is_some()); + assert!(retriever_clone.get().is_some()); + } +} diff --git a/rust-runtime/aws-smithy-http/src/lib.rs b/rust-runtime/aws-smithy-http/src/lib.rs index f777e15c82..77156efc2f 100644 --- a/rust-runtime/aws-smithy-http/src/lib.rs +++ b/rust-runtime/aws-smithy-http/src/lib.rs @@ -39,4 +39,5 @@ pub mod event_stream; pub mod byte_stream; +pub mod connection; mod urlencode; diff --git a/rust-runtime/aws-smithy-http/src/result.rs b/rust-runtime/aws-smithy-http/src/result.rs index f11dcc2d36..a3fdfcc761 100644 --- a/rust-runtime/aws-smithy-http/src/result.rs +++ b/rust-runtime/aws-smithy-http/src/result.rs @@ -12,6 +12,7 @@ //! `Result` wrapper types for [success](SdkSuccess) and [failure](SdkError) responses. +use crate::connection::ConnectionMetadata; use crate::operation; use aws_smithy_types::error::metadata::{ProvideErrorMetadata, EMPTY_ERROR_METADATA}; use aws_smithy_types::error::ErrorMetadata; @@ -240,6 +241,11 @@ impl DispatchFailure { pub fn is_other(&self) -> Option { self.source.is_other() } + + /// Returns the inner error if it is a connector error + pub fn as_connector_error(&self) -> Option<&ConnectorError> { + Some(&self.source) + } } /// Error context for [`SdkError::ResponseError`] @@ -505,6 +511,22 @@ enum ConnectorErrorKind { pub struct ConnectorError { kind: ConnectorErrorKind, source: BoxError, + connection: ConnectionStatus, +} + +#[non_exhaustive] +#[derive(Debug)] +pub(crate) enum ConnectionStatus { + /// This request was never connected to the remote + /// + /// This indicates the failure was during connection establishment + NeverConnected, + + /// It is unknown whether a connection was established + Unknown, + + /// The request connected to the remote prior to failure + Connected(ConnectionMetadata), } impl Display for ConnectorError { @@ -532,14 +554,28 @@ impl ConnectorError { Self { kind: ConnectorErrorKind::Timeout, source, + connection: ConnectionStatus::Unknown, } } + /// Include connection information along with this error + pub fn with_connection(mut self, info: ConnectionMetadata) -> Self { + self.connection = ConnectionStatus::Connected(info); + self + } + + /// Set the connection status on this error to report that a connection was never established + pub fn never_connected(mut self) -> Self { + self.connection = ConnectionStatus::NeverConnected; + self + } + /// Construct a [`ConnectorError`] from an error caused by the user (e.g. invalid HTTP request) pub fn user(source: BoxError) -> Self { Self { kind: ConnectorErrorKind::User, source, + connection: ConnectionStatus::Unknown, } } @@ -548,6 +584,7 @@ impl ConnectorError { Self { kind: ConnectorErrorKind::Io, source, + connection: ConnectionStatus::Unknown, } } @@ -558,6 +595,7 @@ impl ConnectorError { Self { source, kind: ConnectorErrorKind::Other(kind), + connection: ConnectionStatus::Unknown, } } @@ -583,4 +621,16 @@ impl ConnectorError { _ => None, } } + + /// Returns metadata about the connection + /// + /// If a connection was established and provided by the internal connector, a connection will + /// be returned. + pub fn connection_metadata(&self) -> Option<&ConnectionMetadata> { + match &self.connection { + ConnectionStatus::NeverConnected => None, + ConnectionStatus::Unknown => None, + ConnectionStatus::Connected(conn) => Some(conn), + } + } } diff --git a/rust-runtime/aws-smithy-types/src/retry.rs b/rust-runtime/aws-smithy-types/src/retry.rs index 43be79cae4..b96ababf06 100644 --- a/rust-runtime/aws-smithy-types/src/retry.rs +++ b/rust-runtime/aws-smithy-types/src/retry.rs @@ -143,6 +143,7 @@ pub struct RetryConfigBuilder { mode: Option, max_attempts: Option, initial_backoff: Option, + reconnect_mode: Option, } impl RetryConfigBuilder { @@ -163,6 +164,30 @@ impl RetryConfigBuilder { self } + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + /// + /// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead. + pub fn reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self { + self.set_reconnect_mode(Some(reconnect_mode)); + self + } + + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + /// + /// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead. + pub fn set_reconnect_mode(&mut self, reconnect_mode: Option) -> &mut Self { + self.reconnect_mode = reconnect_mode; + self + } + /// Sets the max attempts. This value must be greater than zero. pub fn set_max_attempts(&mut self, max_attempts: Option) -> &mut Self { self.max_attempts = max_attempts; @@ -208,6 +233,7 @@ impl RetryConfigBuilder { mode: self.mode.or(other.mode), max_attempts: self.max_attempts.or(other.max_attempts), initial_backoff: self.initial_backoff.or(other.initial_backoff), + reconnect_mode: self.reconnect_mode.or(other.reconnect_mode), } } @@ -219,6 +245,9 @@ impl RetryConfigBuilder { initial_backoff: self .initial_backoff .unwrap_or_else(|| Duration::from_secs(1)), + reconnect_mode: self + .reconnect_mode + .unwrap_or(ReconnectMode::ReconnectOnTransientError), } } } @@ -230,6 +259,23 @@ pub struct RetryConfig { mode: RetryMode, max_attempts: u32, initial_backoff: Duration, + reconnect_mode: ReconnectMode, +} + +/// Mode for connection re-establishment +/// +/// By default, when a transient error is encountered, the connection in use will be poisoned. This +/// behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead. +#[derive(Debug, Clone, PartialEq, Copy)] +pub enum ReconnectMode { + /// Reconnect on [`ErrorKind::TransientError`] + ReconnectOnTransientError, + + /// Disable reconnect on error + /// + /// When this setting is applied, 503s, timeouts, and other transient errors will _not_ + /// lead to a new connection being established unless the connection is closed by the remote. + ReuseAllConnections, } impl RetryConfig { @@ -239,6 +285,7 @@ impl RetryConfig { mode: RetryMode::Standard, max_attempts: 3, initial_backoff: Duration::from_secs(1), + reconnect_mode: ReconnectMode::ReconnectOnTransientError, } } @@ -260,6 +307,18 @@ impl RetryConfig { self } + /// Set the [`ReconnectMode`] for the retry strategy + /// + /// By default, when a transient error is encountered, the connection in use will be poisoned. + /// This prevents reusing a connection to a potentially bad host but may increase the load on + /// the server. + /// + /// This behavior can be disabled by setting [`ReconnectMode::ReuseAllConnections`] instead. + pub fn with_reconnect_mode(mut self, reconnect_mode: ReconnectMode) -> Self { + self.reconnect_mode = reconnect_mode; + self + } + /// Set the multiplier used when calculating backoff times as part of an /// [exponential backoff with jitter](https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/) /// strategy. Most services should work fine with the default duration of 1 second, but if you @@ -287,6 +346,11 @@ impl RetryConfig { self.mode } + /// Returns the [`ReconnectMode`] + pub fn reconnect_mode(&self) -> ReconnectMode { + self.reconnect_mode + } + /// Returns the max attempts. pub fn max_attempts(&self) -> u32 { self.max_attempts