diff --git a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt index c443110805..9fdab3965d 100644 --- a/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt +++ b/codegen-core/src/main/kotlin/software/amazon/smithy/rust/codegen/core/testutil/Rust.kt @@ -294,7 +294,7 @@ class TestWriterDelegator( } /** - * Generate a newtest module + * Generate a new test module * * This should only be used in test codeā€”the generated module name will be something like `tests_123` */ diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt index 956a88ecf2..779fcb1686 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCargoDependency.kt @@ -22,6 +22,7 @@ object ServerCargoDependency { val Nom: CargoDependency = CargoDependency("nom", CratesIo("7")) val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.13")) val PinProjectLite: CargoDependency = CargoDependency("pin-project-lite", CratesIo("0.2")) + val ThisError: CargoDependency = CargoDependency("thiserror", CratesIo("1.0")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TokioDev: CargoDependency = CargoDependency("tokio", CratesIo("1.23.1"), scope = DependencyScope.Dev) val Regex: CargoDependency = CargoDependency("regex", CratesIo("1.5.5")) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt index 6f66c72166..adfa466257 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/ServerCodegenVisitor.kt @@ -78,6 +78,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.generators.Unconstraine import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedMapGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.UnconstrainedUnionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator +import software.amazon.smithy.rust.codegen.server.smithy.generators.isBuilderFallible import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocol import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolTestGenerator @@ -591,11 +592,15 @@ open class ServerCodegenVisitor( logger.info("[rust-server-codegen] Generating a service $shape") val serverProtocol = protocolGeneratorFactory.protocol(codegenContext) as ServerProtocol + val configMethods = codegenDecorator.configMethods(codegenContext) + val isConfigBuilderFallible = configMethods.isBuilderFallible() + // Generate root. rustCrate.lib { ServerRootGenerator( serverProtocol, codegenContext, + isConfigBuilderFallible, ).render(this) } @@ -612,9 +617,10 @@ open class ServerCodegenVisitor( ServerServiceGenerator( codegenContext, serverProtocol, + isConfigBuilderFallible, ).render(this) - ServiceConfigGenerator(codegenContext).render(this) + ServiceConfigGenerator(codegenContext, configMethods).render(this) ScopeMacroGenerator(codegenContext).render(this) } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt index 06f9c3c09b..5470c0902c 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customize/ServerCodegenDecorator.kt @@ -15,6 +15,7 @@ import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolMap import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext import software.amazon.smithy.rust.codegen.server.smithy.ServerRustSettings import software.amazon.smithy.rust.codegen.server.smithy.ValidationResult +import software.amazon.smithy.rust.codegen.server.smithy.generators.ConfigMethod import software.amazon.smithy.rust.codegen.server.smithy.generators.ValidationExceptionConversionGenerator import software.amazon.smithy.rust.codegen.server.smithy.generators.protocol.ServerProtocolGenerator import java.util.logging.Logger @@ -41,6 +42,12 @@ interface ServerCodegenDecorator : CoreCodegenDecorator = emptyList() + + /** + * Configuration methods that should be injected into the `${serviceName}Config` struct to allow users to configure + * pre-applied layers and plugins. + */ + fun configMethods(codegenContext: ServerCodegenContext): List = emptyList() } /** @@ -74,10 +81,11 @@ class CombinedServerCodegenDecorator(decorators: List) : decorator.postprocessValidationExceptionNotAttachedErrorMessage(accumulated) } - override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List { - return orderedDecorators.map { decorator -> decorator.postprocessGenerateAdditionalStructures(operationShape) } - .flatten() - } + override fun postprocessGenerateAdditionalStructures(operationShape: OperationShape): List = + orderedDecorators.flatMap { it.postprocessGenerateAdditionalStructures(operationShape) } + + override fun configMethods(codegenContext: ServerCodegenContext): List = + orderedDecorators.flatMap { it.configMethods(codegenContext) } companion object { fun fromClasspath( diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt index a1ea4b90f6..d02c552844 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerRootGenerator.kt @@ -31,6 +31,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output open class ServerRootGenerator( val protocol: ServerProtocol, private val codegenContext: ServerCodegenContext, + private val isConfigBuilderFallible: Boolean, ) { private val index = TopDownIndex.of(codegenContext.model) private val operations = index.getContainedOperations(codegenContext.serviceShape).toSortedSet( @@ -57,6 +58,8 @@ open class ServerRootGenerator( } .join("//!\n") + val unwrapConfigBuilder = if (isConfigBuilderFallible) ".expect(\"config failed to build\")" else "" + writer.rustTemplate( """ //! A fast and customizable Rust implementation of the $serviceName Smithy service. @@ -75,7 +78,10 @@ open class ServerRootGenerator( //! ## async fn dummy() { //! use $crateName::{$serviceName, ${serviceName}Config}; //! - //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked(); + //! ## let app = $serviceName::builder( + //! ## ${serviceName}Config::builder() + //! ## .build()$unwrapConfigBuilder + //! ## ).build_unchecked(); //! let server = app.into_make_service(); //! let bind: SocketAddr = "127.0.0.1:6969".parse() //! .expect("unable to parse the server bind address and port"); @@ -92,7 +98,10 @@ open class ServerRootGenerator( //! use $crateName::$serviceName; //! //! ## async fn dummy() { - //! ## let app = $serviceName::builder(${serviceName}Config::builder().build()).build_unchecked(); + //! ## let app = $serviceName::builder( + //! ## ${serviceName}Config::builder() + //! ## .build()$unwrapConfigBuilder + //! ## ).build_unchecked(); //! let handler = LambdaHandler::new(app); //! lambda_http::run(handler).await.unwrap(); //! ## } @@ -118,7 +127,7 @@ open class ServerRootGenerator( //! let http_plugins = HttpPlugins::new() //! .push(LoggingPlugin) //! .push(MetricsPlugin); - //! let config = ${serviceName}Config::builder().build(); + //! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; //! let builder: $builderName = $serviceName::builder(config); //! ``` //! @@ -183,13 +192,13 @@ open class ServerRootGenerator( //! //! ## Example //! - //! ```rust + //! ```rust,no_run //! ## use std::net::SocketAddr; //! use $crateName::{$serviceName, ${serviceName}Config}; //! //! ##[#{Tokio}::main] //! pub async fn main() { - //! let config = ${serviceName}Config::builder().build(); + //! let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; //! let app = $serviceName::builder(config) ${builderFieldNames.values.joinToString("\n") { "//! .$it($it)" }} //! .build() @@ -236,6 +245,23 @@ open class ServerRootGenerator( fun render(rustWriter: RustWriter) { documentation(rustWriter) - rustWriter.rust("pub use crate::service::{$serviceName, ${serviceName}Config, ${serviceName}ConfigBuilder, ${serviceName}Builder, MissingOperationsError};") + // Only export config builder error if fallible. + val configErrorReExport = if (isConfigBuilderFallible) { + "${serviceName}ConfigError," + } else { + "" + } + rustWriter.rust( + """ + pub use crate::service::{ + $serviceName, + ${serviceName}Config, + ${serviceName}ConfigBuilder, + $configErrorReExport + ${serviceName}Builder, + MissingOperationsError + }; + """ + ) } } diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt index 97963cbb40..e1da585d05 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServerServiceGenerator.kt @@ -33,6 +33,7 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerRustModule.Output class ServerServiceGenerator( private val codegenContext: ServerCodegenContext, private val protocol: ServerProtocol, + private val isConfigBuilderFallible: Boolean, ) { private val runtimeConfig = codegenContext.runtimeConfig private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() @@ -107,6 +108,11 @@ class ServerServiceGenerator( val docHandler = DocHandlerGenerator(codegenContext, operationShape, "handler", "///") val handler = docHandler.docSignature() val handlerFixed = docHandler.docFixedSignature() + val unwrapConfigBuilder = if (isConfigBuilderFallible) { + ".expect(\"config failed to build\")" + } else { + "" + } rustTemplate( """ /// Sets the [`$structName`](crate::operation_shape::$structName) operation. @@ -123,7 +129,7 @@ class ServerServiceGenerator( /// #{Handler:W} /// - /// let config = ${serviceName}Config::builder().build(); + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; /// let app = $serviceName::builder(config) /// .$fieldName(handler) /// /* Set other handlers */ @@ -186,7 +192,7 @@ class ServerServiceGenerator( /// #{HandlerFixed:W} /// - /// let config = ${serviceName}Config::builder().build(); + /// let config = ${serviceName}Config::builder().build()$unwrapConfigBuilder; /// let svc = #{Tower}::util::service_fn(handler); /// let app = $serviceName::builder(config) /// .${fieldName}_service(svc) diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt index 960f9d0df7..3aeaa92cc2 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGenerator.kt @@ -6,31 +6,89 @@ package software.amazon.smithy.rust.codegen.server.smithy.generators import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter +import software.amazon.smithy.rust.codegen.core.rustlang.Writable +import software.amazon.smithy.rust.codegen.core.rustlang.conditionalBlock +import software.amazon.smithy.rust.codegen.core.rustlang.docs +import software.amazon.smithy.rust.codegen.core.rustlang.join +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock +import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope import software.amazon.smithy.rust.codegen.core.util.toPascalCase import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +fun List.isBuilderFallible() = this.any { it.isRequired } + +// TODO Docs +data class ConfigMethod( + val name: String, + val docs: String, + val params: List, + val errorType: RuntimeType?, + val initializer: Initializer, + /** Whether the user must invoke the method or not. **/ + val isRequired: Boolean, +) { + fun requiredBuilderFlagName(): String { + check(isRequired) { + "Config method is not required so it shouldn't need a field in the builder tracking whether it has been configured" + } + return "${name}_configured" + } + + fun requiredErrorVariant(): String { + check(isRequired) { + "Config method is not required so it shouldn't need an error variant" + } + return "${name.toPascalCase()}NotConfigured" + } +} + +// TODO Docs +data class Initializer( + val code: Writable, + /** Ordered list of layers that will be applied. **/ + val layerBindings: List, + val httpPluginBindings: List, + val modelPluginBindings: List, +) + +data class Binding( + val name: String, + val ty: RuntimeType, +) + class ServiceConfigGenerator( codegenContext: ServerCodegenContext, + private val configMethods: List, ) { private val crateName = codegenContext.moduleUseName() - private val codegenScope = codegenContext.runtimeConfig.let { runtimeConfig -> - val smithyHttpServer = ServerCargoDependency.smithyHttpServer(runtimeConfig).toType() - arrayOf( - "Debug" to RuntimeType.Debug, - "SmithyHttpServer" to smithyHttpServer, - "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), - "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), - "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), - "Tower" to RuntimeType.Tower, - "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), - ) - } + private val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + private val codegenScope = arrayOf( + *preludeScope, + "Debug" to RuntimeType.Debug, + "SmithyHttpServer" to smithyHttpServer, + "PluginStack" to smithyHttpServer.resolve("plugin::PluginStack"), + "ModelMarker" to smithyHttpServer.resolve("plugin::ModelMarker"), + "HttpMarker" to smithyHttpServer.resolve("plugin::HttpMarker"), + "Tower" to RuntimeType.Tower, + "Stack" to RuntimeType.Tower.resolve("layer::util::Stack"), + ) private val serviceName = codegenContext.serviceShape.id.name.toPascalCase() fun render(writer: RustWriter) { + val unwrapConfigBuilder = if (isBuilderFallible()) { + """ + /// .expect("config failed to build"); + """ + } else { + ";" + } + writer.rustTemplate( """ /// Configuration for the [`$serviceName`]. This is the central place where to register and @@ -50,7 +108,7 @@ class ServiceConfigGenerator( /// .http_plugin(authentication_plugin) /// // ...and right after deserialization, model plugins. /// .model_plugin(authorization_plugin) - /// .build(); + /// .build()$unwrapConfigBuilder /// ``` /// /// See the [`plugin`] system for details. @@ -74,6 +132,7 @@ class ServiceConfigGenerator( layers: #{Tower}::layer::util::Identity::new(), http_plugins: #{SmithyHttpServer}::plugin::IdentityPlugin, model_plugins: #{SmithyHttpServer}::plugin::IdentityPlugin, + #{BuilderRequiredMethodFlagsInit:W} } } } @@ -84,15 +143,21 @@ class ServiceConfigGenerator( pub(crate) layers: L, pub(crate) http_plugins: H, pub(crate) model_plugins: M, + #{BuilderRequiredMethodFlagDefinitions:W} } + + #{BuilderRequiredMethodError:W} impl ${serviceName}ConfigBuilder { + #{InjectedMethods:W} + /// Add a [`#{Tower}::Layer`] to the service. pub fn layer(self, layer: NewLayer) -> ${serviceName}ConfigBuilder<#{Stack}, H, M> { ${serviceName}ConfigBuilder { layers: #{Stack}::new(layer, self.layers), http_plugins: self.http_plugins, model_plugins: self.model_plugins, + #{BuilderRequiredMethodFlagsMove1:W} } } @@ -109,6 +174,7 @@ class ServiceConfigGenerator( layers: self.layers, http_plugins: #{PluginStack}::new(http_plugin, self.http_plugins), model_plugins: self.model_plugins, + #{BuilderRequiredMethodFlagsMove2:W} } } @@ -125,20 +191,203 @@ class ServiceConfigGenerator( layers: self.layers, http_plugins: self.http_plugins, model_plugins: #{PluginStack}::new(model_plugin, self.model_plugins), + #{BuilderRequiredMethodFlagsMove3:W} + } + } + + #{BuilderBuildMethod:W} + } + """, + *codegenScope, + "BuilderRequiredMethodFlagsInit" to builderRequiredMethodFlagsInit(), + "BuilderRequiredMethodFlagDefinitions" to builderRequiredMethodFlagsDefinitions(), + "BuilderRequiredMethodError" to builderRequiredMethodError(), + "InjectedMethods" to injectedMethods(), + "BuilderRequiredMethodFlagsMove1" to builderRequiredMethodFlagsMove(), + "BuilderRequiredMethodFlagsMove2" to builderRequiredMethodFlagsMove(), + "BuilderRequiredMethodFlagsMove3" to builderRequiredMethodFlagsMove(), + "BuilderBuildMethod" to builderBuildMethod(), + ) + } + + private fun isBuilderFallible() = configMethods.isBuilderFallible() + + private fun builderBuildRequiredMethodChecks() = configMethods.filter { it.isRequired }.map { + writable { + rust( + """ + if !self.${it.requiredBuilderFlagName()} { + return Err(${serviceName}ConfigError::${it.requiredErrorVariant()}); + } + """, + ) + } + }.join("\n") + + private fun builderRequiredMethodFlagsDefinitions() = configMethods.filter { it.isRequired }.map { + writable { rust("pub(crate) ${it.requiredBuilderFlagName()}: bool,") } + }.join("\n") + + private fun builderRequiredMethodFlagsInit() = configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: false,") } + }.join("\n") + + /** + * + * If you can come up with a better function name please change it. + */ + private fun builderRequiredMethodFlagsMove() = configMethods.filter { it.isRequired }.map { + writable { rust("${it.requiredBuilderFlagName()}: self.${it.requiredBuilderFlagName()},") } + }.join("\n") + + private fun builderRequiredMethodError() = writable { + val variants = configMethods.filter { it.isRequired }.map { + writable { + rust( + """ + ##[error("service is not fully configured; invoke `${it.requiredBuilderFlagName()}` on the config builder")] + ${it.requiredErrorVariant()}, + """, + ) + } + } + if (isBuilderFallible()) { + rustTemplate( + """ + ##[derive(Debug, #{ThisError}::Error)] + pub enum ${serviceName}ConfigError { + #{Variants:W} + } + """, + "ThisError" to ServerCargoDependency.ThisError.toType(), + "Variants" to variants.join("\n"), + ) + } + } + + private fun injectedMethods() = configMethods.map { + writable { + val paramBindings = it.params.map { binding -> + writable { rustTemplate("${binding.name}: #{BindingTy},", "BindingTy" to binding.ty) } + }.join("\n") + + // This produces a nested type like: "S>", where + // - "S" denotes a "stack type" with two generic type parameters: the first is the top of the stack and the + // second is the rest of the stack. For example, `aws_smithy_http_server::plugin::PluginStack`. + // - "B" is the type of the "thing" that is added. + // - "T" is the generic type variable name used in the enclosing impl block. + fun List.stackReturnType(genericTypeVarName: String, stackType: RuntimeType): Writable = + this.fold(writable { rust(genericTypeVarName) }) { acc, next -> + writable { + rustTemplate( + "#{StackType}<#{Ty}, #{Acc:W}>", + "StackType" to stackType, + "Ty" to next.ty, + "Acc" to acc, + ) } } - /// Build the configuration. - pub fn build(self) -> super::${serviceName}Config { + val layersReturnTy = + it.initializer.layerBindings.stackReturnType("L", RuntimeType.Tower.resolve("layer::util::Stack")) + val httpPluginsReturnTy = + it.initializer.httpPluginBindings.stackReturnType("H", smithyHttpServer.resolve("plugin::PluginStack")) + val modelPluginsReturnTy = + it.initializer.modelPluginBindings.stackReturnType("M", smithyHttpServer.resolve("plugin::PluginStack")) + + val configBuilderReturnTy = writable { + rustTemplate( + """ + ${serviceName}ConfigBuilder< + #{LayersReturnTy:W}, + #{HttpPluginsReturnTy:W}, + #{ModelPluginsReturnTy:W}, + > + """, + "LayersReturnTy" to layersReturnTy, + "HttpPluginsReturnTy" to httpPluginsReturnTy, + "ModelPluginsReturnTy" to modelPluginsReturnTy, + ) + } + + val returnTy = if (it.errorType != null) { + writable { + rustTemplate( + "#{Result}<#{T:W}, #{E}>", + "T" to configBuilderReturnTy, + "E" to it.errorType, + *codegenScope, + ) + } + } else { + configBuilderReturnTy + } + + docs(it.docs) + rustBlockTemplate( + """ + pub fn ${it.name}( + ##[allow(unused_mut)] + mut self, + #{ParamBindings:W} + ) -> #{ReturnTy:W} + """, + "ReturnTy" to returnTy, + "ParamBindings" to paramBindings, + ) { + rustTemplate("#{InitializerCode:W}", "InitializerCode" to it.initializer.code) + + check(it.initializer.layerBindings.size + it.initializer.httpPluginBindings.size + it.initializer.modelPluginBindings.size > 0) { + "This method's initializer does not register any layers, HTTP plugins, or model plugins. It must register at least something!" + } + + if (it.isRequired) { + rust("self.${it.requiredBuilderFlagName()} = true;") + } + conditionalBlock("Ok(", ")", conditional = it.errorType != null) { + val registrations = (it.initializer.layerBindings.map { ".layer(${it.name})" } + + it.initializer.httpPluginBindings.map { ".http_plugin(${it.name})" } + + it.initializer.modelPluginBindings.map { ".model_plugin(${it.name})" }).joinToString("") + rust("self${registrations}") + } + } + } + }.join("\n\n") + + private fun builderBuildReturnType() = writable { + val t = "super::${serviceName}Config" + + if (isBuilderFallible()) { + rustTemplate("#{Result}<$t, ${serviceName}ConfigError>", *codegenScope) + } else { + rust(t) + } + } + + private fun builderBuildMethod() = writable { + rustBlockTemplate( + """ + /// Build the configuration. + pub fn build(self) -> #{BuilderBuildReturnTy:W} + """, + "BuilderBuildReturnTy" to builderBuildReturnType(), + ) { + rustTemplate( + "#{BuilderBuildRequiredMethodChecks:W}", + "BuilderBuildRequiredMethodChecks" to builderBuildRequiredMethodChecks(), + ) + + conditionalBlock("Ok(", ")", isBuilderFallible()) { + rust( + """ super::${serviceName}Config { layers: self.layers, http_plugins: self.http_plugins, model_plugins: self.model_plugins, } - } + """, + ) } - """, - *codegenScope, - ) + } } } diff --git a/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt new file mode 100644 index 0000000000..9358f8a985 --- /dev/null +++ b/codegen-server/src/test/kotlin/software/amazon/smithy/rust/codegen/server/smithy/generators/ServiceConfigGeneratorTest.kt @@ -0,0 +1,226 @@ +package software.amazon.smithy.rust.codegen.server.smithy.generators + +import org.junit.jupiter.api.Test +import software.amazon.smithy.rust.codegen.core.rustlang.rust +import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.core.rustlang.writable +import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType +import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel +import software.amazon.smithy.rust.codegen.core.testutil.testModule +import software.amazon.smithy.rust.codegen.core.testutil.unitTest +import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency +import software.amazon.smithy.rust.codegen.server.smithy.ServerCodegenContext +import software.amazon.smithy.rust.codegen.server.smithy.customize.ServerCodegenDecorator +import software.amazon.smithy.rust.codegen.server.smithy.testutil.serverIntegrationTest +import java.io.File + +internal class ServiceConfigGeneratorTest { + @Test + fun `it should inject an aws_auth method that configures an HTTP plugin and a model plugin`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = object : ServerCodegenDecorator { + override val name: String + get() = "AWSAuth pre-applied middleware decorator" + override val order: Byte + get() = -69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val smithyHttpServer = ServerCargoDependency.smithyHttpServer(codegenContext.runtimeConfig).toType() + val codegenScope = arrayOf( + "SmithyHttpServer" to smithyHttpServer, + ) + return listOf( + ConfigMethod( + name = "aws_auth", + docs = "Docs", + params = listOf( + Binding("auth_spec", RuntimeType.String), + Binding("authorizer", RuntimeType.U64) + ), + errorType = RuntimeType.std.resolve("io::Error"), + initializer = Initializer( + code = writable { + rustTemplate( + """ + if authorizer != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 1")); + } + + if auth_spec.len() != 69 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "failure 2")); + } + let authn_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + let authz_plugin = #{SmithyHttpServer}::plugin::IdentityPlugin; + """, + *codegenScope, + ) + }, + layerBindings = emptyList(), + httpPluginBindings = listOf( + Binding( + "authn_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + modelPluginBindings = listOf( + Binding( + "authz_plugin", + smithyHttpServer.resolve("plugin::IdentityPlugin"), + ), + ), + ), + isRequired = true, + ), + ) + } + } + + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> + rustCrate.testModule { + rust( + """ + use crate::{SimpleServiceConfig, SimpleServiceConfigError}; + use aws_smithy_http_server::plugin::IdentityPlugin; + use crate::server::plugin::PluginStack; + """ + ) + + unitTest("successful_config_initialization") { + rust( + """ + let _: SimpleServiceConfig< + tower::layer::util::Identity, + // One HTTP plugin has been applied. + PluginStack, + // One model plugin has been applied. + PluginStack, + > = SimpleServiceConfig::builder() + .aws_auth("a".repeat(69).to_owned(), 69) + .expect("failed to configure aws_auth") + .build() + .unwrap(); + """, + ) + } + + unitTest("wrong_aws_auth_auth_spec") { + rust( + """ + let actual_err = SimpleServiceConfig::builder() + .aws_auth("a".to_owned(), 69) + .unwrap_err(); + let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 2").to_string(); + assert_eq!(actual_err.to_string(), expected); + """ + ) + } + + unitTest("wrong_aws_auth_authorizer") { + rust( + """ + let actual_err = SimpleServiceConfig::builder() + .aws_auth("a".repeat(69).to_owned(), 6969) + .unwrap_err(); + let expected = std::io::Error::new(std::io::ErrorKind::Other, "failure 1").to_string(); + assert_eq!(actual_err.to_string(), expected); + """ + ) + } + + unitTest("aws_auth_not_configured") { + rust( + """ + let actual_err = SimpleServiceConfig::builder().build().unwrap_err(); + let expected = SimpleServiceConfigError::AwsAuthNotConfigured.to_string(); + assert_eq!(actual_err.to_string(), expected); + """ + ) + } + } + } + } + + @Test + fun `it should inject an method that applies three non-required layers`() { + val model = File("../codegen-core/common-test-models/simple.smithy").readText().asSmithyModel() + + val decorator = object : ServerCodegenDecorator { + override val name: String + get() = "ApplyThreeNonRequiredLayers" + override val order: Byte + get() = 69 + + override fun configMethods(codegenContext: ServerCodegenContext): List { + val identityLayer = RuntimeType.Tower.resolve("layer::util::Identity") + val codegenScope = arrayOf( + "Identity" to identityLayer, + ) + return listOf( + ConfigMethod( + name = "three_non_required_layers", + docs = "Docs", + params = emptyList(), + errorType = null, + initializer = Initializer( + code = writable { + rustTemplate( + """ + let layer1 = #{Identity}::new(); + let layer2 = #{Identity}::new(); + let layer3 = #{Identity}::new(); + """, + *codegenScope, + ) + }, + layerBindings = listOf( + Binding("layer1", identityLayer), + Binding("layer2", identityLayer), + Binding("layer3", identityLayer), + ), + httpPluginBindings = emptyList(), + modelPluginBindings = emptyList(), + ), + isRequired = false, + ), + ) + } + } + + serverIntegrationTest(model, additionalDecorators = listOf(decorator)) { _, rustCrate -> + rustCrate.testModule { + unitTest("successful_config_initialization_applying_the_three_layers") { + rust( + """ + let _: crate::SimpleServiceConfig< + // Three Tower layers have been applied. + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Stack< + tower::layer::util::Identity, + tower::layer::util::Identity, + >, + >, + >, + aws_smithy_http_server::plugin::IdentityPlugin, + aws_smithy_http_server::plugin::IdentityPlugin, + > = crate::SimpleServiceConfig::builder() + .three_non_required_layers() + .build(); + """, + ) + } + + unitTest("successful_config_initialization_without_applying_the_three_layers") { + rust( + """ + crate::SimpleServiceConfig::builder().build(); + """, + ) + } + } + } + } +} diff --git a/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs b/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs index 6c96ebaca0..c42462ec52 100644 --- a/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs +++ b/rust-runtime/aws-smithy-http-server/src/plugin/stack.rs @@ -4,6 +4,7 @@ */ use super::{HttpMarker, ModelMarker, Plugin}; +use std::fmt::Debug; /// A wrapper struct which composes an `Inner` and an `Outer` [`Plugin`]. /// @@ -13,6 +14,7 @@ use super::{HttpMarker, ModelMarker, Plugin}; /// [`HttpPlugins`](crate::plugin::HttpPlugins), and the primary tool for composing HTTP plugins is /// [`ModelPlugins`](crate::plugin::ModelPlugins); if you are an application writer, you should /// prefer composing plugins using these. +#[derive(Debug)] pub struct PluginStack { inner: Inner, outer: Outer,