Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor determining server error type when deserializing an @httpPayload #3752

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.generators.http

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand All @@ -20,12 +19,12 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpBindi
import software.amazon.smithy.rust.codegen.core.smithy.generators.http.HttpMessageType
import software.amazon.smithy.rust.codegen.core.smithy.mapRustType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext
import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol
import software.amazon.smithy.rust.codegen.server.smithy.targetCanReachConstrainedShape

class ServerRequestBindingGenerator(
protocol: Protocol,
val protocol: ServerProtocol,
codegenContext: ServerCodegenContext,
operationShape: OperationShape,
additionalHttpBindingCustomizations: List<HttpBindingCustomization> = listOf(),
Expand All @@ -50,12 +49,11 @@ class ServerRequestBindingGenerator(

fun generateDeserializePayloadFn(
binding: HttpBindingDescriptor,
errorSymbol: Symbol,
structuredHandler: RustWriter.(String) -> Unit,
): RuntimeType =
httpBindingGenerator.generateDeserializePayloadFn(
binding,
errorSymbol,
protocol.deserializePayloadErrorType(binding).toSymbol(),
structuredHandler,
HttpMessageType.REQUEST,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators.protocol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand All @@ -17,7 +18,9 @@ import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingDescriptor
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpBindingResolver
import software.amazon.smithy.rust.codegen.core.smithy.protocols.HttpLocation
import software.amazon.smithy.rust.codegen.core.smithy.protocols.Protocol
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestJson
import software.amazon.smithy.rust.codegen.core.smithy.protocols.RestXml
Expand Down Expand Up @@ -70,8 +73,8 @@ interface ServerProtocol : Protocol {
fun serverRouterRequestSpecType(requestSpecModule: RuntimeType): RuntimeType

/**
* In some protocols, such as restJson1,
* when there is no modeled body input, content type must not be set and the body must be empty.
* In some protocols, such as `restJson1` and `rpcv2Cbor`,
* when there is no modeled body input, `content-type` must not be set and the body must be empty.
* Returns a boolean indicating whether to perform this check.
*/
fun serverContentTypeCheckNoModeledInput(): Boolean = false
Expand All @@ -90,6 +93,19 @@ interface ServerProtocol : Protocol {
fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("protocol::$protocolModulePath::runtime_error::RuntimeError")

/**
* The function that deserializes a payload-bound shape takes as input a byte slab and returns a `Result` holding
* the deserialized shape if successful. What error type should we use in case of failure?
*
* The shape could be payload-bound either because of the `@httpPayload` trait, or because it's part of an event
* stream.
*
* Note that despite the trait (https://smithy.io/2.0/spec/http-bindings.html#httppayload-trait) being able to
* target any structure member shape, AWS Protocols only support binding the following shape types to the payload
* (and Smithy does indeed enforce this at model build-time): string, blob, structure, union, and document
*/
fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType
}

fun returnSymbolToParseFn(codegenContext: ServerCodegenContext): (Shape) -> ReturnSymbolToParse {
Expand Down Expand Up @@ -185,6 +201,18 @@ class ServerAwsJsonProtocol(
override fun runtimeError(runtimeConfig: RuntimeConfig): RuntimeType =
ServerCargoDependency.smithyHttpServer(runtimeConfig)
.toType().resolve("protocol::aws_json::runtime_error::RuntimeError")

/*
* Note that despite the AWS JSON 1.x protocols not supporting the `@httpPayload` trait, event streams are bound
* to the payload.
*/
override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
)
}

private fun restRouterType(runtimeConfig: RuntimeConfig) =
Expand Down Expand Up @@ -227,6 +255,14 @@ class ServerRestJsonProtocol(
override fun serverRouterRuntimeConstructor() = "new_rest_json_router"

override fun serverContentTypeCheckNoModeledInput() = true

override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyJson(codegenContext.runtimeConfig).resolve("deserialize::error::DeserializeError"),
)
}

class ServerRestXmlProtocol(
Expand All @@ -252,6 +288,32 @@ class ServerRestXmlProtocol(
override fun serverRouterRuntimeConstructor() = "new_rest_xml_router"

override fun serverContentTypeCheckNoModeledInput() = true

override fun deserializePayloadErrorType(binding: HttpBindingDescriptor): RuntimeType =
deserializePayloadErrorType(
codegenContext,
binding,
requestRejection(runtimeConfig),
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError"),
)
}

/** Just a common function to keep things DRY. **/
fun deserializePayloadErrorType(
codegenContext: CodegenContext,
binding: HttpBindingDescriptor,
requestRejection: RuntimeType,
protocolSerializationFormatError: RuntimeType,
): RuntimeType {
check(binding.location == HttpLocation.PAYLOAD)

if (codegenContext.model.expectShape(binding.member.target) is StringShape) {
// The only way deserializing a string can fail is if the HTTP body does not contain valid UTF-8.
// TODO(https://github.com/smithy-lang/smithy-rs/issues/3750): we're returning an incorrect `RequestRejection` variant here.
return requestRejection
}

return protocolSerializationFormatError
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package software.amazon.smithy.rust.codegen.server.smithy.protocols

import software.amazon.smithy.aws.traits.protocols.AwsJson1_0Trait
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.knowledge.HttpBindingIndex
import software.amazon.smithy.model.node.ExpectationNotMetException
Expand All @@ -20,7 +16,6 @@ import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.NumberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StringShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.model.traits.HttpErrorTrait
Expand Down Expand Up @@ -124,13 +119,7 @@ class ServerHttpBoundProtocolGenerator(
) : ServerProtocolGenerator(
protocol,
ServerHttpBoundProtocolTraitImplGenerator(codegenContext, protocol, customizations, additionalHttpBindingCustomizations),
) {
// Define suffixes for operation input / output / error wrappers
companion object {
const val OPERATION_INPUT_WRAPPER_SUFFIX = "OperationInputWrapper"
const val OPERATION_OUTPUT_WRAPPER_SUFFIX = "OperationOutputWrapper"
}
}
)

class ServerHttpBoundProtocolPayloadGenerator(
codegenContext: CodegenContext,
Expand Down Expand Up @@ -697,8 +686,6 @@ class ServerHttpBoundProtocolTraitImplGenerator(
inputShape: StructureShape,
bindings: List<HttpBindingDescriptor>,
) {
val httpBindingGenerator =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
val structuredDataParser = protocol.structuredDataParser()
Attribute.AllowUnusedMut.render(this)
rust(
Expand Down Expand Up @@ -740,7 +727,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
for (binding in bindings) {
val member = binding.member
val parsedValue =
serverRenderBindingParser(binding, operationShape, httpBindingGenerator, structuredDataParser)
serverRenderBindingParser(binding, operationShape, httpBindingGenerator(operationShape), structuredDataParser)
val valueToSet =
if (symbolProvider.toSymbol(binding.member).isOptional()) {
"Some(value)"
Expand Down Expand Up @@ -801,13 +788,8 @@ class ServerHttpBoundProtocolTraitImplGenerator(
val structureShapeHandler: RustWriter.(String) -> Unit = { body ->
rust("#T($body)", structuredDataParser.payloadParser(binding.member))
}
val errorSymbol = getDeserializePayloadErrorSymbol(binding)
val deserializer =
httpBindingGenerator.generateDeserializePayloadFn(
binding,
errorSymbol,
structuredHandler = structureShapeHandler,
)
httpBindingGenerator.generateDeserializePayloadFn(binding, structuredHandler = structureShapeHandler)
return writable {
if (binding.member.isStreaming(model)) {
rustTemplate(
Expand Down Expand Up @@ -1196,9 +1178,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
binding: HttpBindingDescriptor,
operationShape: OperationShape,
) {
val httpBindingGenerator =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
val deserializer = httpBindingGenerator.generateDeserializeHeaderFn(binding)
val deserializer = httpBindingGenerator(operationShape).generateDeserializeHeaderFn(binding)
writer.rustTemplate(
"""
#{deserializer}(&headers)?
Expand All @@ -1215,8 +1195,7 @@ class ServerHttpBoundProtocolTraitImplGenerator(
) {
check(binding.location == HttpLocation.PREFIX_HEADERS)

val httpBindingGenerator = ServerRequestBindingGenerator(protocol, codegenContext, operationShape)
val deserializer = httpBindingGenerator.generateDeserializePrefixHeadersFn(binding)
val deserializer = httpBindingGenerator(operationShape).generateDeserializePrefixHeadersFn(binding)
writer.rustTemplate(
"""
#{deserializer}(&headers)?
Expand Down Expand Up @@ -1300,33 +1279,13 @@ class ServerHttpBoundProtocolTraitImplGenerator(
}
}

/**
* Returns the error type of the function that deserializes a non-streaming HTTP payload (a byte slab) into the
* shape targeted by the `httpPayload` trait.
*/
private fun getDeserializePayloadErrorSymbol(binding: HttpBindingDescriptor): Symbol {
check(binding.location == HttpLocation.PAYLOAD)

if (model.expectShape(binding.member.target) is StringShape) {
return protocol.requestRejection(runtimeConfig).toSymbol()
}
return when (codegenContext.protocol) {
RestJson1Trait.ID, AwsJson1_0Trait.ID, AwsJson1_1Trait.ID -> {
RuntimeType.smithyJson(runtimeConfig).resolve("deserialize::error::DeserializeError").toSymbol()
}
RestXmlTrait.ID -> {
RuntimeType.smithyXml(runtimeConfig).resolve("decode::XmlDecodeError").toSymbol()
}
else -> {
TODO("Protocol ${codegenContext.protocol} not supported yet")
}
}
}

private fun streamingBodyTraitBounds(operationShape: OperationShape) =
if (operationShape.inputShape(model).hasStreamingMember(model)) {
"\n B: Into<#{SmithyTypes}::byte_stream::ByteStream>,"
} else {
""
}

private fun httpBindingGenerator(operationShape: OperationShape) =
ServerRequestBindingGenerator(protocol, codegenContext, operationShape, additionalHttpBindingCustomizations)
}