Skip to content

Commit

Permalink
Use CBOR encoded string for marhsalling tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Fahad Zubair committed Aug 30, 2024
1 parent 7bfe038 commit 39d57f3
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 77 deletions.
1 change: 1 addition & 0 deletions codegen-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies {
implementation("org.jsoup:jsoup:1.16.2")
api("software.amazon.smithy:smithy-codegen-core:$smithyVersion")
api("com.moandjiezana.toml:toml4j:0.7.2")
implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:2.13.0")
implementation("software.amazon.smithy:smithy-aws-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-protocol-test-traits:$smithyVersion")
implementation("software.amazon.smithy:smithy-waiters:$smithyVersion")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingReso
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolFunctions
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
import software.amazon.smithy.rust.codegen.core.util.dq
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -447,7 +446,24 @@ class CborParserGenerator(
}

override fun payloadParser(member: MemberShape): RuntimeType {
UNREACHABLE("No protocol using CBOR serialization supports payload binding")
val shape = model.expectShape(member.target)
val returnSymbol = returnSymbolToParse(shape)
check(shape is UnionShape || shape is StructureShape) {
"Payload parser should only be used on structure and union shapes."
}
return protocolFunctions.deserializeFn(shape, fnNameSuffix = "payload") { fnName ->
rustTemplate(
"""
pub(crate) fn $fnName(value: &[u8]) -> #{Result}<#{ReturnType}, #{Error}> {
let decoder = &mut #{Decoder}::new(value);
#{DeserializeMember}
}
""",
"ReturnType" to returnSymbol.symbol,
"DeserializeMember" to deserializeMember(member),
*codegenScope,
)
}
}

override fun operationParser(operationShape: OperationShape): RuntimeType? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,26 @@ class CborSerializerGenerator(
}
}

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3573)
override fun payloadSerializer(member: MemberShape): RuntimeType {
TODO("We only call this when serializing in event streams, which are not supported yet: https://github.com/smithy-lang/smithy-rs/issues/3573")
val target = model.expectShape(member.target)
return protocolFunctions.serializeFn(member, fnNameSuffix = "payload") { fnName ->
rustBlockTemplate(
"pub fn $fnName(input: &#{target}) -> std::result::Result<#{Vec}<u8>, #{Error}>",
*codegenScope,
"target" to symbolProvider.toSymbol(target),
) {
rustTemplate("let mut encoder = #{Encoder}::new(#{Vec}::new());", *codegenScope)
rustBlock("") {
rust("let encoder = &mut encoder;")
when (target) {
is StructureShape -> serializeStructure(StructContext("input", target))
is UnionShape -> serializeUnion(Context(ValueExpression.Reference("input"), target))
else -> throw IllegalStateException("CBOR payloadSerializer only supports structs and unions")
}
}
rustTemplate("#{Ok}(encoder.into_writer())", *codegenScope)
}
}
}

override fun unsetStructure(structure: StructureShape): RuntimeType =
Expand All @@ -223,6 +240,7 @@ class CborSerializerGenerator(
}

val httpDocumentMembers = httpBindingResolver.requestMembers(operationShape, HttpLocation.DOCUMENT)

val inputShape = operationShape.inputShape(model)
return protocolFunctions.serializeFn(operationShape, fnNameSuffix = "input") { fnName ->
rustBlockTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

package software.amazon.smithy.rust.codegen.core.testutil

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.dataformat.cbor.CBORFactory
import software.amazon.smithy.model.Model
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RpcV2Cbor
import java.util.Base64

private fun fillInBaseModel(
protocolName: String,
namespacedProtocolName: String,
extraServiceAnnotations: String = "",
): String =
"""
namespace test
use smithy.framework#ValidationException
use aws.protocols#$protocolName
use $namespacedProtocolName
union TestUnion {
Foo: String,
Expand Down Expand Up @@ -86,22 +90,24 @@ private fun fillInBaseModel(
}
$extraServiceAnnotations
@$protocolName
@${namespacedProtocolName.substringAfter("#")}
service TestService { version: "123", operations: [TestStreamOp] }
"""

object EventStreamTestModels {
private fun restJson1(): Model = fillInBaseModel("restJson1").asSmithyModel()
private fun restJson1(): Model = fillInBaseModel("aws.protocols#restJson1").asSmithyModel()

private fun restXml(): Model = fillInBaseModel("restXml").asSmithyModel()
private fun restXml(): Model = fillInBaseModel("aws.protocols#restXml").asSmithyModel()

private fun awsJson11(): Model = fillInBaseModel("awsJson1_1").asSmithyModel()
private fun awsJson11(): Model = fillInBaseModel("aws.protocols#awsJson1_1").asSmithyModel()

private fun rpcv2Cbor(): Model = fillInBaseModel("smithy.protocols#rpcv2Cbor").asSmithyModel()

private fun awsQuery(): Model =
fillInBaseModel("awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#awsQuery", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

private fun ec2Query(): Model =
fillInBaseModel("ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()
fillInBaseModel("aws.protocols#ec2Query", "@xmlNamespace(uri: \"https://example.com\")").asSmithyModel()

data class TestCase(
val protocolShapeId: String,
Expand All @@ -120,39 +126,67 @@ object EventStreamTestModels {
override fun toString(): String = protocolShapeId
}

private fun base64Encode(input: ByteArray): String {
val encodedBytes = Base64.getEncoder().encode(input)
return String(encodedBytes)
}

private fun createCBORFromJSON(jsonString: String): ByteArray {
val jsonMapper = ObjectMapper()
val cborMapper = ObjectMapper(CBORFactory())
// Parse JSON string to a generic type.
val jsonData = jsonMapper.readValue(jsonString, Any::class.java)
// Convert the parsed data to CBOR.
return cborMapper.writeValueAsBytes(jsonData)
}

private val restJsonTestCase =
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) }

val TEST_CASES =
listOf(
//
// restJson1
//
TestCase(
protocolShapeId = "aws.protocols#restJson1",
model = restJson1(),
mediaType = "application/json",
requestContentType = "application/vnd.amazon.eventstream",
responseContentType = "application/json",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { RestJson(it) },
restJsonTestCase,
//
// rpcV2Cbor
//
restJsonTestCase.copy(
protocolShapeId = "smithy.protocols#rpcv2Cbor",
model = rpcv2Cbor(),
mediaType = "application/cbor",
responseContentType = "application/cbor",
eventStreamMessageContentType = "application/cbor",
validTestStruct = base64Encode(createCBORFromJSON(restJsonTestCase.validTestStruct)),
validMessageWithNoHeaderPayloadTraits = base64Encode(createCBORFromJSON(restJsonTestCase.validMessageWithNoHeaderPayloadTraits)),
validTestUnion = base64Encode(createCBORFromJSON(restJsonTestCase.validTestUnion)),
validSomeError = base64Encode(createCBORFromJSON(restJsonTestCase.validSomeError)),
validUnmodeledError = base64Encode(createCBORFromJSON(restJsonTestCase.validUnmodeledError)),
protocolBuilder = { RpcV2Cbor(it) },
),
//
// awsJson1_1
//
TestCase(
restJsonTestCase.copy(
protocolShapeId = "aws.protocols#awsJson1_1",
model = awsJson11(),
mediaType = "application/x-amz-json-1.1",
requestContentType = "application/x-amz-json-1.1",
responseContentType = "application/x-amz-json-1.1",
eventStreamMessageContentType = "application/json",
validTestStruct = """{"someString":"hello","someInt":5}""",
validMessageWithNoHeaderPayloadTraits = """{"someString":"hello","someInt":5}""",
validTestUnion = """{"Foo":"hello"}""",
validSomeError = """{"Message":"some error"}""",
validUnmodeledError = """{"Message":"unmodeled error"}""",
) { AwsJson(it, AwsJsonVersion.Json11) },
//
// restXml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ open class ServerCodegenVisitor(
.protocolFor(context.model, service)
this.protocolGeneratorFactory = protocolGeneratorFactory

val protocolTransformedModel = ServerProtocolBasedTransformationFactory.createTransformer(protocolShape).transform(baseModel, service)
model = codegenDecorator.transformModel(service, protocolTransformedModel, settings)
model = codegenDecorator.transformModel(service, baseModel, settings)

val serverSymbolProviders =
ServerSymbolProviders.from(
Expand Down Expand Up @@ -210,6 +209,8 @@ open class ServerCodegenVisitor(
.let(AttachValidationExceptionToConstrainedOperationInputsInAllowList::transform)
// Tag aggregate shapes reachable from operation input
.let(ShapesReachableFromOperationInputTagger::transform)
// Remove traits that are not supported by the chosen protocol
.let { ServerProtocolBasedTransformationFactory.transform(it, settings) }
// Normalize event stream operations
.let(EventStreamNormalizer::transform)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.AbstractShapeBuilder
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.model.traits.HttpPayloadTrait
import software.amazon.smithy.model.traits.HttpTrait
import software.amazon.smithy.model.transform.ModelTransformer
import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings
import software.amazon.smithy.utils.SmithyBuilder
import software.amazon.smithy.utils.ToSmithyBuilder

Expand All @@ -21,17 +22,27 @@ import software.amazon.smithy.utils.ToSmithyBuilder
* object that transforms the model and removes specific traits based on the protocol being instantiated.
*/
object ServerProtocolBasedTransformationFactory {
fun createTransformer(protocolShapeId: ShapeId): Transformer =
when (protocolShapeId) {
Rpcv2CborTrait.ID -> Rpcv2Transformer()
else -> IdentityTransformer()
fun transform(
model: Model,
settings: ServerRustSettings,
): Model {
val service = settings.getService(model)
if (!service.hasTrait<Rpcv2CborTrait>()) {
return model
}

interface Transformer {
fun transform(
model: Model,
service: ServiceShape,
): Model
return ModelTransformer.create().mapShapes(model) { shape ->
when (shape) {
is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID)
is MemberShape -> {
shape
.removeTraitIfPresent(HttpLabelTrait.ID)
.removeTraitIfPresent(HttpPayloadTrait.ID)
}

else -> shape
}
}
}

fun <T : Shape, B> T.removeTraitIfPresent(
Expand All @@ -47,36 +58,4 @@ object ServerProtocolBasedTransformationFactory {
this
}
}

class Rpcv2Transformer() : Transformer {
override fun transform(
model: Model,
service: ServiceShape,
): Model {
val transformedModel =
ModelTransformer.create().mapShapes(model) { shape ->
when (shape) {
is OperationShape -> shape.removeTraitIfPresent(HttpTrait.ID)
is MemberShape -> {
shape
.removeTraitIfPresent(HttpLabelTrait.ID)
.removeTraitIfPresent(HttpPayloadTrait.ID)
}

else -> shape
}
}

return transformedModel
}
}

class IdentityTransformer() : Transformer {
override fun transform(
model: Model,
service: ServiceShape,
): Model {
return model
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package software.amazon.smithy.rust.codegen.server.smithy
import io.kotest.inspectors.forAll
import io.kotest.matchers.ints.shouldBeGreaterThan
import io.kotest.matchers.shouldBe
import java.io.File
import org.junit.jupiter.api.Test
import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
Expand All @@ -29,6 +28,7 @@ import software.amazon.smithy.protocol.traits.Rpcv2CborTrait
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.util.lookup
import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverTestSymbolProvider
import java.io.File

enum class ModelProtocol(val trait: AbstractTrait) {
AwsJson10(AwsJson1_0Trait.builder().build()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.protocols.serialize
import org.junit.jupiter.api.Test
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.testutil.IntegrationTestParams
import software.amazon.smithy.rust.codegen.core.testutil.ServerAdditionalSettings
import software.amazon.smithy.rust.codegen.server.smithy.ModelProtocol
import software.amazon.smithy.rust.codegen.server.smithy.loadSmithyConstraintsModelForProtocol
import software.amazon.smithy.rust.codegen.server.smithy.removeOperations
Expand All @@ -15,7 +16,7 @@ class CborConstraintsIntegrationTest {
// Event streaming operations are not supported by `Rpcv2Cbor` implementation.
// https://github.com/smithy-lang/smithy-rs/issues/3573
val nonSupportedOperations =
listOf("EventStreamsOperation", "StreamingBlobOperation")
listOf("StreamingBlobOperation")
.map { ShapeId.from("${serviceShape.namespace}#$it") }
val model =
constraintModel
Expand All @@ -25,6 +26,7 @@ class CborConstraintsIntegrationTest {
model,
IntegrationTestParams(
service = serviceShape.toString(),
additionalSettings = ServerAdditionalSettings.builder().generateCodegenComments().toObjectNode(),
),
) { _, _ ->
}
Expand Down

0 comments on commit 39d57f3

Please sign in to comment.