Skip to content

Commit

Permalink
feat(codegen): generate client side error correction (#958)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd authored Sep 26, 2023
1 parent 2d4403f commit 39f85c3
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 23 deletions.
5 changes: 5 additions & 0 deletions .changes/d4722cf5-e6ef-4b42-b869-d854fd80be51.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "d4722cf5-e6ef-4b42-b869-d854fd80be51",
"type": "feature",
"description": "Generate client side error correction for @required members"
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ object RuntimeTypes {
val merge = "kotlinx.coroutines.flow.merge".toSymbol()
val map = "kotlinx.coroutines.flow.map".toSymbol()
val take = "kotlinx.coroutines.flow.take".toSymbol()
val drop = "kotlinx.coroutines.flow.drop".toSymbol()
val single = "kotlinx.coroutines.flow.single".toSymbol()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@ import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.ClientErrorCorrection
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
import software.amazon.smithy.model.shapes.*
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpLabelTrait
import software.amazon.smithy.model.traits.HttpQueryParamsTrait
import software.amazon.smithy.model.traits.HttpQueryTrait
import software.amazon.smithy.model.traits.LengthTrait
import software.amazon.smithy.model.traits.RetryableTrait
import software.amazon.smithy.model.traits.SensitiveTrait
import software.amazon.smithy.model.traits.StreamingTrait
import software.amazon.smithy.model.traits.*

/**
* Renders Smithy structure shapes
Expand Down Expand Up @@ -285,6 +279,26 @@ class StructureGenerator(
.write("this.#L = #Q.invoke(block)", memberName, memberSymbol)
.closeBlock("}")
}

write("")

// render client side error correction function to set @required members to a default
withBlock(
"internal fun correctErrors(): Builder {",
"}",
) {
sortedMembers
.filter(MemberShape::isRequired)
.filterNot {
val target = ctx.model.expectShape(it.target)
target.isStreaming
}
.forEach {
val correctedValue = ClientErrorCorrection.defaultValue(ctx, it, writer)
write("if (#1L == null) #1L = #2L", ctx.symbolProvider.toMemberName(it), correctedValue)
}
write("return this")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ import software.amazon.smithy.kotlin.codegen.core.*
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.lang.toEscapedLiteral
import software.amazon.smithy.kotlin.codegen.model.*
import software.amazon.smithy.kotlin.codegen.rendering.serde.deserializerName
import software.amazon.smithy.kotlin.codegen.rendering.serde.formatInstant
import software.amazon.smithy.kotlin.codegen.rendering.serde.parseInstant
import software.amazon.smithy.kotlin.codegen.rendering.serde.serializerName
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.knowledge.HttpBinding
import software.amazon.smithy.model.shapes.*
Expand Down Expand Up @@ -573,7 +570,6 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
writer: KotlinWriter,
) {
writer
.addImport(RuntimeTypes.Core.ExecutionContext)
.openBlock(
"override suspend fun deserialize(context: #T, call: #T): #T {",
RuntimeTypes.Core.ExecutionContext,
Expand Down Expand Up @@ -618,6 +614,10 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator {
renderDeserializeResponseCode(ctx, it, writer)
}
}
// Render client side error correction for `@required` members.
// NOTE: nested members bound via the document/payload will be handled by the deserializer for the relevant
// content type. All other members (e.g. bound via REST semantics) will be corrected here.
.write("builder.correctErrors()")
.write("return builder.build()")
.closeBlock("}")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.kotlin.codegen.rendering.serde

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.kotlin.codegen.core.CodegenContext
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.model.isEnum
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ShapeType

object ClientErrorCorrection {
/**
* Determine the default value for a required member based on
* [client error correction](https://smithy.io/2.0/spec/aggregate-types.html?highlight=error%20correction#client-error-correction)
*
* @param ctx the generation context
* @param member the target member shape to get the default value for
* @param writer the writer the default value will be written to, this is used for certain shapes to format the
* default value which will mutate the writer (e.g. add imports).
* @return default value expression as a string
*/
fun defaultValue(
ctx: CodegenContext,
member: MemberShape,
writer: KotlinWriter,
): String {
val target = ctx.model.expectShape(member.target)
val targetSymbol = ctx.symbolProvider.toSymbol(target)

// In IDL v1 all enums were `ShapeType.STRING` and you had to explicitly check for the @enum trait, this handles
// the differences in IDL versions
if (target.isEnum) {
return writer.format("#T.SdkUnknown(#S)", targetSymbol, "no value provided")
}

return when (target.type) {
ShapeType.BLOB -> "ByteArray(0)"
ShapeType.BOOLEAN -> "false"
ShapeType.STRING -> "\"\""
ShapeType.BYTE -> "0.toByte()"
ShapeType.SHORT -> "0.toShort()"
ShapeType.INTEGER -> "0"
ShapeType.LONG -> "0L"
ShapeType.FLOAT -> "0f"
ShapeType.DOUBLE -> "0.0"
ShapeType.BIG_INTEGER -> writer.format("#T(\"0\")", RuntimeTypes.Core.Content.BigInteger)
ShapeType.BIG_DECIMAL -> writer.format("#T(\"0\")", RuntimeTypes.Core.Content.BigDecimal)
ShapeType.DOCUMENT -> "null"
ShapeType.UNION -> writer.format("#T.SdkUnknown", targetSymbol)
ShapeType.LIST,
ShapeType.SET,
-> "emptyList()"
ShapeType.MAP -> "emptyMap()"
ShapeType.STRUCTURE -> writer.format("#T.Builder().correctErrors().build()", targetSymbol)
ShapeType.TIMESTAMP -> writer.format("#T.fromEpochSeconds(0)", RuntimeTypes.Core.Instant)
else -> throw CodegenException("unexpected member type $member")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ open class JsonParserGenerator(
else -> {
writer.write("val builder = #T.Builder()", symbol)
renderDeserializerBody(ctx, shape, members.toList(), writer)
writer.write("builder.correctErrors()")
writer.write("return builder.build()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ open class XmlParserGenerator(
} else {
writer.write("val builder = #T.Builder()", symbol)
renderDeserializerBody(ctx, shape, members.toList(), writer)
writer.write("builder.correctErrors()")
writer.write("return builder.build()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.test.*
import software.amazon.smithy.model.shapes.IntEnumShape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.shapes.StringShape
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

class EnumGeneratorTest {

@Test
fun `it generates unnamed string enums`() {
fun itGeneratesUnnamedStringEnumsIdlv1() {
val model = """
@enum([
{
Expand All @@ -33,10 +35,35 @@ class EnumGeneratorTest {
@documentation("Documentation for this enum")
string Baz
""".prependNamespaceAndService(namespace = "test").toSmithyModel()
"""
testUnnamedStringEnum(model, "1.0")
}

@Test
fun itGeneratesUnnamedStringEnumsIdlv2() {
val model = """
@documentation("Documentation for this enum")
enum Baz {
FOO,
@documentation("Documentation for bar")
BAR,
}
"""
testUnnamedStringEnum(model, "2.0")
}

private fun testUnnamedStringEnum(modelContents: String, idlVersion: String) {
val model = modelContents.prependNamespaceAndService(version = idlVersion, namespace = "test").toSmithyModel()

val provider = KotlinCodegenPlugin.createSymbolProvider(model, rootNamespace = "test")
val shape = model.expectShape<StringShape>("test#Baz")
val expectedShapeType = when (idlVersion) {
"1.0" -> ShapeType.STRING
else -> ShapeType.ENUM
}
assertEquals(expectedShapeType, shape.type)

val symbol = provider.toSymbol(shape)
val writer = KotlinWriter(TestModelDefault.NAMESPACE)
EnumGenerator(shape, symbol, writer).render()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class StructureGeneratorTest {
public fun quux(block: com.test.model.Qux.Builder.() -> kotlin.Unit) {
this.quux = com.test.model.Qux.invoke(block)
}
internal fun correctErrors(): Builder {
return this
}
}
""".formatForTest()
commonTestContents.shouldContainOnlyOnceWithDiff(expected)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ import kotlin.test.assertTrue
// NOTE: protocol conformance is mostly handled by the protocol tests suite
class HttpBindingProtocolGeneratorTest {
private val defaultModel = loadModelFromResource("http-binding-protocol-generator-test.smithy")
private val modelPrefix = """
@http(method: "POST", uri: "/foo-no-input")
operation Foo {
input: FooRequest
}
""".prependNamespaceAndService(protocol = AwsProtocolModelDeclaration.REST_JSON, operations = listOf("Foo")).trimIndent()

private fun getTransformFileContents(filename: String, testModel: Model = defaultModel): String {
val (ctx, manifest, generator) = testModel.newTestContext()
generator.generateProtocolClient(ctx)
Expand Down Expand Up @@ -380,6 +373,7 @@ internal class SmokeTestOperationDeserializer: HttpDeserialize<SmokeTestResponse
if (payload != null) {
deserializeSmokeTestOperationBody(builder, payload)
}
builder.correctErrors()
return builder.build()
}
}
Expand Down Expand Up @@ -433,6 +427,17 @@ internal class SmokeTestOperationDeserializer: HttpDeserialize<SmokeTestResponse
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun itDeserializesExplicitEnumPayloads() {
val contents = getTransformFileContents("ExplicitEnumOperationDeserializer.kt")
contents.assertBalancedBracesAndParens()
val expectedContents = """
val contents = response.body.readAll()?.decodeToString()
builder.payload1 = contents?.let { MyEnum.fromValue(it) }
"""
contents.shouldContainOnlyOnce(expectedContents)
}

@Test
fun itDeserializesExplicitBlobPayloads() {
val contents = getTransformFileContents("ExplicitBlobOperationDeserializer.kt")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ service Test {
SmokeTest,
DuplicateInputTest,
ExplicitString,
ExplicitEnum,
ExplicitBlob,
ExplicitBlobStream,
ExplicitStruct,
Expand Down Expand Up @@ -112,6 +113,22 @@ structure ExplicitStringResponse {
payload1: String
}

@http(method: "POST", uri: "/explicit/enum")
operation ExplicitEnum {
input: ExplicitEnumRequest,
output: ExplicitEnumResponse
}

structure ExplicitEnumRequest {
@httpPayload
payload1: MyEnum
}

structure ExplicitEnumResponse {
@httpPayload
payload1: MyEnum
}

@http(method: "POST", uri: "/explicit/blob")
operation ExplicitBlob {
input: ExplicitBlobRequest,
Expand Down

0 comments on commit 39f85c3

Please sign in to comment.