Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent test dependencies from leaking into production #2264

Merged
merged 13 commits into from
Feb 6, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
import software.amazon.smithy.rust.codegen.core.testutil.testDependenciesOnly
import java.nio.file.Files
import java.nio.file.Paths
import kotlin.io.path.absolute
Expand Down Expand Up @@ -72,7 +73,7 @@ class IntegrationTestDependencies(
private val hasBenches: Boolean,
) : LibRsCustomization() {
override fun section(section: LibRsSection) = when (section) {
is LibRsSection.Body -> writable {
is LibRsSection.Body -> testDependenciesOnly {
if (hasTests) {
val smithyClient = CargoDependency.smithyClient(runtimeConfig)
.copy(features = setOf("test-util"), scope = DependencyScope.Dev)
Expand All @@ -81,7 +82,7 @@ class IntegrationTestDependencies(
addDependency(SerdeJson)
addDependency(Tokio)
addDependency(FuturesUtil)
addDependency(Tracing)
addDependency(Tracing.copy(scope = DependencyScope.Dev))
addDependency(TracingSubscriber)
}
if (hasBenches) {
Expand All @@ -91,6 +92,7 @@ class IntegrationTestDependencies(
serviceSpecific.section(section)(this)
}
}

else -> emptySection
}

Expand All @@ -115,7 +117,7 @@ class S3TestDependencies : LibRsCustomization() {
writable {
addDependency(AsyncStd)
addDependency(BytesUtils)
addDependency(FastRand)
addDependency(FastRand.copy(scope = DependencyScope.Dev))
addDependency(HdrHistogram)
addDependency(Smol)
addDependency(TempFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointTypesG
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.join
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.PublicImportSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
Expand Down Expand Up @@ -146,8 +146,7 @@ class OperationInputTestGenerator(_ctx: ClientCodegenContext, private val test:
let _result = dbg!(#{invoke_operation});
#{assertion}
""",
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
"conf" to config(testOperationInput),
"invoke_operation" to operationInvocation(testOperationInput),
"assertion" to writable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest
Expand Down Expand Up @@ -96,8 +96,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/us-west-2/foobaz/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand All @@ -120,8 +119,7 @@ class EndpointsCredentialsTest {
let auth_header = req.headers().get("AUTHORIZATION").unwrap().to_str().unwrap();
assert!(auth_header.contains("/region-custom-auth/name-custom-auth/aws4_request"), "{}", auth_header);
""",
"capture_request" to CargoDependency.smithyClient(context.runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(context.runtimeConfig),
"Credentials" to AwsCargoDependency.awsCredentialTypes(context.runtimeConfig)
.withFeature("test-util").toType().resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(context.runtimeConfig).resolve("region::Region"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.symbol
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.derive
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustType
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
Expand Down Expand Up @@ -59,7 +58,7 @@ val EndpointTests = RustModule.new(
documentation = "Generated endpoint tests",
parent = EndpointsModule,
inline = true,
).copy(rustMetadata = RustMetadata.TestModule)
).cfgTest()

// stdlib is isolated because it contains code generated names of stdlib functions–we want to ensure we avoid clashing
val EndpointsStdLib = RustModule.private("endpoint_lib", "Endpoints standard library functions")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import software.amazon.smithy.rust.codegen.client.smithy.endpoint.EndpointCustom
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.Types
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.rustName
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.escape
Expand Down Expand Up @@ -48,8 +47,7 @@ internal class EndpointTestGenerator(
"Error" to types.resolveEndpointError,
"Document" to RuntimeType.document(runtimeConfig),
"HashMap" to RuntimeType.HashMap,
"capture_request" to CargoDependency.smithyClient(runtimeConfig)
.withFeature("test-util").toType().resolve("test_connection::capture_request"),
"capture_request" to RuntimeType.captureRequest(runtimeConfig),
)

private val instantiator = clientInstantiator(codegenContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.generators.clientInstantiator
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.allow
import software.amazon.smithy.rust.codegen.core.rustlang.RustMetadata
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.escape
import software.amazon.smithy.rust.codegen.core.rustlang.rust
Expand Down Expand Up @@ -91,14 +89,10 @@ class ProtocolTestGenerator(
if (allTests.isNotEmpty()) {
val operationName = operationSymbol.name
val testModuleName = "${operationName.toSnakeCase()}_request_test"
val moduleMeta = RustMetadata(
visibility = Visibility.PRIVATE,
additionalAttributes = listOf(
Attribute.CfgTest,
Attribute(allow("unreachable_code", "unused_variables")),
),
val additionalAttributes = listOf(
Attribute(allow("unreachable_code", "unused_variables")),
)
writer.withInlineModule(RustModule.LeafModule(testModuleName, moduleMeta, inline = true)) {
writer.withInlineModule(RustModule.inlineTests(testModuleName, additionalAttributes = additionalAttributes)) {
renderAllTestCases(allTests)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ data class CargoDependency(
) : RustDependency(name) {
val key: Triple<String, DependencyLocation, DependencyScope> get() = Triple(name, location, scope)

val nameAndLocation: Pair<String, DependencyLocation> = name to location

fun withFeature(feature: String): CargoDependency {
return copy(features = features.toMutableSet().apply { add(feature) })
}
Expand Down Expand Up @@ -190,7 +192,7 @@ data class CargoDependency(
val OnceCell: CargoDependency = CargoDependency("once_cell", CratesIo("1.16"))
val Url: CargoDependency = CargoDependency("url", CratesIo("2.3.1"))
val Bytes: CargoDependency = CargoDependency("bytes", CratesIo("1.0.0"))
val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.0"))
val BytesUtils: CargoDependency = CargoDependency("bytes-utils", CratesIo("0.1.0"), DependencyScope.Dev)
val FastRand: CargoDependency = CargoDependency("fastrand", CratesIo("1.8.0"))
val Hex: CargoDependency = CargoDependency("hex", CratesIo("0.4.3"))
val Http: CargoDependency = CargoDependency("http", CratesIo("0.2.0"))
Expand Down Expand Up @@ -220,7 +222,12 @@ data class CargoDependency(
val Smol: CargoDependency = CargoDependency("smol", CratesIo("1.2.0"), DependencyScope.Dev)
val TempFile: CargoDependency = CargoDependency("tempfile", CratesIo("3.2.0"), DependencyScope.Dev)
val Tokio: CargoDependency =
CargoDependency("tokio", CratesIo("1.8.4"), DependencyScope.Dev, features = setOf("macros", "test-util", "rt-multi-thread"))
CargoDependency(
"tokio",
CratesIo("1.8.4"),
DependencyScope.Dev,
features = setOf("macros", "test-util", "rt-multi-thread"),
)
val TracingAppender: CargoDependency = CargoDependency(
"tracing-appender",
CratesIo("0.2.2"),
Expand All @@ -236,12 +243,16 @@ data class CargoDependency(
fun smithyAsync(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-async")
fun smithyChecksums(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-checksums")
fun smithyClient(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-client")
fun smithyClientTestUtil(runtimeConfig: RuntimeConfig) =
smithyClient(runtimeConfig).copy(scope = DependencyScope.Dev).withFeature("test-util")

fun smithyEventStream(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-eventstream")
fun smithyHttp(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http")
fun smithyHttpTower(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-http-tower")
fun smithyJson(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-json")
fun smithyProtocolTestHelpers(runtimeConfig: RuntimeConfig) =
runtimeConfig.smithyRuntimeCrate("smithy-protocol-test", scope = DependencyScope.Dev)

fun smithyQuery(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-query")
fun smithyTypes(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-types")
fun smithyXml(runtimeConfig: RuntimeConfig) = runtimeConfig.smithyRuntimeCrate("smithy-xml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ sealed class RustModule {
val documentation: String? = null,
val parent: RustModule = LibRs,
val inline: Boolean = false,
/* module is a cfg(test) module */
val tests: Boolean = false,
) : RustModule() {

init {
check(!name.contains("::")) {
"Module names CANNOT contain `::`—modules must be nested with parent (name was: `$name`)"
Expand All @@ -45,6 +48,12 @@ sealed class RustModule {
"Module `$name` cannot be a module name—it is a reserved word."
}
}

/** Convert a module into a module gated with `#[cfg(test)]` */
fun cfgTest(): LeafModule = this.copy(
rustMetadata = rustMetadata.copy(additionalAttributes = rustMetadata.additionalAttributes + Attribute.CfgTest),
tests = true,
)
}

companion object {
Expand Down Expand Up @@ -78,12 +87,36 @@ sealed class RustModule {
fun pubCrate(name: String, documentation: String? = null, parent: RustModule): LeafModule =
new(name, visibility = Visibility.PUBCRATE, documentation = documentation, inline = false, parent = parent)

fun inlineTests(
name: String = "test",
parent: RustModule = LibRs,
additionalAttributes: List<Attribute> = listOf(),
) = new(
name,
Visibility.PRIVATE,
inline = true,
additionalAttributes = additionalAttributes,
parent = parent,
).cfgTest()

/* Common modules used across client, server and tests */
val Config = public("config", documentation = "Configuration for the service.")
val Error = public("error", documentation = "All error types that operations can return. Documentation on these types is copied from the model.")
val Model = public("model", documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.")
val Input = public("input", documentation = "Input structures for operations. Documentation on these types is copied from the model.")
val Output = public("output", documentation = "Output structures for operations. Documentation on these types is copied from the model.")
val Error = public(
"error",
documentation = "All error types that operations can return. Documentation on these types is copied from the model.",
)
val Model = public(
"model",
documentation = "Data structures used by operation inputs/outputs. Documentation on these types is copied from the model.",
)
val Input = public(
"input",
documentation = "Input structures for operations. Documentation on these types is copied from the model.",
)
val Output = public(
"output",
documentation = "Output structures for operations. Documentation on these types is copied from the model.",
)
val Types = public("types", documentation = "Data primitives referenced by other data types.")

/**
Expand Down
Loading