Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for flexible checksums on Normal payloads #1304

Merged
merged 10 commits into from
Feb 7, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package software.amazon.smithy.aws.swift.codegen.customization.flexiblechecksums

import software.amazon.smithy.aws.traits.HttpChecksumTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftSettings
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.getOrNull
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.shapes

class FlexibleChecksumsRequestIntegration : SwiftIntegration {
override fun enabledForService(model: Model, settings: SwiftSettings): Boolean = model
.shapes<OperationShape>()
.any { it.hasTrait(HttpChecksumTrait::class.java) }

override fun customizeMiddleware(
ctx: ProtocolGenerator.GenerationContext,
operationShape: OperationShape,
operationMiddleware: OperationMiddleware,
) {
val httpChecksumTrait = operationShape.getTrait(HttpChecksumTrait::class.java).getOrNull()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same .getOrNull() that I'm eliminating in the XML branch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think we decided to use .orElse(null) which is a Kotlin built-in)

Copy link
Collaborator Author

@dayaffe dayaffe Jan 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the .orElse(null) we were talking about I believe works on Java types not Kotlin types where we prefer to use ?. that work on optionals. I can check where .getOrNull() is used and see if we can switch it over to the kotlin optionals. It seems like that method .getOrNull() was created to work on either Java or Kotlin types

val input = operationShape.input.getOrNull()?.let { ctx.model.expectShape<StructureShape>(it) }

val useFlexibleChecksum = (httpChecksumTrait != null) &&
(httpChecksumTrait.requestAlgorithmMember?.getOrNull() != null) &&
(input?.memberNames?.any { it == httpChecksumTrait.requestAlgorithmMember.get() } == true)

if (useFlexibleChecksum) {
operationMiddleware.appendMiddleware(operationShape, FlexibleChecksumRequestMiddleware)
}
}
}

private fun String.lowercaseFirstLetter(): String =
takeIf { it.isNotEmpty() }?.let { it.first().lowercase() + it.substring(1) } ?: this

private object FlexibleChecksumRequestMiddleware : MiddlewareRenderable {
override val name = "FlexibleChecksumRequestMiddleware"

override val middlewareStep = MiddlewareStep.SERIALIZESTEP

override val position = MiddlewarePosition.AFTER

override fun render(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter, op: OperationShape, operationStackName: String) {
val inputShapeName = MiddlewareShapeUtils.inputSymbol(ctx.symbolProvider, ctx.model, op).name
val outputShapeName = MiddlewareShapeUtils.outputSymbol(ctx.symbolProvider, ctx.model, op).name
val httpChecksumTrait = op.getTrait(HttpChecksumTrait::class.java).getOrNull()
val inputMemberName = httpChecksumTrait?.requestAlgorithmMember?.get()?.lowercaseFirstLetter()

// Convert algorithmNames list to a Swift array representation
val middlewareInit = "${ClientRuntimeTypes.Middleware.FlexibleChecksumsRequestMiddleware}<$inputShapeName, $outputShapeName>(checksumAlgorithm: input.$inputMemberName?.rawValue)"

writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: $middlewareInit)")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package software.amazon.smithy.aws.swift.codegen.customization.flexiblechecksums

import software.amazon.smithy.aws.traits.HttpChecksumTrait
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftSettings
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.getOrNull
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils
import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition
import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable
import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep
import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.shapes

class FlexibleChecksumsResponseIntegration : SwiftIntegration {
override fun enabledForService(model: Model, settings: SwiftSettings): Boolean = model
.shapes<OperationShape>()
.any { it.hasTrait(HttpChecksumTrait::class.java) }

override fun customizeMiddleware(
ctx: ProtocolGenerator.GenerationContext,
operationShape: OperationShape,
operationMiddleware: OperationMiddleware,
) {
val httpChecksumTrait = operationShape.getTrait(HttpChecksumTrait::class.java).getOrNull()
val input = operationShape.input.getOrNull()?.let { ctx.model.expectShape<StructureShape>(it) }

val useFlexibleChecksum = (httpChecksumTrait != null) &&
(httpChecksumTrait.requestValidationModeMember?.getOrNull() != null) &&
(input?.memberNames?.any { it == httpChecksumTrait.requestValidationModeMember.get() } == true)

if (useFlexibleChecksum) {
operationMiddleware.appendMiddleware(operationShape, FlexibleChecksumResponseMiddleware)
}
}
}

private object FlexibleChecksumResponseMiddleware : MiddlewareRenderable {
override val name = "FlexibleChecksumResponseMiddleware"

override val middlewareStep = MiddlewareStep.DESERIALIZESTEP

override val position = MiddlewarePosition.AFTER

override fun render(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter, op: OperationShape, operationStackName: String) {
val outputShapeName = MiddlewareShapeUtils.outputSymbol(ctx.symbolProvider, ctx.model, op).name
val httpChecksumTrait = op.getTrait(HttpChecksumTrait::class.java).getOrNull()
val inputMemberName = httpChecksumTrait?.requestValidationModeMember?.get()
val validationModeMember = ctx.model.expectShape(op.inputShape).getMember(inputMemberName)
val requestValidationModeEnumShape = ctx.model.expectShape(validationModeMember.getOrNull()?.target)

// Will pass the validation mode to validation middleware
val validationMode: Boolean = requestValidationModeEnumShape.members().map { it.memberName }.first().equals("ENABLED")
val middlewareInit = "${ClientRuntimeTypes.Middleware.FlexibleChecksumsResponseMiddleware}<$outputShapeName>(validationMode: $validationMode)"
writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: $middlewareInit)")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ software.amazon.smithy.aws.swift.codegen.customization.s3.TruncatablePaginationI
software.amazon.smithy.aws.swift.codegen.customization.route53.Route53TrimHostedZone
software.amazon.smithy.aws.swift.codegen.customization.route53.Route53InvalidBatchErrorIntegration
software.amazon.smithy.aws.swift.codegen.customization.apigateway.ApiGatewayAddAcceptHeader
software.amazon.smithy.aws.swift.codegen.customization.flexiblechecksums.FlexibleChecksumsRequestIntegration
software.amazon.smithy.aws.swift.codegen.customization.flexiblechecksums.FlexibleChecksumsResponseIntegration
software.amazon.smithy.aws.swift.codegen.customization.glacier.GlacierAddVersionHeader
software.amazon.smithy.aws.swift.codegen.customization.glacier.GlacierAccountIdDefault
software.amazon.smithy.aws.swift.codegen.customization.glacier.GlacierChecksum
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package software.amazon.smithy.aws.swift.codegen.customizations

import io.kotest.matchers.string.shouldContainOnlyOnce
import org.junit.jupiter.api.Test
import software.amazon.smithy.aws.swift.codegen.TestContext
import software.amazon.smithy.aws.swift.codegen.TestContextGenerator
import software.amazon.smithy.aws.swift.codegen.restjson.AWSRestJson1ProtocolGenerator
import software.amazon.smithy.aws.swift.codegen.shouldSyntacticSanityCheck
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait

class FlexibleChecksumMiddlewareTests {

@Test
fun `Test that FlexibleChecksumsRequestMiddleware and FlexibleChecksumsResponseMiddleware are properly generated`() {
val context = setupTests("flexible-checksums.smithy", "aws.flex.checks#ChecksumTests")
val contents = TestContextGenerator.getFileContents(context.manifest, "/Example/ChecksumTestsClient.swift")
contents.shouldSyntacticSanityCheck()
val expectedContents = """
extension ChecksumTestsClient {
/// Performs the `SomeOperation` operation on the `ChecksumTests` service.
///
///
/// - Parameter SomeOperationInput : [no documentation found]
///
/// - Returns: `SomeOperationOutput` : [no documentation found]
public func someOperation(input: SomeOperationInput) async throws -> SomeOperationOutput {
let context = ClientRuntime.HttpContextBuilder()
.withEncoder(value: encoder)
.withDecoder(value: decoder)
.withMethod(value: .post)
.withServiceName(value: serviceName)
.withOperation(value: "someOperation")
.withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)
.withLogger(value: config.logger)
.withPartitionID(value: config.partitionID)
.withCredentialsProvider(value: config.credentialsProvider)
.withRegion(value: config.region)
.build()
var operation = ClientRuntime.OperationStack<SomeOperationInput, SomeOperationOutput>(id: "someOperation")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware<SomeOperationInput, SomeOperationOutput>(SomeOperationInput.urlPathProvider(_:)))
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware<SomeOperationInput, SomeOperationOutput>())
operation.buildStep.intercept(position: .before, middleware: ClientRuntime.ContentMD5Middleware<SomeOperationOutput>())
let endpointParams = EndpointParams()
operation.buildStep.intercept(position: .before, middleware: EndpointResolverMiddleware<SomeOperationOutput>(endpointResolver: config.serviceSpecific.endpointResolver, endpointParams: endpointParams))
operation.buildStep.intercept(position: .before, middleware: AWSClientRuntime.UserAgentMiddleware(metadata: AWSClientRuntime.AWSUserAgentMetadata.fromConfig(serviceID: serviceName, version: "1.0.0", config: config)))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.HeaderMiddleware<SomeOperationInput, SomeOperationOutput>(SomeOperationInput.headerProvider(_:)))
operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware<SomeOperationInput, SomeOperationOutput>(contentType: "application/octet-stream"))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.BlobBodyMiddleware<SomeOperationInput, SomeOperationOutput>(keyPath: \.content))
operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.FlexibleChecksumsRequestMiddleware<SomeOperationInput, SomeOperationOutput>(checksumAlgorithm: input.checksumAlgorithm?.rawValue))
operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware())
operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware<ClientRuntime.DefaultRetryStrategy, AWSClientRuntime.AWSRetryErrorInfoProvider, SomeOperationOutput>(options: config.retryStrategyOptions))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware<SomeOperationOutput>(responseClosure(decoder: decoder), responseErrorClosure(SomeOperationOutputError.self, decoder: decoder)))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware<SomeOperationOutput>(clientLogMode: config.clientLogMode))
operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.FlexibleChecksumsResponseMiddleware<SomeOperationOutput>(validationMode: true))
let result = try await operation.handleMiddleware(context: context, input: input, next: client.getHandler())
return result
}

}
""".trimIndent()
contents.shouldContainOnlyOnce(expectedContents)
}

private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context =
TestContextGenerator.initContextFrom(smithyFile, serviceShapeId, RestJson1Trait.ID)

val generator = AWSRestJson1ProtocolGenerator()
generator.generateProtocolUnitTests(context.ctx)
context.ctx.delegator.flushWriters()
return context
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
$version: "2.0"

namespace aws.flex.checks

use aws.api#service
use aws.protocols#httpChecksum
use aws.protocols#restJson1

@restJson1
@service(sdkId: "ChecksumTests")
service ChecksumTests {
version: "1.0.0",
operations: [SomeOperation]
}

// Define the operation
@httpChecksum(
requestChecksumRequired: true,
requestAlgorithmMember: "checksumAlgorithm",
requestValidationModeMember: "validationMode",
responseAlgorithms: ["CRC32C", "CRC32", "SHA1", "SHA256"]
)
@http(method: "POST", uri: "/foo")
operation SomeOperation {
input: PutSomethingInput
output: PutSomethingOutput
}

structure PutSomethingInput {
@httpHeader("x-amz-request-algorithm")
checksumAlgorithm: ChecksumAlgorithm

@httpHeader("x-amz-response-validation-mode")
validationMode: ValidationMode

@httpPayload
content: Blob
}

structure PutSomethingOutput {
foo: String
}

enum ChecksumAlgorithm {
CRC32C
CRC32
SHA1
SHA256
}

enum ValidationMode {
ENABLED
}
Loading