Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract builderInstantiator interface to prepare for nullability changes #2988

Merged
merged 2 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.generators.ClientBuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.CodegenTarget
import software.amazon.smithy.rust.codegen.core.smithy.ModuleDocProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol

/**
Expand All @@ -36,4 +38,7 @@ data class ClientCodegenContext(
model, symbolProvider, moduleDocProvider, serviceShape, protocol, settings, CodegenTarget.CLIENT,
) {
val enableUserConfigurableRuntimePlugins: Boolean get() = settings.codegenConfig.enableUserConfigurableRuntimePlugins
override fun builderInstantiator(): BuilderInstantiator {
return ClientBuilderInstantiator(symbolProvider)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.client.smithy.generators

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.map
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

fun ClientCodegenContext.builderInstantiator(): BuilderInstantiator = ClientBuilderInstantiator(symbolProvider)

class ClientBuilderInstantiator(private val symbolProvider: RustSymbolProvider) : BuilderInstantiator {
override fun setField(builder: String, value: Writable, field: MemberShape): Writable {
return setFieldWithSetter(builder, value, field)
}

override fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable?): Writable = writable {
if (BuilderGenerator.hasFallibleBuilder(shape, symbolProvider)) {
rustTemplate(
"$builder.build()#{mapErr}?",
"mapErr" to (
mapErr?.map {
rust(".map_err(#T)", it)
} ?: writable { }
),
)
} else {
rust("$builder.build()")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ private class ClientAwsJsonFactory(private val version: AwsJsonVersion) :
ProtocolGeneratorFactory<OperationGenerator, ClientCodegenContext> {
override fun protocol(codegenContext: ClientCodegenContext): Protocol =
if (compatibleWithAwsQuery(codegenContext.serviceShape, version)) {
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version))
AwsQueryCompatible(codegenContext, AwsJson(codegenContext, version, codegenContext.builderInstantiator()))
} else {
AwsJson(codegenContext, version)
AwsJson(codegenContext, version, codegenContext.builderInstantiator())
}

override fun buildProtocolGenerator(codegenContext: ClientCodegenContext): OperationGenerator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.core.smithy
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator

/**
* [CodegenContext] contains code-generation context that is _common to all_ smithy-rs plugins.
Expand All @@ -17,7 +18,7 @@ import software.amazon.smithy.model.shapes.ShapeId
* If your data is specific to the `rust-client-codegen` client plugin, put it in [ClientCodegenContext] instead.
* If your data is specific to the `rust-server-codegen` server plugin, put it in [ServerCodegenContext] instead.
*/
open class CodegenContext(
abstract class CodegenContext(
/**
* The smithy model.
*
Expand Down Expand Up @@ -89,4 +90,6 @@ open class CodegenContext(
fun expectModuleDocProvider(): ModuleDocProvider = checkNotNull(moduleDocProvider) {
"A ModuleDocProvider must be set on the CodegenContext"
}

abstract fun builderInstantiator(): BuilderInstantiator
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.rust.codegen.core.smithy.generators

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable

/** Abstraction for instantiating builders.
*
* Builder abstractions vary—clients MAY use `build_with_error_correction`, e.g., and builders can vary in fallibility.
* */
interface BuilderInstantiator {
/** Set a field on a builder. */
fun setField(builder: String, value: Writable, field: MemberShape): Writable

/** Finalize a builder, turning it into a built object
* - In the case of builders-of-builders, the value should be returned directly
* - If an error is returned, you MUST use `mapErr` to convert the error type
*/
fun finalizeBuilder(builder: String, shape: StructureShape, mapErr: Writable? = null): Writable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth documenting mapErr.


/** Set a field on a builder using the `$setterName` method. $value will be passed directly. */
fun setFieldWithSetter(builder: String, value: Writable, field: MemberShape) = writable {
rustTemplate("$builder = $builder.${field.setterName()}(#{value})", "value" to value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderInstantiator
import software.amazon.smithy.rust.codegen.core.smithy.generators.serializationError
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.JsonParserGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.parse.StructuredDataParserGenerator
Expand Down Expand Up @@ -122,6 +123,7 @@ class AwsJsonSerializerGenerator(
open class AwsJson(
val codegenContext: CodegenContext,
val awsJsonVersion: AwsJsonVersion,
val builderInstantiator: BuilderInstantiator,
) : Protocol {
private val runtimeConfig = codegenContext.runtimeConfig
private val errorScope = arrayOf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class EventStreamUnmarshallerGenerator(
private val unionShape: UnionShape,
) {
private val model = codegenContext.model
private val builderInstantiator = codegenContext.builderInstantiator()
private val symbolProvider = codegenContext.symbolProvider
private val codegenTarget = codegenContext.target
private val runtimeConfig = codegenContext.runtimeConfig
Expand Down Expand Up @@ -339,6 +340,7 @@ class EventStreamUnmarshallerGenerator(
// TODO(EventStream): Errors on the operation can be disjoint with errors in the union,
// so we need to generate a new top-level Error type for each event stream union.
when (codegenTarget) {
// TODO(https://github.com/awslabs/smithy-rs/issues/1970) It should be possible to unify these branches now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should it be possible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we have the builderInstantiator, I think. But it's possible this is more divergence that can't be handled like that.

CodegenTarget.CLIENT -> {
val target = model.expectShape(member.target, StructureShape::class.java)
val parser = protocol.structuredDataParser().errorParser(target)
Expand All @@ -352,9 +354,19 @@ class EventStreamUnmarshallerGenerator(
})?;
builder.set_meta(Some(generic));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: not sure why we bind the error metadata to a variable named generic.

return Ok(#{UnmarshalledMessage}::Error(
#{OpError}::${member.target.name}(builder.build())
#{OpError}::${member.target.name}(
#{build}
)
))
""",
"build" to builderInstantiator.finalizeBuilder(
"builder", target,
mapErr = {
rustTemplate(
"""|err|#{Error}::unmarshalling(format!("{}", err))""", *codegenScope,
)
},
),
"parser" to parser,
*codegenScope,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,16 @@ import software.amazon.smithy.utils.StringUtils
* Class describing a JSON parser section that can be used in a customization.
*/
sealed class JsonParserSection(name: String) : Section(name) {
data class BeforeBoxingDeserializedMember(val shape: MemberShape) : JsonParserSection("BeforeBoxingDeserializedMember")
data class BeforeBoxingDeserializedMember(val shape: MemberShape) :
JsonParserSection("BeforeBoxingDeserializedMember")

data class AfterTimestampDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterTimestampDeserializedMember")
data class AfterTimestampDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterTimestampDeserializedMember")

data class AfterBlobDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterBlobDeserializedMember")

data class AfterDocumentDeserializedMember(val shape: MemberShape) : JsonParserSection("AfterDocumentDeserializedMember")
data class AfterDocumentDeserializedMember(val shape: MemberShape) :
JsonParserSection("AfterDocumentDeserializedMember")
}

/**
Expand Down Expand Up @@ -100,6 +103,7 @@ class JsonParserGenerator(
private val codegenTarget = codegenContext.target
private val smithyJson = CargoDependency.smithyJson(runtimeConfig).toType()
private val protocolFunctions = ProtocolFunctions(codegenContext)
private val builderInstantiator = codegenContext.builderInstantiator()
private val codegenScope = arrayOf(
"Error" to smithyJson.resolve("deserialize::error::DeserializeError"),
"expect_blob_or_null" to smithyJson.resolve("deserialize::token::expect_blob_or_null"),
Expand Down Expand Up @@ -251,6 +255,7 @@ class JsonParserGenerator(
deserializeMember(member)
}
}

CodegenTarget.SERVER -> {
if (symbolProvider.toSymbol(member).isOptional()) {
withBlock("builder = builder.${member.setterName()}(", ");") {
Expand Down Expand Up @@ -508,12 +513,14 @@ class JsonParserGenerator(
"Builder" to symbolProvider.symbolForBuilder(shape),
)
deserializeStructInner(shape.members())
// Only call `build()` if the builder is not fallible. Otherwise, return the builder.
if (returnSymbolToParse.isUnconstrained) {
rust("Ok(Some(builder))")
} else {
rust("Ok(Some(builder.build()))")
val builder = builderInstantiator.finalizeBuilder(
"builder", shape,
) {
rustTemplate(
"""|err|#{Error}::custom_source("Response was invalid", err)""", *codegenScope,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct to me. This only applies to clients, not servers. It only works because ServerBuilderInstantiator does not rely on the passed in mapErr writable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mapErr defines how to convert the error back into the caller error—since the server never returns an error, it doesn't need to use it.

)
}
rust("Ok(Some(#T))", builder)
}
}
}
Expand Down
Loading