Skip to content

Commit

Permalink
Split RuntimeError and RequestRejection by protocol (#2517)
Browse files Browse the repository at this point in the history
As outlined in the [Protocol Specific Errors] of the [Service Builder
Improvements RFC], `RuntimeError` should be split up into smaller,
protocol specific, errors which accurately model the failure cases of
each protocol.

The same goes for `RequestRejection`.

Closes #1703.

[Protocol Specific Errors]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md#protocol-specific-errors
[Service Builder Improvements RFC]: https://github.com/awslabs/smithy-rs/blob/main/design/src/rfcs/rfc0020_service_builder.md
  • Loading branch information
david-perez authored Apr 3, 2023
1 parent 3feb4be commit f708076
Show file tree
Hide file tree
Showing 33 changed files with 856 additions and 483 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ open class ServerCodegenVisitor(

val baseModel = baselineTransform(context.model)
val service = settings.getService(baseModel)
val (protocol, generator) =
val (protocolShape, protocolGeneratorFactory) =
ServerProtocolLoader(
codegenDecorator.protocols(
service.id,
ServerProtocolLoader.DefaultProtocols,
),
)
.protocolFor(context.model, service)
protocolGeneratorFactory = generator
this.protocolGeneratorFactory = protocolGeneratorFactory

model = codegenDecorator.transformModel(service, baseModel)

Expand All @@ -145,7 +145,7 @@ open class ServerCodegenVisitor(
serverSymbolProviders.symbolProvider,
null,
service,
protocol,
protocolShape,
settings,
serverSymbolProviders.unconstrainedShapeSymbolProvider,
serverSymbolProviders.constrainedShapeSymbolProvider,
Expand All @@ -169,7 +169,7 @@ open class ServerCodegenVisitor(
settings.codegenConfig,
codegenContext.expectModuleDocProvider(),
)
protocolGenerator = protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
protocolGenerator = this.protocolGeneratorFactory.buildProtocolGenerator(codegenContext)
}

/**
Expand Down Expand Up @@ -315,7 +315,12 @@ open class ServerCodegenVisitor(
writer: RustWriter,
) {
if (codegenContext.settings.codegenConfig.publicConstrainedTypes || shape.isReachableFromOperationInput()) {
val serverBuilderGenerator = ServerBuilderGenerator(codegenContext, shape, validationExceptionConversionGenerator)
val serverBuilderGenerator = ServerBuilderGenerator(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGenerator.render(rustCrate, writer)

if (codegenContext.settings.codegenConfig.publicConstrainedTypes) {
Expand All @@ -336,7 +341,12 @@ open class ServerCodegenVisitor(

if (!codegenContext.settings.codegenConfig.publicConstrainedTypes) {
val serverBuilderGeneratorWithoutPublicConstrainedTypes =
ServerBuilderGeneratorWithoutPublicConstrainedTypes(codegenContext, shape, validationExceptionConversionGenerator)
ServerBuilderGeneratorWithoutPublicConstrainedTypes(
codegenContext,
shape,
validationExceptionConversionGenerator,
protocolGenerator.protocol,
)
serverBuilderGeneratorWithoutPublicConstrainedTypes.render(rustCrate, writer)

writer.implBlock(codegenContext.symbolProvider.toSymbol(shape)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy

import software.amazon.smithy.rust.codegen.core.rustlang.InlineDependency
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType

Expand All @@ -15,17 +14,11 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
* For a runtime type that is used in the client, or in both the client and the server, use [RuntimeType] directly.
*/
object ServerRuntimeType {
fun forInlineDependency(inlineDependency: InlineDependency) = RuntimeType("crate::${inlineDependency.name}", inlineDependency)
fun router(runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun router(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("routing::Router")

fun runtimeError(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("runtime_error::RuntimeError")

fun requestRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::RequestRejection")

fun responseRejection(runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("rejection::ResponseRejection")

fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")
fun protocol(name: String, path: String, runtimeConfig: RuntimeConfig) =
ServerCargoDependency.smithyHttpServer(runtimeConfig).toType().resolve("proto::$path::$name")

fun protocol(runtimeConfig: RuntimeConfig) = protocol("Protocol", "", runtimeConfig)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
Expand All @@ -35,6 +34,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.StringTraitI
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -67,11 +67,7 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
override val shapeId: ShapeId =
ShapeId.from(codegenContext.settings.codegenConfig.experimentalCustomValidationExceptionWithReasonPleaseDoNotUse)

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -89,7 +85,8 @@ class ValidationExceptionWithReasonConversionGenerator(private val codegenContex
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.util.getTrait
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator
import software.amazon.smithy.rust.codegen.server.smithy.generators.BlobLength
import software.amazon.smithy.rust.codegen.server.smithy.generators.CollectionTraitInfo
Expand All @@ -34,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.TraitInfo
import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator
import software.amazon.smithy.rust.codegen.server.smithy.generators.isKeyConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.isValueConstrained
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.validationErrorMessage

/**
Expand Down Expand Up @@ -66,11 +66,7 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
override val shapeId: ShapeId = SHAPE_ID

override fun renderImplFromConstraintViolationForRequestRejection(): Writable = writable {
val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
override fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable = writable {
rustTemplate(
"""
impl #{From}<ConstraintViolation> for #{RequestRejection} {
Expand All @@ -87,7 +83,8 @@ class SmithyValidationExceptionConversionGenerator(private val codegenContext: S
}
}
""",
*codegenScope,
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"From" to RuntimeType.From,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ import software.amazon.smithy.rust.codegen.core.util.letIf
import software.amazon.smithy.rust.codegen.core.util.redactIfNecessary
import software.amazon.smithy.rust.codegen.core.util.toSnakeCase
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.canReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.hasConstraintTraitOrTargetHasConstraintTrait
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape
import software.amazon.smithy.rust.codegen.server.smithy.traits.ConstraintViolationRustBoxTrait
Expand Down Expand Up @@ -92,6 +92,7 @@ class ServerBuilderGenerator(
val codegenContext: ServerCodegenContext,
private val shape: StructureShape,
private val customValidationExceptionWithReasonConversionGenerator: ValidationExceptionConversionGenerator,
private val protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -148,7 +149,7 @@ class ServerBuilderGenerator(
ServerBuilderConstraintViolations(codegenContext, shape, takeInUnconstrainedTypes, customValidationExceptionWithReasonConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down Expand Up @@ -222,7 +223,8 @@ class ServerBuilderGenerator(
"""
#{Converter:W}
""",
"Converter" to customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(),
"Converter" to
customValidationExceptionWithReasonConversionGenerator.renderImplFromConstraintViolationForRequestRejection(protocol),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.makeOptional
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.ServerRuntimeType
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.withInMemoryInlineModule

/**
Expand All @@ -49,6 +49,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
private val codegenContext: ServerCodegenContext,
shape: StructureShape,
validationExceptionConversionGenerator: ValidationExceptionConversionGenerator,
protocol: ServerProtocol,
) {
companion object {
/**
Expand Down Expand Up @@ -85,7 +86,7 @@ class ServerBuilderGeneratorWithoutPublicConstrainedTypes(
ServerBuilderConstraintViolations(codegenContext, shape, builderTakesInUnconstrainedTypes = false, validationExceptionConversionGenerator)

private val codegenScope = arrayOf(
"RequestRejection" to ServerRuntimeType.requestRejection(codegenContext.runtimeConfig),
"RequestRejection" to protocol.requestRejection(codegenContext.runtimeConfig),
"Structure" to structureSymbol,
"From" to RuntimeType.From,
"TryFrom" to RuntimeType.TryFrom,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.traits.EnumTrait
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol

/**
* Collection of methods that will be invoked by the respective generators to generate code to convert constraint
Expand All @@ -26,7 +27,7 @@ interface ValidationExceptionConversionGenerator {
* Convert from a top-level operation input's constraint violation into
* `aws_smithy_http_server::rejection::RequestRejection`.
*/
fun renderImplFromConstraintViolationForRequestRejection(): Writable
fun renderImplFromConstraintViolationForRequestRejection(protocol: ServerProtocol): Writable

// Simple shapes.
fun stringShapeConstraintViolationImplBlock(stringConstraintsInfo: Collection<StringTraitInfo>): Writable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol

import software.amazon.smithy.model.knowledge.TopDownIndex
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
Expand Down Expand Up @@ -37,12 +36,10 @@ import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerAwsJson
import software.amazon.smithy.rust.codegen.server.smithy.protocols.ServerRestJsonSerializerGenerator
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

private fun allOperations(codegenContext: CodegenContext): List<OperationShape> {
val index = TopDownIndex.of(codegenContext.model)
return index.getContainedOperations(codegenContext.serviceShape).sortedBy { it.id }
}

interface ServerProtocol : Protocol {
/** The path such that `aws_smithy_http_server::proto::$path` points to the protocol's module. */
val protocolModulePath: String

/** Returns the Rust marker struct enjoying `OperationShape`. */
fun markerStruct(): RuntimeType

Expand Down Expand Up @@ -76,6 +73,21 @@ interface ServerProtocol : Protocol {
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false

/** The protocol-specific `RequestRejection` type. **/
fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::rejection::RequestRejection")

/** The protocol-specific `ResponseRejection` type. **/
fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::rejection::ResponseRejection")

/** The protocol-specific `RuntimeError` type. **/
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::$protocolModulePath::runtime_error::RuntimeError")
}

class ServerAwsJsonProtocol(
Expand All @@ -84,6 +96,12 @@ class ServerAwsJsonProtocol(
) : AwsJson(serverCodegenContext, awsJsonVersion), ServerProtocol {
private val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String
get() = when (version) {
is AwsJsonVersion.Json10 -> "aws_json_10"
is AwsJsonVersion.Json11 -> "aws_json_11"
}

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, serverCodegenContext.symbolProvider)) {
Expand All @@ -107,12 +125,8 @@ class ServerAwsJsonProtocol(

override fun markerStruct(): RuntimeType {
return when (version) {
is AwsJsonVersion.Json10 -> {
ServerRuntimeType.protocol("AwsJson1_0", "aws_json_10", runtimeConfig)
}
is AwsJsonVersion.Json11 -> {
ServerRuntimeType.protocol("AwsJson1_1", "aws_json_11", runtimeConfig)
}
is AwsJsonVersion.Json10 -> ServerRuntimeType.protocol("AwsJson1_0", protocolModulePath, runtimeConfig)
is AwsJsonVersion.Json11 -> ServerRuntimeType.protocol("AwsJson1_1", protocolModulePath, runtimeConfig)
}
}

Expand All @@ -139,6 +153,16 @@ class ServerAwsJsonProtocol(
AwsJsonVersion.Json10 -> "new_aws_json_10_router"
AwsJsonVersion.Json11 -> "new_aws_json_11_router"
}

override fun requestRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::RequestRejection")
override fun responseRejection(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::rejection::ResponseRejection")
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("proto::aws_json::runtime_error::RuntimeError")
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand All @@ -150,6 +174,8 @@ class ServerRestJsonProtocol(
) : RestJson(serverCodegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig

override val protocolModulePath: String = "rest_json_1"

override fun structuredDataParser(operationShape: OperationShape): StructuredDataParserGenerator {
fun returnSymbolToParse(shape: Shape): ReturnSymbolToParse =
if (shape.canReachConstrainedShape(codegenContext.model, codegenContext.symbolProvider)) {
Expand All @@ -173,7 +199,7 @@ class ServerRestJsonProtocol(
override fun structuredDataSerializer(operationShape: OperationShape): StructuredDataSerializerGenerator =
ServerRestJsonSerializerGenerator(serverCodegenContext, httpBindingResolver)

override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", "rest_json_1", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestJson1", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand All @@ -196,8 +222,9 @@ class ServerRestXmlProtocol(
codegenContext: CodegenContext,
) : RestXml(codegenContext), ServerProtocol {
val runtimeConfig = codegenContext.runtimeConfig
override val protocolModulePath = "rest_xml"

override fun markerStruct() = ServerRuntimeType.protocol("RestXml", "rest_xml", runtimeConfig)
override fun markerStruct() = ServerRuntimeType.protocol("RestXml", protocolModulePath, runtimeConfig)

override fun routerType() = restRouterType(runtimeConfig)

Expand Down
Loading

0 comments on commit f708076

Please sign in to comment.