Skip to content

Commit

Permalink
Make SdkError::into_service_error infallible (#1974)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdisanti authored Nov 16, 2022
1 parent 4563849 commit 9f0bc36
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.RetryableTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.asType
import software.amazon.smithy.rust.codegen.core.rustlang.deprecatedShape
import software.amazon.smithy.rust.codegen.core.rustlang.documentShape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand Down Expand Up @@ -124,6 +126,7 @@ class CombinedErrorGenerator(
) {
private val runtimeConfig = symbolProvider.config().runtimeConfig
private val genericError = RuntimeType.GenericError(symbolProvider.config().runtimeConfig)
private val createUnhandledError = CargoDependency.SmithyHttp(runtimeConfig).asType().member("result::CreateUnhandledError")

fun render(writer: RustWriter) {
val errorSymbol = RuntimeType("${operationSymbol.name}Error", null, "crate::error")
Expand Down Expand Up @@ -154,6 +157,15 @@ class CombinedErrorGenerator(
RuntimeType.GenericError(runtimeConfig),
)
}
writer.rustBlock("impl #T for ${errorSymbol.name}", createUnhandledError) {
rustBlock("fn create_unhandled_error(source: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self") {
rustBlock("Self") {
rust("kind: ${errorSymbol.name}Kind::Unhandled(#T::new(source)),", unhandledError())
rust("meta: Default::default()")
}
}
}

writer.rust("/// Types of errors that can occur for the `${operationSymbol.name}` operation.")
meta.render(writer)
writer.rustBlock("enum ${errorSymbol.name}Kind") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class TopLevelErrorGenerator(private val codegenContext: CodegenContext, private

private fun RustWriter.renderImplFrom(errorSymbol: RuntimeType, errors: List<ShapeId>) {
if (errors.isNotEmpty() || CodegenTarget.CLIENT == codegenContext.target) {
val operationErrors = errors.map { model.expectShape(it) }
rustBlock(
"impl<R> From<#T<#T, R>> for Error where R: Send + Sync + std::fmt::Debug + 'static",
sdkError,
Expand All @@ -106,22 +107,27 @@ class TopLevelErrorGenerator(private val codegenContext: CodegenContext, private
"OpError" to errorSymbol,
) {
rustBlock("match err") {
val operationErrors = errors.map { model.expectShape(it) }
rustBlock("#T::ServiceError(context) => match context.into_err().kind", sdkError) {
operationErrors.forEach { errorShape ->
val errSymbol = symbolProvider.toSymbol(errorShape)
rust(
"#TKind::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),",
errorSymbol,
)
}
rustTemplate(
"#{errorSymbol}Kind::Unhandled(inner) => Error::Unhandled(#{unhandled}::new(inner.into())),",
"errorSymbol" to errorSymbol,
"unhandled" to unhandledError(),
rust("#T::ServiceError(context) => Self::from(context.into_err()),", sdkError)
rust("_ => Error::Unhandled(#T::new(err.into())),", unhandledError())
}
}
}

rustBlock("impl From<#T> for Error", errorSymbol) {
rustBlock("fn from(err: #T) -> Self", errorSymbol) {
rustBlock("match err.kind") {
operationErrors.forEach { errorShape ->
val errSymbol = symbolProvider.toSymbol(errorShape)
rust(
"#TKind::${errSymbol.name}(inner) => Error::${errSymbol.name}(inner),",
errorSymbol,
)
}
rust("_ => Error::Unhandled(#T::new(err.into())),", unhandledError())
rustTemplate(
"#{errorSymbol}Kind::Unhandled(inner) => Error::Unhandled(#{unhandled}::new(inner.into())),",
"errorSymbol" to errorSymbol,
"unhandled" to unhandledError(),
)
}
}
}
Expand Down
38 changes: 27 additions & 11 deletions rust-runtime/aws-smithy-http/src/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ impl<E, R> ServiceError<E, R> {
}
}

/// Constructs the unhandled variant of a code generated error.
///
/// This trait exists so that [`SdkError::into_service_error`] can be infallible.
pub trait CreateUnhandledError {
/// Creates an unhandled error variant with the given `source`.
fn create_unhandled_error(source: Box<dyn Error + Send + Sync + 'static>) -> Self;
}

/// Failed SDK Result
#[non_exhaustive]
#[derive(Debug)]
Expand Down Expand Up @@ -179,34 +187,42 @@ impl<E, R> SdkError<E, R> {

/// Returns the underlying service error `E` if there is one
///
/// If a service error is not available (for example, the error is a network timeout),
/// then the full `SdkError` is returned. This makes it easy to match on the service's
/// error response while simultaneously bubbling up transient failures. For example,
/// handling the `NoSuchKey` error for S3's `GetObject` operation may look as follows:
/// If the `SdkError` is not a `ServiceError` (for example, the error is a network timeout),
/// then it will be converted into an unhandled variant of `E`. This makes it easy to match
/// on the service's error response while simultaneously bubbling up transient failures.
/// For example, handling the `NoSuchKey` error for S3's `GetObject` operation may look as
/// follows:
///
/// ```no_run
/// # use aws_smithy_http::result::SdkError;
/// # use aws_smithy_http::result::{SdkError, CreateUnhandledError};
/// # #[derive(Debug)] enum GetObjectErrorKind { NoSuchKey(()), Other(()) }
/// # #[derive(Debug)] struct GetObjectError { kind: GetObjectErrorKind }
/// # impl std::fmt::Display for GetObjectError {
/// # fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { unimplemented!() }
/// # }
/// # impl std::error::Error for GetObjectError {}
/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
/// # impl CreateUnhandledError for GetObjectError {
/// # fn create_unhandled_error(_: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self { unimplemented!() }
/// # }
/// # fn example() -> Result<(), GetObjectError> {
/// # let sdk_err = SdkError::service_error(GetObjectError { kind: GetObjectErrorKind::NoSuchKey(()) }, ());
/// match sdk_err.into_service_error()? {
/// match sdk_err.into_service_error() {
/// GetObjectError { kind: GetObjectErrorKind::NoSuchKey(_) } => {
/// // handle NoSuchKey
/// }
/// err @ _ => return Err(err.into()),
/// err @ _ => return Err(err),
/// }
/// # Ok(())
/// # }
/// ```
pub fn into_service_error(self) -> Result<E, Self> {
pub fn into_service_error(self) -> E
where
E: std::error::Error + Send + Sync + CreateUnhandledError + 'static,
R: Debug + Send + Sync + 'static,
{
match self {
Self::ServiceError(context) => Ok(context.source),
_ => Err(self),
Self::ServiceError(context) => context.source,
_ => E::create_unhandled_error(self.into()),
}
}

Expand Down
5 changes: 1 addition & 4 deletions tools/ci-cdk/canary-lambda/src/s3_canary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ pub async fn s3_canary(client: s3::Client, s3_bucket_name: String) -> anyhow::Re
CanaryError(format!("Expected object {} to not exist in S3", test_key)).into(),
);
}
Err(err) => match err
.into_service_error()
.context("unexpected s3::GetObject failure")?
{
Err(err) => match err.into_service_error() {
GetObjectError {
kind: GetObjectErrorKind::NoSuchKey(..),
..
Expand Down

0 comments on commit 9f0bc36

Please sign in to comment.