Skip to content

Commit

Permalink
refactor XML deserialize (#1233)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd authored Feb 29, 2024
1 parent a65dc90 commit 44b8249
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@ import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingPr
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.*

Expand Down Expand Up @@ -68,24 +67,6 @@ private class AwsQuerySerdeFormUrlDescriptorGenerator(
member.hasTrait<XmlFlattenedTrait>()
}

private class AwsQuerySerdeXmlDescriptorGenerator(
ctx: RenderingContext<Shape>,
memberShapes: List<MemberShape>? = null,
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {

override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
val traits = super.getObjectDescriptorTraits().toMutableList()

if (objectShape.hasTrait<OperationOutput>()) {
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
val serialName = objectShape.changeNameSuffix("Response" to "Result")
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
}

return traits
}
}

private class AwsQuerySerializerGenerator(
private val protocolGenerator: AwsQuery,
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
Expand All @@ -98,50 +79,76 @@ private class AwsQuerySerializerGenerator(
}

private class AwsQueryXmlParserGenerator(
private val protocolGenerator: AwsQuery,
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {

override fun descriptorGenerator(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
writer: KotlinWriter,
): XmlSerdeDescriptorGenerator = AwsQuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)

override fun renderDeserializeOperationBody(
ctx: ProtocolGenerator.GenerationContext,
op: OperationShape,
documentMembers: List<MemberShape>,
writer: KotlinWriter,
) {
writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer)
unwrapOperationResponseBody(op.id.name, writer)
val shape = ctx.model.expectShape(op.output.get())
renderDeserializerBody(ctx, shape, documentMembers, writer)
}
protocolGenerator: AwsQuery,
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {

/**
* Unwraps the response body as specified by
* https://awslabs.github.io/smithy/1.0/spec/aws/aws-query-protocol.html#response-serialization so that the
* deserializer is in the correct state.
*
* ```
* <SomeOperationResponse>
* <SomeOperationResult>
* <-- SAME AS REST XML -->
* </SomeOperationResult>
*</SomeOperationResponse>
* ```
*/
private fun unwrapOperationResponseBody(
operationName: String,
override fun unwrapOperationBody(
ctx: ProtocolGenerator.GenerationContext,
serdeCtx: SerdeCtx,
op: OperationShape,
writer: KotlinWriter,
) {
writer.write("// begin unwrap response wrapper")
.write("val resultDescriptor = #T(#T.Struct, #T(#S))", RuntimeTypes.Serde.SdkFieldDescriptor, RuntimeTypes.Serde.SerialKind, RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Result")
.withBlock("val wrapperDescriptor = #T.build {", "}", RuntimeTypes.Serde.SdkObjectDescriptor) {
write("trait(#T(#S))", RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Response")
write("#T(resultDescriptor)", RuntimeTypes.Serde.field)
): SerdeCtx {
val operationName = op.id.getName(ctx.service)

val unwrapAwsQueryOperation = buildSymbol {
name = "unwrapAwsQueryResponse"
namespace = ctx.settings.pkg.serde
definitionFile = "AwsQueryUtil.kt"
renderBy = { writer ->

writer.withBlock(
"internal fun $name(root: #1T, operationName: #2T): #1T {",
"}",
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
KotlinTypes.String,
) {
write("val responseWrapperName = \"\${operationName}Response\"")
write("val resultWrapperName = \"\${operationName}Result\"")
withBlock(
"if (root.tagName != responseWrapperName) {",
"}",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected \$responseWrapperName; found `\${root.tag}`")
}

write("val resultTag = ${serdeCtx.tagReader}.nextTag()")
withBlock(
"if (resultTag == null || resultTag.tagName != resultWrapperName) {",
"}",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid result, expected \$resultWrapperName; found `\${resultTag?.tag}`")
}

write("return resultTag")
}
}
.write("")
// abandon the iterator, this only occurs at the top level operational output
.write("val wrapper = deserializer.#T(wrapperDescriptor)", RuntimeTypes.Serde.deserializeStruct)
.withBlock("if (wrapper.findNextFieldIndex() != resultDescriptor.index) {", "}") {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "failed to unwrap $operationName response")
}
.write("// end unwrap response wrapper")
.write("")
}

writer.write("val unwrapped = #T(#L, #S)", unwrapAwsQueryOperation, serdeCtx.tagReader, operationName)

return SerdeCtx("unwrapped")
}

override fun unwrapOperationError(
ctx: ProtocolGenerator.GenerationContext,
serdeCtx: SerdeCtx,
errorShape: StructureShape,
writer: KotlinWriter,
): SerdeCtx {
writer.write("val errReader = #T(${serdeCtx.tagReader})", RestXmlErrors.wrappedErrorResponseDeserializer(ctx))
return SerdeCtx("errReader")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@ package software.amazon.smithy.kotlin.codegen.aws.protocols

import software.amazon.smithy.aws.traits.protocols.Ec2QueryNameTrait
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RenderingContext
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.model.changeNameSuffix
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
import software.amazon.smithy.kotlin.codegen.model.getTrait
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.kotlin.codegen.utils.dq
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.XmlNameTrait

Expand Down Expand Up @@ -73,24 +72,6 @@ private class Ec2QuerySerdeFormUrlDescriptorGenerator(
targetShape.type == ShapeType.LIST
}

private class Ec2QuerySerdeXmlDescriptorGenerator(
ctx: RenderingContext<Shape>,
memberShapes: List<MemberShape>? = null,
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {

override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
val traits = super.getObjectDescriptorTraits().toMutableList()

if (objectShape.hasTrait<OperationOutput>()) {
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
val serialName = objectShape.changeNameSuffix("Response" to "Result")
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
}

return traits
}
}

private class Ec2QuerySerializerGenerator(
private val protocolGenerator: Ec2Query,
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
Expand All @@ -104,13 +85,73 @@ private class Ec2QuerySerializerGenerator(
}

private class Ec2QueryParserGenerator(
private val protocolGenerator: Ec2Query,
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {

override fun descriptorGenerator(
protocolGenerator: Ec2Query,
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {
override fun unwrapOperationError(
ctx: ProtocolGenerator.GenerationContext,
shape: Shape,
members: List<MemberShape>,
serdeCtx: SerdeCtx,
errorShape: StructureShape,
writer: KotlinWriter,
): XmlSerdeDescriptorGenerator = Ec2QuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)
): SerdeCtx {
val unwrapFn = unwrapErrorResponse(ctx)
writer.write("val errReader = #T(${serdeCtx.tagReader})", unwrapFn)
return SerdeCtx("errReader")
}

/**
* Error deserializer for a wrapped error response
*
* ```
* <Response>
* <Errors>
* <Error>
* <-- DATA -->>
* </Error>
* </Errors>
* </Response>
* ```
*
* See https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization
*/
private fun unwrapErrorResponse(ctx: ProtocolGenerator.GenerationContext): Symbol = buildSymbol {
name = "unwrapXmlErrorResponse"
namespace = ctx.settings.pkg.serde
definitionFile = "XmlErrorUtils.kt"
renderBy = { writer ->
writer.dokka("Handle [wrapped](https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization) error responses")
writer.withBlock(
"internal fun $name(root: #1T): #1T {",
"}",
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
) {
withBlock(
"if (root.tagName != #S) {",
"}",
"Response",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected <Response>; found `\${root.tag}`")
}

write("val errorsTag = root.nextTag()")
withBlock(
"if (errorsTag == null || errorsTag.tagName != #S) {",
"}",
"Errors",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Errors>; found `\${errorsTag?.tag}`")
}

write("val errTag = errorsTag.nextTag()")
withBlock(
"if (errTag == null || errTag.tagName != #S) {",
"}",
"Error",
) {
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Error>; found `\${errTag?.tag}`")
}

write("return errTag")
}
}
}
}
Loading

0 comments on commit 44b8249

Please sign in to comment.