Skip to content

Commit

Permalink
feat: finish interceptors migration (#1631)
Browse files Browse the repository at this point in the history
This commit includes the codegen changes that move all services to
use orchestrator + interceptors, instead of the operation stack using
middleware. It also removes all references to the old types related to
Middleware, in Swift libraries and in the code generator.

Since this is already a huge PR, I tried to limit any refactoring, which
may have made it harder to review. For example, I didn't change the names
of middleware or their file names.

This commit also adds two new interceptors, AmzSdkRequestMiddleware and
AmzSdkInvocationIdMiddleware, which add corresponding headers to requests.
They are added in codegen through a SwiftIntegration. Previously, these
interceptors were implemented as a part of RetryMiddleware in smithy-swift.
But unlike all other middleware, RetryMiddleware does not have an
interceptor implementation - its logic is inlined within Orchestrator - so
the logic for adding these amz-sdk headers had to be extracted. Plus, they
seem to be aws headers, so I moved them here into the SDK. I also copied
the RetryIntegrationTests test, which was testing the headers' functionality.

Other than that, there were some codegen changes I made which impacted the
ordering that interceptors are rendered into operation bodies, which meant
I needed to fix some codegen tests. The order interceptors are added to
orchestrator shouldn't matter.
  • Loading branch information
milesziemer authored Jul 22, 2024
1 parent ee4cf62 commit 46e00d6
Show file tree
Hide file tree
Showing 545 changed files with 509,166 additions and 226,814 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@ import struct ClientRuntime.DefaultEndpointsAuthSchemeResolver
import enum ClientRuntime.EndpointsAuthScheme
import protocol ClientRuntime.EndpointsAuthSchemeResolver
import protocol ClientRuntime.EndpointsRequestContextProviding
import protocol ClientRuntime.Handler
import protocol ClientRuntime.Middleware
import struct ClientRuntime.OperationOutput

public struct EndpointResolverMiddleware<OperationStackOutput, Params: EndpointsRequestContextProviding>: Middleware {
public typealias MInput = HTTPRequestBuilder
public typealias MOutput = ClientRuntime.OperationOutput<OperationStackOutput>

public struct EndpointResolverMiddleware<OperationStackOutput, Params: EndpointsRequestContextProviding> {
public let id: Swift.String = "EndpointResolverMiddleware"

let endpointResolverBlock: (Params) throws -> Endpoint
Expand All @@ -42,30 +36,14 @@ public struct EndpointResolverMiddleware<OperationStackOutput, Params: Endpoints
self.endpointParams = endpointParams
self.authSchemeResolver = authSchemeResolver
}

public func handle<H>(
context: Smithy.Context,
input: HTTPRequestBuilder,
next: H
) async throws -> ClientRuntime.OperationOutput<OperationStackOutput>
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {
let selectedAuthScheme = context.selectedAuthScheme
let request = input.build()
let updatedRequest =
try await apply(request: request, selectedAuthScheme: selectedAuthScheme, attributes: context)
return try await next.handle(context: context, input: updatedRequest.toBuilder())
}
}

extension EndpointResolverMiddleware: ApplyEndpoint {

public func apply(
request: HTTPRequest,
request: SmithyHTTPAPI.HTTPRequest,
selectedAuthScheme: SelectedAuthScheme?,
attributes: Smithy.Context
) async throws -> HTTPRequest {
) async throws -> SmithyHTTPAPI.HTTPRequest {
let builder = request.toBuilder()

let endpoint = try endpointResolverBlock(endpointParams)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import class SmithyHTTPAPI.HTTPResponse
import ClientRuntime

/// AWS specific Service Error structure used when exact error could not be deduced from the `HttpResponse`
/// AWS specific Service Error structure used when exact error could not be deduced from the `HTTPResponse`
/// Developers should catch unknown errors by the interface `AWSServiceError`, then use the `errorCode` to determine & handle each type of error.
@_spi(UnknownAWSHTTPServiceError) public struct UnknownAWSHTTPServiceError: AWSServiceError, HTTPError, Error {

Expand All @@ -26,9 +26,9 @@ import ClientRuntime

extension UnknownAWSHTTPServiceError {

/// Creates an `UnknownAWSHttpServiceError` from a `HttpResponse` and associated parameters.
/// Creates an `UnknownAWSHttpServiceError` from a `HTTPResponse` and associated parameters.
/// - Parameters:
/// - httpResponse: The `HttpResponse` for this error.
/// - httpResponse: The `HTTPResponse` for this error.
/// - message: The message associated with this error.
/// - requestID: The request ID associated with this error.
/// - requestID2: The request ID2 associated with this error (defined on S3 only.) Defaults to `nil`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,12 @@ import class Smithy.Context
import ClientRuntime
import SmithyHTTPAPI

public struct AWSS3ErrorWith200StatusXMLMiddleware<OperationStackInput, OperationStackOutput>: Middleware {
public struct AWSS3ErrorWith200StatusXMLMiddleware<OperationStackInput, OperationStackOutput> {
public let id: String = "AWSS3ErrorWith200StatusXMLMiddleware"
private let errorStatusCode: HTTPStatusCode = .internalServerError

public init() {}

public func handle<H>(context: Context,
input: HTTPRequest,
next: H) async throws -> OperationOutput<OperationStackOutput>
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {

// Let the next handler in the chain process the input
let response = try await next.handle(context: context, input: input)

if try await isErrorWith200Status(response: response.httpResponse) {
// Handle the error as a 500 Internal Server Error
let modifiedResponse = response
modifiedResponse.httpResponse.statusCode = errorStatusCode
return modifiedResponse
}

return response
}

private func isErrorWith200Status(response: HTTPResponse) async throws -> Bool {
// Check if the status code is OK (200)
guard response.statusCode == .ok else {
Expand All @@ -50,9 +30,6 @@ public struct AWSS3ErrorWith200StatusXMLMiddleware<OperationStackInput, Operatio
let xmlString = String(decoding: data, as: UTF8.self)
return xmlString.contains("<Error>")
}

public typealias MInput = HTTPRequest
public typealias MOutput = OperationOutput<OperationStackOutput>
}

extension AWSS3ErrorWith200StatusXMLMiddleware: HttpInterceptor {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import ClientRuntime
import Smithy
import SmithyHTTPAPI
import struct Foundation.UUID

private let AMZ_SDK_INVOCATION_ID_HEADER = "amz-sdk-invocation-id"

/// Adds the amz-sdk-invocation-id header to requests.
public struct AmzSdkInvocationIdMiddleware<InputType, OperationStackOutput> {
public var id: String { "AmzSdkInvocationId" }

// The UUID string used to uniquely identify an API call and all of its subsequent retries.
private let invocationId = UUID().uuidString.lowercased()

public init() {}

private func addHeader(builder: HTTPRequestBuilder) {
builder.withHeader(name: AMZ_SDK_INVOCATION_ID_HEADER, value: invocationId)
}
}

extension AmzSdkInvocationIdMiddleware: HttpInterceptor {
public typealias InputType = InputType
public typealias OutputType = OperationStackOutput

public func modifyBeforeRetryLoop(context: some MutableRequest<InputType, HTTPRequest>) async throws {
let builder = context.getRequest().toBuilder()
addHeader(builder: builder)
context.updateRequest(updated: builder.build())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import ClientRuntime
import Smithy
import SmithyHTTPAPI
import class Foundation.DateFormatter
import struct Foundation.Locale
import struct Foundation.TimeInterval
import struct Foundation.TimeZone
import struct Foundation.UUID

private let AMZ_SDK_REQUEST_HEADER = "amz-sdk-request"

/// Adds the amz-sdk-request header to requests.
public class AmzSdkRequestMiddleware<InputType, OperationStackOutput> {
public var id: String { "AmzSdkRequest" }

// Max number of retries configured for retry strategy.
private var maxRetries: Int
private var attempt: Int = 0

public init(maxRetries: Int) {
self.maxRetries = maxRetries
}

private func addHeader(builder: HTTPRequestBuilder, context: Context) {
self.attempt += 1

// Only compute ttl after first attempt
if self.attempt == 1 {
builder.withHeader(name: AMZ_SDK_REQUEST_HEADER, value: "attempt=1; max=\(maxRetries)")
} else {
let estimatedSkew = context.estimatedSkew ?? {
context.getLogger()?.info("Estimated skew not found; defaulting to zero.")
return 0
}()
let socketTimeout = context.socketTimeout ?? {
context.getLogger()?.info("Socket timeout value not found; defaulting to 60 seconds.")
return 60.0
}()
let ttlDateUTCString = awsGetTTL(now: Date(), estimatedSkew: estimatedSkew, socketTimeout: socketTimeout)
builder.updateHeader(
name: AMZ_SDK_REQUEST_HEADER,
value: "ttl=\(ttlDateUTCString); attempt=\(self.attempt); max=\(maxRetries)"
)
}
}

}

extension AmzSdkRequestMiddleware: HttpInterceptor {
public typealias InputType = InputType
public typealias OutputType = OperationStackOutput

public func modifyBeforeSigning(context: some MutableRequest<InputType, HTTPRequest>) async throws {
let builder = context.getRequest().toBuilder()
addHeader(builder: builder, context: context.getAttributes())
context.updateRequest(updated: builder.build())
}
}

// Calculates & returns TTL datetime in strftime format `YYYYmmddTHHMMSSZ`.
func awsGetTTL(now: Date, estimatedSkew: TimeInterval, socketTimeout: TimeInterval) -> String {
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
dateFormatter.locale = Locale(identifier: "en_US_POSIX")
dateFormatter.timeZone = TimeZone(abbreviation: "UTC")
let ttlDate = now.addingTimeInterval(estimatedSkew + socketTimeout)
return dateFormatter.string(from: ttlDate)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import AWSSDKChecksums
import ClientRuntime
import SmithyHTTPAPI

public struct FlexibleChecksumsRequestMiddleware<OperationStackInput, OperationStackOutput>: Middleware {
public struct FlexibleChecksumsRequestMiddleware<OperationStackInput, OperationStackOutput> {

public let id: String = "FlexibleChecksumsRequestMiddleware"

Expand All @@ -24,16 +24,6 @@ public struct FlexibleChecksumsRequestMiddleware<OperationStackInput, OperationS
self.checksumAlgorithm = checksumAlgorithm
}

public func handle<H>(context: Context,
input: SerializeStepInput<OperationStackInput>,
next: H) async throws -> OperationOutput<OperationStackOutput>
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {
try await addHeaders(builder: input.builder, attributes: context)
return try await next.handle(context: context, input: input)
}

private func addHeaders(builder: HTTPRequestBuilder, attributes: Context) async throws {
if case(.stream(let stream)) = builder.body {
attributes.isChunkedEligibleStream = stream.isEligibleForChunkedStreaming
Expand Down Expand Up @@ -100,9 +90,6 @@ public struct FlexibleChecksumsRequestMiddleware<OperationStackInput, OperationS
throw ClientError.dataNotFound("Cannot calculate the checksum of an empty body!")
}
}

public typealias MInput = SerializeStepInput<OperationStackInput>
public typealias MOutput = OperationOutput<OperationStackOutput>
}

extension FlexibleChecksumsRequestMiddleware: HttpInterceptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import enum SmithyChecksumsAPI.ChecksumAlgorithm
import enum SmithyChecksums.ChecksumMismatchException
import ClientRuntime

public struct FlexibleChecksumsResponseMiddleware<OperationStackInput, OperationStackOutput>: Middleware {
public struct FlexibleChecksumsResponseMiddleware<OperationStackInput, OperationStackOutput> {

public let id: String = "FlexibleChecksumsResponseMiddleware"

Expand All @@ -29,27 +29,6 @@ public struct FlexibleChecksumsResponseMiddleware<OperationStackInput, Operation
: CHECKSUM_HEADER_VALIDATION_PRIORITY_LIST
}

public func handle<H>(context: Context,
input: HTTPRequest,
next: H) async throws -> OperationOutput<OperationStackOutput>
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {

// The name of the checksum header which was validated. If `null`, validation was not performed.
context.attributes.set(key: AttributeKey<String>(name: "ChecksumHeaderValidated"), value: nil)

// Initialize logger
guard let logger = context.getLogger() else { throw ClientError.unknownError("No logger found!") }

// Get the response
let output = try await next.handle(context: context, input: input)

try await validateChecksum(response: output.httpResponse, logger: logger, attributes: context)

return output
}

private func validateChecksum(response: HTTPResponse, logger: any LogAgent, attributes: Context) async throws {
// Exit if validation should not be performed
if !validationMode {
Expand Down Expand Up @@ -111,9 +90,6 @@ public struct FlexibleChecksumsResponseMiddleware<OperationStackInput, Operation
throw ClientError.dataNotFound("Cannot calculate the checksum of an empty body!")
}
}

public typealias MInput = HTTPRequest
public typealias MOutput = OperationOutput<OperationStackOutput>
}

extension FlexibleChecksumsResponseMiddleware: HttpInterceptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import class Smithy.Context
import ClientRuntime

public struct Route53TrimHostedZoneMiddleware<Input, Output>: ClientRuntime.Middleware {
public struct Route53TrimHostedZoneMiddleware<Input, Output> {
public let id: Swift.String = "Route53TrimHostedZoneMiddleware"
private let prefixes = ["/hostedzone/", "hostedzone/", "/hostedzone", "hostedzone"]

Expand All @@ -18,16 +18,6 @@ public struct Route53TrimHostedZoneMiddleware<Input, Output>: ClientRuntime.Midd
self.hostedZoneIDKeyPath = hostedZoneIDKeyPath
}

public func handle<H>(context: Context,
input: Input,
next: H) async throws -> ClientRuntime.OperationOutput<Output>
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {
let updatedInput = getUpdatedInput(input: input)
return try await next.handle(context: context, input: updatedInput)
}

private func getUpdatedInput(input: Input) -> Input {
guard let hostedZoneId = input[keyPath: hostedZoneIDKeyPath] else {
return input
Expand All @@ -37,9 +27,6 @@ public struct Route53TrimHostedZoneMiddleware<Input, Output>: ClientRuntime.Midd
copiedInput[keyPath: hostedZoneIDKeyPath] = stripped
return copiedInput
}

public typealias MInput = Input
public typealias MOutput = ClientRuntime.OperationOutput<Output>
}

extension Route53TrimHostedZoneMiddleware: HttpInterceptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import SmithyHTTPAPI
import struct Foundation.Data
import struct Smithy.AttributeKey

public struct Sha256TreeHashMiddleware<OperationStackInput, OperationStackOutput>: Middleware {
public struct Sha256TreeHashMiddleware<OperationStackInput, OperationStackOutput> {
public let id: String = "Sha256TreeHash"

private let X_AMZ_SHA256_TREE_HASH_HEADER_NAME = "X-Amz-Sha256-Tree-Hash"
Expand All @@ -18,20 +18,9 @@ public struct Sha256TreeHashMiddleware<OperationStackInput, OperationStackOutput

public init() {}

public func handle<H>(context: Context,
input: MInput,
next: H) async throws -> MOutput
where H: Handler,
Self.MInput == H.Input,
Self.MOutput == H.Output {
let request = input.build()
try await addHashes(request: request, builder: input, context: context)
return try await next.handle(context: context, input: input)
}

private func addHashes(
request: SmithyHTTPAPI.HTTPRequest,
builder: HTTPRequestBuilder,
builder: SmithyHTTPAPI.HTTPRequestBuilder,
context: Context
) async throws {
switch request.body {
Expand Down Expand Up @@ -107,9 +96,6 @@ public struct Sha256TreeHashMiddleware<OperationStackInput, OperationStackOutput
let data = Data(previousLevelHashes[0])
return data.encodeToHexString()
}

public typealias MInput = HTTPRequestBuilder
public typealias MOutput = OperationOutput<OperationStackOutput>
}

extension Sha256TreeHashMiddleware: HttpInterceptor {
Expand Down
Loading

0 comments on commit 46e00d6

Please sign in to comment.