Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for requiresLength trait and Transfer-Encoding: Chunked #604

Merged
merged 20 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>

private let contentLengthHeaderName = "Content-Length"

public init() {}
private var requiresLength: Bool = false

private var unsignedPayload: Bool = false

public init(requiresLength: Bool = false, unsignedPayload: Bool = false) {
self.requiresLength = requiresLength
self.unsignedPayload = unsignedPayload
}

public func handle<H>(context: Context,
input: MInput,
Expand All @@ -22,8 +29,12 @@ public struct ContentLengthMiddleware<OperationStackOutput: HttpResponseBinding>
case .stream(let stream):
if let length = stream.length {
input.headers.update(name: "Content-Length", value: String(length))
} else if !requiresLength && unsignedPayload {
// only for HTTP/1.1 requests, will be removed in all HTTP/2 requests
input.headers.update(name: "Transfer-Encoding", value: "Chunked")
} else {
input.headers.update(name: "Transfer-Encoded", value: "Chunked")
let errorMessage = unsignedPayload ? "operation requires length" : "sigv4 requires length"
dayaffe marked this conversation as resolved.
Show resolved Hide resolved
throw StreamError.notSupported(errorMessage)
}
default:
input.headers.update(name: "Content-Length", value: "0")
Expand Down
3 changes: 3 additions & 0 deletions Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ extension SdkHttpRequest {
httpRequest.path = [endpoint.path, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
httpRequest.addHeaders(headers: headers.toHttpHeaders())

// Remove the "Transfer-Encoding" header if it exists since h2 does not support it
httpRequest.removeHeader(name: "Transfer-Encoding")

// HTTP2Request used with manual writes hence we need to set the body to nil
// so that CRT does not write the body for us (we will write it manually)
httpRequest.body = nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0.

import XCTest
import SmithyTestUtil
@testable import ClientRuntime

class ContentLengthMiddlewareTests: XCTestCase {
private var builtContext: HttpContext!
private var stack: OperationStack<MockInput, MockOutput, MockMiddlewareError>!

override func setUpWithError() throws {
try super.setUpWithError()
builtContext = HttpContextBuilder()
.withMethod(value: .get)
.withPath(value: "/")
.withEncoder(value: JSONEncoder())
.withDecoder(value: JSONDecoder())
.withOperation(value: "Test Operation")
.build()
stack = OperationStack<MockInput, MockOutput, MockMiddlewareError>(id: "Test Operation")
}

func testTransferEncodingChunkedSetWhenStreamLengthIsNil() async throws {
addContentLengthMiddlewareWith(requiresLength: false, unsignedPayload: true)
forceEmptyStream()
try await AssertHeadersArePresent(expectedHeaders: ["Transfer-Encoding": "Chunked"])
}

func testContentLengthSetWhenStreamLengthAvailableAndRequiresLengthSet() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
}

func testContentLengthSetWhenRequiresLengthAndUnsignedPayload() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: true)
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
}

func testRequiresLengthSetWithNilStreamShouldThrowError() async throws {
addContentLengthMiddlewareWith(requiresLength: true, unsignedPayload: false)
forceEmptyStream()
do {
try await AssertHeadersArePresent(expectedHeaders: ["Content-Length": "0"])
XCTFail("Should throw error")
} catch let error as StreamError {
switch error {
case .notSupported("sigv4 requires length"), .notSupported("operation requires length"):
// The error matches one of the expected cases, test passes
break
default:
XCTFail("Error is not StreamError.notSupported with expected message")
}
}
}

private func addContentLengthMiddlewareWith(requiresLength: Bool, unsignedPayload: Bool) {
stack.finalizeStep.intercept(
position: .before,
middleware: ContentLengthMiddleware(requiresLength: requiresLength, unsignedPayload: unsignedPayload)
)
}

private func forceEmptyStream() {
// Force stream length to be nil
stack.finalizeStep.intercept(position: .before, id: "set nil stream length") { (context, input, next) -> OperationOutput<MockOutput> in
input.body = .stream(BufferedStream()) // Set the stream length to nil
return try await next.handle(context: context, input: input)
}
}

private func AssertHeadersArePresent(expectedHeaders: [String: String]) async throws -> Void {
dayaffe marked this conversation as resolved.
Show resolved Hide resolved
_ = try await stack.handleMiddleware(context: builtContext, input: MockInput(),
next: MockHandler(handleCallback: { (_, input) in
for (key, value) in expectedHeaders {
XCTAssert(input.headers.value(for: key) == value)
}
let httpResponse = HttpResponse(body: HttpBody.none, statusCode: HttpStatusCode.ok)
let mockOutput = try! MockOutput(httpResponse: httpResponse, decoder: nil)
let output = OperationOutput<MockOutput>(httpResponse: httpResponse,
output: mockOutput)
return output
}))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package software.amazon.smithy.swift.codegen.integration

import software.amazon.smithy.aws.traits.auth.UnsignedPayloadTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.knowledge.HttpBindingIndex
Expand All @@ -30,6 +31,7 @@ import software.amazon.smithy.model.traits.HttpPrefixHeadersTrait
import software.amazon.smithy.model.traits.HttpQueryParamsTrait
import software.amazon.smithy.model.traits.HttpQueryTrait
import software.amazon.smithy.model.traits.MediaTypeTrait
import software.amazon.smithy.model.traits.RequiresLengthTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.TimestampFormatTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
Expand Down Expand Up @@ -60,6 +62,7 @@ import software.amazon.smithy.swift.codegen.integration.serde.UnionEncodeGenerat
import software.amazon.smithy.swift.codegen.middleware.OperationMiddlewareGenerator
import software.amazon.smithy.swift.codegen.model.ShapeMetadata
import software.amazon.smithy.swift.codegen.model.bodySymbol
import software.amazon.smithy.swift.codegen.model.findStreamingMember
import software.amazon.smithy.swift.codegen.model.hasEventStreamMember
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.utils.OptionalUtils
Expand Down Expand Up @@ -91,9 +94,8 @@ fun formatHeaderOrQueryValue(
memberShape: MemberShape,
location: HttpBinding.Location,
bindingIndex: HttpBindingIndex,
defaultTimestampFormat: TimestampFormatTrait.Format
defaultTimestampFormat: TimestampFormatTrait.Format,
dayaffe marked this conversation as resolved.
Show resolved Hide resolved
): Pair<String, Boolean> {

return when (val shape = ctx.model.expectShape(memberShape.target)) {
is TimestampShape -> {
val timestampFormat = bindingIndex.determineTimestampFormat(memberShape, location, defaultTimestampFormat)
Expand Down Expand Up @@ -165,7 +167,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
writer.openBlock(
"extension $symbolName: \$N {",
"}",
SwiftTypes.Protocols.Encodable
SwiftTypes.Protocols.Encodable,
) {
writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)

Expand Down Expand Up @@ -286,7 +288,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
private fun generateCodingKeysForMembers(
ctx: ProtocolGenerator.GenerationContext,
writer: SwiftWriter,
members: List<MemberShape>
members: List<MemberShape>,
) {
codingKeysGenerator.generateCodingKeysForMembers(ctx, writer, members)
}
Expand All @@ -298,7 +300,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
val inputType = ctx.model.expectShape(operation.input.get())
var metadata = mapOf<ShapeMetadata, Any>(
Pair(ShapeMetadata.OPERATION_SHAPE, operation),
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version)
Pair(ShapeMetadata.SERVICE_VERSION, ctx.service.version),
)
shapesInfo.put(inputType, metadata)
}
Expand Down Expand Up @@ -336,7 +338,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
}

private fun resolveShapesNeedingCodableConformance(ctx: ProtocolGenerator.GenerationContext): Set<Shape> {

val topLevelOutputMembers = getHttpBindingOperations(ctx).flatMap {
val outputShape = ctx.model.expectShape(it.output.get())
outputShape.members()
Expand Down Expand Up @@ -390,7 +391,8 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
RelationshipType.LIST_MEMBER,
RelationshipType.SET_MEMBER,
RelationshipType.MAP_VALUE,
RelationshipType.UNION_MEMBER -> true
RelationshipType.UNION_MEMBER,
-> true
else -> false
}
}.forEach {
Expand All @@ -403,6 +405,31 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
return resolved
}

// Checks for @requiresLength trait
// Returns true if the operation:
// - has a streaming member with @httpPayload trait
// - target is a blob shape with @requiresLength trait
private fun hasRequiresLengthTrait(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean {
if (op.input.isPresent) {
val inputShape = ctx.model.expectShape(op.input.get())
val streamingMember = inputShape.findStreamingMember(ctx.model)
if (streamingMember != null) {
val targetShape = ctx.model.expectShape(streamingMember.target)
if (targetShape != null) {
return streamingMember.hasTrait<HttpPayloadTrait>() &&
targetShape.isBlobShape &&
targetShape.hasTrait<RequiresLengthTrait>()
}
}
}
return false
}

// Checks for @unsignedPayload trait on an operation
private fun hasUnsignedBody(op: OperationShape): Boolean {
dayaffe marked this conversation as resolved.
Show resolved Hide resolved
return op.hasTrait<UnsignedPayloadTrait>()
}

override fun generateProtocolClient(ctx: ProtocolGenerator.GenerationContext) {
val symbol = ctx.symbolProvider.toSymbol(ctx.service)
ctx.delegator.useFileWriter("./${ctx.settings.moduleName}/${symbol.name}.swift") { writer ->
Expand All @@ -414,7 +441,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
serviceSymbol.name,
defaultContentType,
httpProtocolCustomizable,
operationMiddleware
operationMiddleware,
)
clientGenerator.render()
}
Expand All @@ -433,7 +460,7 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
operationMiddleware.appendMiddleware(operation, ContentTypeMiddleware(ctx.model, ctx.symbolProvider, resolver.determineRequestContentType(operation)))
operationMiddleware.appendMiddleware(operation, OperationInputBodyMiddleware(ctx.model, ctx.symbolProvider))

operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance))
operationMiddleware.appendMiddleware(operation, ContentLengthMiddleware(ctx.model, shouldRenderEncodableConformance, hasRequiresLengthTrait(ctx, operation), hasUnsignedBody(operation)))

operationMiddleware.appendMiddleware(operation, DeserializeMiddleware(ctx.model, ctx.symbolProvider))
operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider))
Expand Down Expand Up @@ -463,15 +490,15 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
members: List<MemberShape>,
writer: SwiftWriter,
defaultTimestampFormat: TimestampFormatTrait.Format,
path: String? = null
path: String? = null,
)
protected abstract fun renderStructDecode(
ctx: ProtocolGenerator.GenerationContext,
shapeMetaData: Map<ShapeMetadata, Any>,
members: List<MemberShape>,
writer: SwiftWriter,
defaultTimestampFormat: TimestampFormatTrait.Format,
path: String
path: String,
)
protected abstract fun addProtocolSpecificMiddleware(ctx: ProtocolGenerator.GenerationContext, operation: OperationShape)

Expand All @@ -487,11 +514,11 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
for (operation in topDownIndex.getContainedOperations(ctx.service)) {
OptionalUtils.ifPresentOrElse(
Optional.of(getProtocolHttpBindingResolver(ctx, defaultContentType).httpTrait(operation)::class.java),
{ containedOperations.add(operation) }
{ containedOperations.add(operation) },
) {
LOGGER.warning(
"Unable to fetch $protocolName protocol request bindings for ${operation.id} because " +
"it does not have an http binding trait"
"it does not have an http binding trait",
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep

class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean) : MiddlewareRenderable {
class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boolean, private val requiresLength: Boolean, private val unsignedPayload: Boolean) : MiddlewareRenderable {

override val name = "ContentLengthMiddleware"

Expand All @@ -20,17 +20,20 @@ class ContentLengthMiddleware(val model: Model, private val alwaysIntercept: Boo
override fun render(
writer: SwiftWriter,
op: OperationShape,
operationStackName: String
operationStackName: String,
) {
val hasHttpBody = MiddlewareShapeUtils.hasHttpBody(model, op)
if (hasHttpBody || alwaysIntercept) {
writer.write(
"\$L.\$L.intercept(position: \$L, middleware: \$N())",
operationStackName,
middlewareStep.stringValue(),
position.stringValue(),
ClientRuntimeTypes.Middleware.ContentLengthMiddleware
)
val middlewareArgs = if (requiresLength || unsignedPayload) {
"requiresLength: $requiresLength, unsignedPayload: $unsignedPayload"
} else {
""
}
dayaffe marked this conversation as resolved.
Show resolved Hide resolved

val interceptStatement = "$operationStackName.${middlewareStep.stringValue()}.intercept(" +
"position: ${position.stringValue()}, middleware: ${ClientRuntimeTypes.Middleware.ContentLengthMiddleware}($middlewareArgs))"

writer.write(interceptStatement)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TestHttpProtocolClientGeneratorFactory : HttpProtocolClientGeneratorFactor
private fun getClientProperties(ctx: ProtocolGenerator.GenerationContext): List<ClientProperty> {
return mutableListOf(
DefaultRequestEncoder(),
DefaultResponseDecoder()
DefaultResponseDecoder(),
)
}

Expand Down Expand Up @@ -125,6 +125,7 @@ extension InlineDocumentAsPayloadOutput: ClientRuntime.HttpResponseBinding {
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun `default fooMap to an empty map if keysForFooMap is empty`() {
val contents = getModelFileContents("example", "HttpPrefixHeadersOutput+HttpResponseBinding.swift", newTestContext.manifest)
Expand Down
Loading
Loading