Skip to content

Commit

Permalink
fix: S3 200 error customization should only not apply to streaming + …
Browse files Browse the repository at this point in the history
…blob shapes (#1633)
  • Loading branch information
dayaffe authored Aug 6, 2024
1 parent a453825 commit 4aecaa0
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@ import AwsCommonRuntimeKit
import ClientRuntime

public class MockHttpClientEngine: HTTPClient {
private let errorResponsePayload: String

// Public initializer
public init() {}
public init(response: String) {
self.errorResponsePayload = response
}

func successHttpResponse(request: SmithyHTTPAPI.HTTPRequest) -> HTTPResponse {
let errorResponsePayload = """
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"""
request.withHeader(name: "Date", value: "Wed, 21 Oct 2015 07:28:00 GMT")
return HTTPResponse(
headers: request.headers,
body: ByteStream.data(errorResponsePayload.data(using: .utf8)),
body: ByteStream.data(self.errorResponsePayload.data(using: .utf8)),
statusCode: .ok
)
}
Expand All @@ -44,25 +39,101 @@ public class MockHttpClientEngine: HTTPClient {

class S3ErrorIn200Test: XCTestCase {

let errorInternalErrorResponsePayload = """
<Error>
<Code>InternalError</Code>
<Message>We encountered an internal error. Please try again.</Message>
<RequestId>656c76696e6727732072657175657374</RequestId>
<HostId>Uuag1LuByRx9e6j5Onimru9pO4ZVKnJ2Qz7/C1NPcfTWAtRPfTaOFg==</HostId>
</Error>
"""

let errorSlowDownResponsePayload = """
<Error>
<Code>SlowDown</Code>
<Message>Please reduce your request rate.</Message>
<RequestId>K2H6N7ZGQT6WHCEG</RequestId>
<HostId>WWoZlnK4pTjKCYn6eNV7GgOurabfqLkjbSyqTvDMGBaI9uwzyNhSaDhOCPs8paFGye7S6b/AB3A=</HostId>
</Error>
"""

let shouldNotApplyResponsePayload = """
<DeleteResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Deleted>
<Key>sample1.txt</Key>
</Deleted>
<Error>
<Key>sample2.txt</Key>
<Code>AccessDenied</Code>
<Message>Access Denied</Message>
</Error>
</DeleteResult>
"""

override class func setUp() {
AwsCommonRuntimeKit.CommonRuntimeKit.initialize()
}

/// S3Client throws expected error in response (200) with <Error> tag
func test_foundExpectedError() async throws {
/// S3Client throws expected InternalError error in response (200) with <Error> tag
func test_foundInternalErrorExpectedError() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine()
config.httpClientEngine = MockHttpClientEngine(response: errorInternalErrorResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a stream
// any method on S3Client where the output shape doesnt have a blob stream
_ = try await client.listBuckets(input: .init())
XCTFail("Expected an error to be thrown, but it was not.")
} catch let error as UnknownAWSHTTPServiceError {
// check for the error we added in our mock client
XCTAssertEqual("InternalError", error.typeName)
XCTAssertEqual("We encountered an internal error. Please try again.", error.message)
} catch {
XCTFail("Unexpected error: \(error)")
}
}

/// S3Client throws expected SlowDown error in response (200) with <Error> tag
func test_foundSlowDownExpectedError() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine(response: errorSlowDownResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a blob stream
_ = try await client.listBuckets(input: .init())
XCTFail("Expected an error to be thrown, but it was not.")
} catch let error as UnknownAWSHTTPServiceError {
// check for the error we added in our mock client
XCTAssertEqual("SlowDown", error.typeName)
XCTAssertEqual("Please reduce your request rate.", error.message)
} catch {
XCTFail("Unexpected error: \(error)")
}
}

/// S3Client does not throw error when <Error> is not at the root
func test_noErrorExpected() async throws {
let config = try await S3Client.S3ClientConfiguration(region: "us-west-2")
config.httpClientEngine = MockHttpClientEngine(response: shouldNotApplyResponsePayload)
let client = S3Client(config: config)

do {
// any method on S3Client where the output shape doesnt have a stream
let result = try await client.deleteObjects(input: .init(delete: .init(objects: [.init(key: "test")])))

// Check results
XCTAssertEqual(result.deleted?.count, 1)
XCTAssertEqual(result.errors?.count, 1)

let actualDeleted = result.deleted?.first
XCTAssertEqual(actualDeleted?.key, "sample1.txt")

let actualError = result.errors?.first
XCTAssertEqual(actualError?.code, "AccessDenied")
XCTAssertEqual(actualError?.key, "sample2.txt")
} catch let error {
XCTFail("Expected success, but received \(error).")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ class S3EventStreamTests: S3XCTestCase {
outputSerialization: S3ClientTypes.OutputSerialization(json: S3ClientTypes.JSONOutput())
))

let outputStream = result.payload
guard let outputStream = result.payload else {
XCTFail("result.payload is nil")
return
}

var actualOutput = ""

for try await event in outputStream! {
for try await event in outputStream {
switch event {
case .records(let record):
actualOutput = actualOutput + (String(data: record.payload ?? Data(), encoding: .utf8) ?? "")
Expand Down
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ extension Target.Dependency {
static var smithyWaitersAPI: Self { .product(name: "SmithyWaitersAPI", package: "smithy-swift") }
static var smithyTestUtils: Self { .product(name: "SmithyTestUtil", package: "smithy-swift") }
static var smithyStreams: Self { .product(name: "SmithyStreams", package: "smithy-swift") }
static var smithyXML: Self { .product(name: "SmithyXML", package: "smithy-swift") }
}

// MARK: - Base Package
Expand Down Expand Up @@ -79,6 +80,7 @@ let package = Package(
.smithyRetries,
.smithyEventStreamsAPI,
.smithyEventStreamsAuthAPI,
.smithyXML,
.awsSDKCommon,
.awsSDKHTTPAuth,
.awsSDKIdentity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,25 @@
// SPDX-License-Identifier: Apache-2.0
//

import enum Smithy.ByteStream
import class Smithy.Context
import ClientRuntime
import SmithyHTTPAPI
import SmithyXML
import struct Foundation.Data
import SmithyStreams

public struct AWSS3ErrorWith200StatusXMLMiddleware<OperationStackInput, OperationStackOutput> {
public let id: String = "AWSS3ErrorWith200StatusXMLMiddleware"
private let errorStatusCode: HTTPStatusCode = .internalServerError

public init() {}

private func isErrorWith200Status(response: HTTPResponse) async throws -> Bool {
// Check if the status code is OK (200)
guard response.statusCode == .ok else {
return false
}

// Check if the response body contains an XML Error
guard let data = try await response.body.readData() else {
return false
}
private func isRootErrorElement(data: Data) throws -> Bool {
let reader = try Reader.from(data: data)

response.body = .data(data)
let xmlString = String(decoding: data, as: UTF8.self)
return xmlString.contains("<Error>")
// Check if there's an "Error" node at the root of the XML response
return reader.nodeInfo.name == "Error"
}
}

Expand All @@ -40,9 +35,41 @@ extension AWSS3ErrorWith200StatusXMLMiddleware: HttpInterceptor {
context: some MutableResponse<Self.InputType, Self.RequestType, Self.ResponseType>
) async throws {
let response = context.getResponse()
if try await isErrorWith200Status(response: response) {
response.statusCode = errorStatusCode
context.updateResponse(updated: response)

// Check if the status code is OK (200)
guard response.statusCode == .ok else {
return
}

guard let data = try await response.body.readData() else {
return
}

let statusCode = try isRootErrorElement(data: data) ? errorStatusCode : response.statusCode

// For event streams the body needs to be copied as buffered streams are non-seekable
let updatedBody = response.body.copy(data: data)

let updatedResponse = response.copy(
body: updatedBody,
statusCode: statusCode
)

context.updateResponse(updated: updatedResponse)
}
}

extension ByteStream {

// Copy an existing ByteStream, optionally with new data
public func copy(data: Data?) -> ByteStream {
switch self {
case .data(let existingData):
return .data(data ?? existingData)
case .stream(let existingStream):
return .stream(data != nil ? BufferedStream(data: data, isClosed: true) : existingStream)
case .noStream:
return .noStream
}
}
}
1 change: 1 addition & 0 deletions Sources/Services/AWSS3/Sources/AWSS3/S3Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7604,6 +7604,7 @@ extension S3Client {
builder.selectAuthScheme(ClientRuntime.AuthSchemeMiddleware<SelectObjectContentOutput>())
builder.interceptors.add(AWSClientRuntime.AmzSdkInvocationIdMiddleware<SelectObjectContentInput, SelectObjectContentOutput>())
builder.interceptors.add(AWSClientRuntime.AmzSdkRequestMiddleware<SelectObjectContentInput, SelectObjectContentOutput>(maxRetries: config.retryStrategyOptions.maxRetriesBase))
builder.interceptors.add(AWSClientRuntime.AWSS3ErrorWith200StatusXMLMiddleware<SelectObjectContentInput, SelectObjectContentOutput>())
var metricsAttributes = Smithy.Attributes()
metricsAttributes.set(key: ClientRuntime.OrchestratorMetricsAttributesKeys.service, value: "S3")
metricsAttributes.set(key: ClientRuntime.OrchestratorMetricsAttributesKeys.method, value: "SelectObjectContent")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import software.amazon.smithy.aws.swift.codegen.swiftmodules.AWSClientRuntimeTyp
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.swift.codegen.SwiftSettings
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
Expand All @@ -17,7 +16,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.Mid
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.isStreaming

/**
* Register interceptor to handle S3 error responses returned with an HTTP 200 status code.
Expand All @@ -39,10 +38,14 @@ class S3ErrorWith200StatusIntegration : SwiftIntegration {
// Instead of playing whack-a-mole broadly apply this interceptor to everything but streaming responses
// which adds a small amount of overhead to response processing.
val output = ctx.model.expectShape(operationShape.output.get())
val outputIsNotStreaming = output.members().none {
it.hasTrait<StreamingTrait>() || ctx.model.expectShape(it.target).hasTrait<StreamingTrait>()
val outputIsNotAStreamingBlobShape = output.members().none {
val targetShape = ctx.model.expectShape(it.target)
val isBlob = it.isBlobShape || targetShape.isBlobShape
val isStreaming = it.isStreaming || targetShape.isStreaming
isBlob && isStreaming
}
if (outputIsNotStreaming) {

if (outputIsNotAStreamingBlobShape) {
operationMiddleware.appendMiddleware(operationShape, S3HandleError200ResponseMiddleware)
}
}
Expand Down

0 comments on commit 4aecaa0

Please sign in to comment.