Skip to content

Commit

Permalink
Refactor/fix S3 express integrations
Browse files Browse the repository at this point in the history
  • Loading branch information
0marperez committed Dec 6, 2024
1 parent 2761bf9 commit 5328641
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware
import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigProperty
import software.amazon.smithy.kotlin.codegen.rendering.util.ConfigPropertyType
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.*
Expand Down Expand Up @@ -99,7 +97,6 @@ class S3ExpressIntegration : KotlinIntegration {
resolved + listOf(
addClientToExecutionContext,
addBucketToExecutionContext,
useCrc32Checksum,
uploadPartDisableChecksum,
)

Expand Down Expand Up @@ -132,44 +129,14 @@ class S3ExpressIntegration : KotlinIntegration {
}
}

/**
* For any operations that require a checksum, set CRC32 if the user has not already configured a checksum.
*/
private val useCrc32Checksum = object : ProtocolMiddleware {
override val name: String = "UseCrc32Checksum"

override val order: Byte = -1 // Render before flexible checksums

override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean = !op.isS3UploadPart &&
(op.hasTrait<HttpChecksumRequiredTrait>() || (op.hasTrait<HttpChecksumTrait>() && op.expectTrait<HttpChecksumTrait>().isRequestChecksumRequired))

override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
val interceptorSymbol = buildSymbol {
namespace = "aws.sdk.kotlin.services.s3.express"
name = "S3ExpressCrc32ChecksumInterceptor"
}

val httpChecksumTrait = op.getTrait<HttpChecksumTrait>()

val checksumAlgorithmMember = ctx.model.expectShape<StructureShape>(op.input.get())
.members()
.firstOrNull { it.memberName == httpChecksumTrait?.requestAlgorithmMember?.getOrNull() }

// S3 models a header name x-amz-sdk-checksum-algorithm representing the name of the checksum algorithm used
val checksumHeaderName = checksumAlgorithmMember?.getTrait<HttpHeaderTrait>()?.value

writer.write("op.interceptors.add(#T(${checksumHeaderName?.dq() ?: ""}))", interceptorSymbol)
}
}

/**
* Disable all checksums for s3:UploadPart
*/
private val uploadPartDisableChecksum = object : ProtocolMiddleware {
override val name: String = "UploadPartDisableChecksum"

override fun isEnabledFor(ctx: ProtocolGenerator.GenerationContext, op: OperationShape): Boolean =
op.isS3UploadPart
op.isS3UploadPart && op.hasTrait<HttpChecksumTrait>()

override fun render(ctx: ProtocolGenerator.GenerationContext, op: OperationShape, writer: KotlinWriter) {
val interceptorSymbol = buildSymbol {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ package aws.sdk.kotlin.services.s3.express

import aws.smithy.kotlin.runtime.client.ProtocolRequestInterceptorContext
import aws.smithy.kotlin.runtime.collections.AttributeKey
import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder
import aws.smithy.kotlin.runtime.http.HeadersBuilder
import aws.smithy.kotlin.runtime.http.interceptors.HttpInterceptor
import aws.smithy.kotlin.runtime.http.operation.HttpOperationContext
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.request.toBuilder
import aws.smithy.kotlin.runtime.telemetry.logging.logger
import kotlin.coroutines.coroutineContext

private const val CHECKSUM_HEADER_PREFIX = "x-amz-checksum-"

/**
* Disable checksums entirely for s3:UploadPart requests.
* Disables checksums for s3:UploadPart requests that use S3 express.
*/
internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor {
override suspend fun modifyBeforeSigning(context: ProtocolRequestInterceptorContext<Any, HttpRequest>): HttpRequest {
Expand All @@ -22,14 +26,45 @@ internal class S3ExpressDisableChecksumInterceptor : HttpInterceptor {
}

val logger = coroutineContext.logger<S3ExpressDisableChecksumInterceptor>()
logger.warn { "Checksums must not be sent with S3 express upload part operation, removing checksum(s)" }

val request = context.protocolRequest.toBuilder()

request.headers.removeChecksumHeaders()
request.trailingHeaders.removeChecksumTrailingHeaders()
request.headers.removeChecksumTrailingHeadersFromXAmzTrailer()

return request.build()
}
}

val configuredChecksumAlgorithm = context.executionContext.getOrNull(HttpOperationContext.ChecksumAlgorithm)
/**
* Removes any checksums sent in the request's headers
*/
internal fun HeadersBuilder.removeChecksumHeaders(): Unit =
names().forEach { name ->
if (name.startsWith(CHECKSUM_HEADER_PREFIX)) {
remove(name)
}
}

configuredChecksumAlgorithm?.let {
logger.warn { "Disabling configured checksum $it for S3 Express UploadPart" }
context.executionContext.remove(HttpOperationContext.ChecksumAlgorithm)
/**
* Removes any checksums sent in the request's trailing headers
*/
internal fun DeferredHeadersBuilder.removeChecksumTrailingHeaders(): Unit =
names().forEach { name ->
if (name.startsWith(CHECKSUM_HEADER_PREFIX)) {
remove(name)
}
}

return context.protocolRequest
/**
* Removes any checksums sent in the request's trailing headers from `x-amz-trailer`
*/
internal fun HeadersBuilder.removeChecksumTrailingHeadersFromXAmzTrailer() {
this.getAll("x-amz-trailer")?.forEach { trailingHeader ->
if (trailingHeader.startsWith(CHECKSUM_HEADER_PREFIX)) {
this.remove("x-amz-trailer", trailingHeader)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package aws.sdk.kotlin.services.s3.express

import aws.smithy.kotlin.runtime.http.DeferredHeadersBuilder
import aws.smithy.kotlin.runtime.http.HeadersBuilder
import kotlin.test.Test
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class ChecksumRemovalTest {
@Test
fun removeChecksumHeaders() {
val headers = HeadersBuilder()

headers.append("x-amz-checksum-crc32", "foo")
headers.append("x-amz-checksum-sha256", "bar")

assertTrue(
headers.contains("x-amz-checksum-crc32"),
)
assertTrue(
headers.contains("x-amz-checksum-sha256"),
)

headers.removeChecksumHeaders()

assertFalse(
headers.contains("x-amz-checksum-crc32"),
)
assertFalse(
headers.contains("x-amz-checksum-sha256"),
)
}

@Test
fun removeChecksumTrailingHeaders() {
val trailingHeaders = DeferredHeadersBuilder()

trailingHeaders.add("x-amz-checksum-crc32", "foo")
trailingHeaders.add("x-amz-checksum-sha256", "bar")

assertTrue(
trailingHeaders.contains("x-amz-checksum-crc32"),
)
assertTrue(
trailingHeaders.contains("x-amz-checksum-sha256"),
)

trailingHeaders.removeChecksumTrailingHeaders()

assertFalse(
trailingHeaders.contains("x-amz-checksum-crc32"),
)
assertFalse(
trailingHeaders.contains("x-amz-checksum-sha256"),
)
}

@Test
fun removeChecksumTrailingHeadersFromXAmzTrailer() {
val headers = HeadersBuilder()

headers.append("x-amz-trailer", "x-amz-checksum-crc32")
headers.append("x-amz-trailer", "x-amz-trailing-header")

val xAmzTrailer = headers.getAll("x-amz-trailer")

assertTrue(
xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false,
)
assertTrue(
xAmzTrailer?.contains("x-amz-trailing-header") ?: false,
)

headers.removeChecksumTrailingHeadersFromXAmzTrailer()

assertFalse(
xAmzTrailer?.contains("x-amz-checksum-crc32") ?: false,
)
assertTrue(
xAmzTrailer?.contains("x-amz-trailing-header") ?: false,
)
}
}

0 comments on commit 5328641

Please sign in to comment.