Skip to content

Commit

Permalink
Allow <Error> to trigger error handling for S3 (#2958)
Browse files Browse the repository at this point in the history
## Motivation and Context
- awslabs/aws-sdk-rust#873

## Description
Add a customization for S3 so that if `<Error>` is the root element, we
trigger the error parsing flow

## Testing
- Added an integration test

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the
smithy-rs codegen or runtime crates
- [x] I have updated `CHANGELOG.next.toml` if I made changes to the AWS
SDK, generated SDK code, or SDK runtime crates

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
rcoh authored Aug 30, 2023
1 parent 9f39607 commit 2db1c34
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ references = ["smithy-rs#2948"]
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "all" }
author = "Velfi"

[[aws-sdk-rust]]
message = "Correctly identify HTTP 200 responses from S3 with `<Error>` as the root Element as errors. **Note**: This a behavior change and will change the error type returned by the SDK in some cases."
references = ["smithy-rs#2958", "aws-sdk-rust#873"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "rcoh"

[[aws-sdk-rust]]
message = "Allow `no_credentials` to be used with all S3 operations."
references = ["smithy-rs#2955", "aws-sdk-rust#878"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustSettings
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustomization
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationGenerator
import software.amazon.smithy.rust.codegen.client.smithy.generators.OperationSection
import software.amazon.smithy.rust.codegen.client.smithy.protocols.ClientRestXmlFactory
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
Expand Down Expand Up @@ -106,6 +108,34 @@ class S3Decorator : ClientCodegenDecorator {
)
}

override fun operationCustomizations(
codegenContext: ClientCodegenContext,
operation: OperationShape,
baseCustomizations: List<OperationCustomization>,
): List<OperationCustomization> {
return baseCustomizations + object : OperationCustomization() {
override fun section(section: OperationSection): Writable {
return writable {
when (section) {
is OperationSection.BeforeParseResponse -> {
section.body?.also { body ->
rustTemplate(
"""
if matches!(#{errors}::body_is_error($body), Ok(true)) {
${section.forceError} = true;
}
""",
"errors" to RuntimeType.unwrappedXmlErrors(codegenContext.runtimeConfig),
)
}
}
else -> {}
}
}
}
}
}

private fun isInInvalidXmlRootAllowList(shape: Shape): Boolean {
return shape.isStructureShape && invalidXmlRootAllowList.contains(shape.id)
}
Expand All @@ -115,15 +145,15 @@ class FilterEndpointTests(
private val testFilter: (EndpointTestCase) -> EndpointTestCase? = { a -> a },
private val operationInputFilter: (EndpointTestOperationInput) -> EndpointTestOperationInput? = { a -> a },
) {
fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
private fun updateEndpointTests(endpointTests: List<EndpointTestCase>): List<EndpointTestCase> {
val filteredTests = endpointTests.mapNotNull { test -> testFilter(test) }
return filteredTests.map { test ->
val operationInputs = test.operationInputs
test.toBuilder().operationInputs(operationInputs.mapNotNull { operationInputFilter(it) }).build()
}
}

fun transform(model: Model) = ModelTransformer.create().mapTraits(model) { _, trait ->
fun transform(model: Model): Model = ModelTransformer.create().mapTraits(model) { _, trait ->
when (trait) {
is EndpointTestsTrait -> EndpointTestsTrait.builder().testCases(updateEndpointTests(trait.testCases))
.version(trait.version).build()
Expand All @@ -135,7 +165,7 @@ class FilterEndpointTests(

// TODO(P96049742): This model transform may need to change depending on if and how the S3 model is updated.
private class AddOptionalAuth {
fun transform(model: Model) = ModelTransformer.create().mapShapes(model) { shape ->
fun transform(model: Model): Model = ModelTransformer.create().mapShapes(model) { shape ->
// Add @optionalAuth to all S3 operations
if (shape is OperationShape && !shape.hasTrait<OptionalAuthTrait>()) {
shape.toBuilder()
Expand Down
40 changes: 40 additions & 0 deletions aws/sdk/integration-tests/s3/tests/status-200-errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
use aws_sdk_s3::Client;
use aws_smithy_client::test_connection::infallible_connection_fn;
use aws_smithy_http::body::SdkBody;
use aws_smithy_types::error::metadata::ProvideErrorMetadata;
use aws_types::region::Region;
use aws_types::SdkConfig;

const ERROR_RESPONSE: &str = r#"<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"#;

#[tokio::test]
async fn status_200_errors() {
let conn = infallible_connection_fn(|_req| http::Response::new(SdkBody::from(ERROR_RESPONSE)));
let sdk_config = SdkConfig::builder()
.credentials_provider(SharedCredentialsProvider::new(Credentials::for_tests()))
.region(Region::new("us-west-4"))
.http_connector(conn)
.build();
let client = Client::new(&sdk_config);
let error = client
.delete_objects()
.bucket("bucket")
.send()
.await
.expect_err("should fail");
assert_eq!(error.into_service_error().code(), Some("SlowDown"));
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ sealed class OperationSection(name: String) : Section(name) {
data class BeforeParseResponse(
override val customizations: List<OperationCustomization>,
val responseName: String,
/**
* Name of the `force_error` variable. Set this to true to trigger error parsing.
*/
val forceError: String,
/**
* When set, the name of the response body data field
*/
val body: String?,
) : OperationSection("BeforeParseResponse")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ class ResponseDeserializerGenerator(
rustTemplate(
"""
fn deserialize_streaming(&self, response: &mut #{HttpResponse}) -> #{Option}<#{OutputOrError}> {
##[allow(unused_mut)]
let mut force_error = false;
#{BeforeParseResponse}
// If this is an error, defer to the non-streaming parser
if !response.status().is_success() && response.status().as_u16() != $successCode {
if (!response.status().is_success() && response.status().as_u16() != $successCode) || force_error {
return #{None};
}
#{Some}(#{type_erase_result}(#{parse_streaming_response}(response)))
Expand All @@ -106,7 +108,7 @@ class ResponseDeserializerGenerator(
*codegenScope,
"parse_streaming_response" to parserGenerator.parseStreamingResponseFn(operationShape, customizations),
"BeforeParseResponse" to writable {
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", body = null))
},
)
}
Expand Down Expand Up @@ -136,8 +138,10 @@ class ResponseDeserializerGenerator(
let (success, status) = (response.status().is_success(), response.status().as_u16());
let headers = response.headers();
let body = response.body().bytes().expect("body loaded");
##[allow(unused_mut)]
let mut force_error = false;
#{BeforeParseResponse}
let parse_result = if !success && status != $successCode {
let parse_result = if !success && status != $successCode || force_error {
#{parse_error}(status, headers, body)
} else {
#{parse_response}(status, headers, body)
Expand All @@ -148,7 +152,7 @@ class ResponseDeserializerGenerator(
"parse_error" to parserGenerator.parseErrorFn(operationShape, customizations),
"parse_response" to parserGenerator.parseResponseFn(operationShape, customizations),
"BeforeParseResponse" to writable {
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response"))
writeCustomizations(customizations, OperationSection.BeforeParseResponse(customizations, "response", "force_error", "body"))
},
)
}
Expand Down

0 comments on commit 2db1c34

Please sign in to comment.