Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore SmokeTestsDecoratorTest #3811

Merged
merged 6 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ import software.amazon.smithy.rust.codegen.core.rustlang.Attribute
import software.amazon.smithy.rust.codegen.core.rustlang.Attribute.Companion.cfg
import software.amazon.smithy.rust.codegen.core.rustlang.AttributeKind
import software.amazon.smithy.rust.codegen.core.rustlang.RustWriter
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.containerDocs
import software.amazon.smithy.rust.codegen.core.rustlang.docs
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext
import software.amazon.smithy.rust.codegen.core.smithy.PublicImportSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
Expand Down Expand Up @@ -67,66 +70,83 @@ class SmokeTestsDecorator : ClientCodegenDecorator {
rustCrate: RustCrate,
) {
// Get all operations with smoke tests
val smokeTestedOperations =
codegenContext.model.getOperationShapesWithTrait(SmokeTestsTrait::class.java).toList()
val smokeTestedOperations = operationToTestCases(codegenContext.model)
val supportedTests =
smokeTestedOperations.map { operationShape ->
smokeTestedOperations.map { (operationShape, testCases) ->
// filter out unsupported smoke tests, logging a warning for each one, and sort the remaining tests by
// case ID. This ensures deterministic rendering, meaning the test methods are always rendered in a
// consistent order.
val testCases =
operationShape.expectTrait<SmokeTestsTrait>().testCases.filter { smokeTestCase ->
isSmokeTestSupported(smokeTestCase)
}.sortedBy { smokeTestCase -> smokeTestCase.id }
testCases.filter { smokeTestCase ->
isSmokeTestSupported(smokeTestCase)
}.sortedBy { smokeTestCase -> smokeTestCase.id }

operationShape to testCases
}
// filter out operations with no supported smoke tests
.filter { (_, testCases) -> testCases.isNotEmpty() }
// Similar to sorting test cases above, sort operations by name to ensure consistent ordering.
.sortedBy { (operationShape, _) -> operationShape.id.name }

// Return if there are no supported smoke tests across all operations
if (supportedTests.isEmpty()) return

rustCrate.integrationTest("smoketests") {
// Don't run the tests in this module unless `RUSTFLAGS="--cfg smoketests"` is passed.
Attribute(cfg("smoketests")).render(this, AttributeKind.Inner)

containerDocs(
"""
The tests in this module run against live AWS services. As such,
they are disabled by default. To enable them, run the tests with

```sh
RUSTFLAGS="--cfg smoketests" cargo test.
```""",
)

val model = codegenContext.model
val moduleUseName = codegenContext.moduleUseName()
rust("use $moduleUseName::{ Client, config };")
renderPrologue(codegenContext.moduleUseName(), this)

for ((operationShape, testCases) in supportedTests) {
val operationName = operationShape.id.name.toSnakeCase()
val operationInput = operationShape.inputShape(model)

docs("Smoke tests for the `$operationName` operation")

docs("Smoke tests for the `${operationShape.id.name.toSnakeCase()}` operation")
val instantiator =
SmokeTestsInstantiator(
codegenContext, operationShape,
configBuilderInitializer = { ->
writable {
rustTemplate(
"""
let config = #{awsConfig}::load_defaults(config::BehaviorVersion::latest()).await;
let conf = config::Config::from(&config).to_builder()
""",
"awsConfig" to AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toType(),
)
}
},
)
for (testCase in testCases) {
Attribute.TokioTest.render(this)
this.rustBlock("async fn test_${testCase.id.toSnakeCase()}()") {
val instantiator = SmokeTestsInstantiator(codegenContext)
instantiator.renderConf(this, testCase)
rust("let client = Client::from_conf(conf);")
instantiator.renderInput(this, operationShape, operationInput, testCase.params)
instantiator.renderExpectation(this, model, testCase.expectation)
instantiator.render(this, testCase)
}
}
}
}
}
}

fun renderPrologue(
moduleUseName: String,
writer: RustWriter,
) = writer.apply {
// Don't run the tests in this module unless `RUSTFLAGS="--cfg smoketests"` is passed.
Attribute(cfg("smoketests")).render(this, AttributeKind.Inner)

containerDocs(
"""
The tests in this module run against live AWS services. As such,
they are disabled by default. To enable them, run the tests with

```sh
RUSTFLAGS="--cfg smoketests" cargo test.
```
""",
)

rust("use $moduleUseName::{Client, config};")
}

fun operationToTestCases(model: Model) =
model.getOperationShapesWithTrait(SmokeTestsTrait::class.java).toList().map { operationShape ->
operationShape to operationShape.expectTrait<SmokeTestsTrait>().testCases
}

class SmokeTestsBuilderKindBehavior(val codegenContext: CodegenContext) : Instantiator.BuilderKindBehavior {
override fun hasFallibleBuilder(shape: StructureShape): Boolean =
BuilderGenerator.hasFallibleBuilder(shape, codegenContext.symbolProvider)
Expand All @@ -136,89 +156,99 @@ class SmokeTestsBuilderKindBehavior(val codegenContext: CodegenContext) : Instan
override fun doesSetterTakeInOption(memberShape: MemberShape): Boolean = true
}

class SmokeTestsInstantiator(private val codegenContext: ClientCodegenContext) : Instantiator(
PublicImportSymbolProvider(codegenContext.symbolProvider, codegenContext.moduleUseName()),
codegenContext.model,
codegenContext.runtimeConfig,
SmokeTestsBuilderKindBehavior(codegenContext),
) {
fun renderConf(
class SmokeTestsInstantiator(
codegenContext: ClientCodegenContext, private val operationShape: OperationShape,
private val configBuilderInitializer: () -> Writable,
) : Instantiator(
PublicImportSymbolProvider(codegenContext.symbolProvider, codegenContext.moduleUseName()),
codegenContext.model,
codegenContext.runtimeConfig,
SmokeTestsBuilderKindBehavior(codegenContext),
) {
private val model = codegenContext.model
private val symbolProvider = codegenContext.symbolProvider

fun render(
writer: RustWriter,
testCase: SmokeTestCase,
) {
writer.rust(
"let config = #{T}::load_defaults(config::BehaviorVersion::latest()).await;",
AwsCargoDependency.awsConfig(codegenContext.runtimeConfig).toType(),
)
writer.rust("let conf = config::Config::from(&config).to_builder()")
writer.indent()
) = writer.apply {
renderConf(this, testCase)
rust("let client = Client::from_conf(conf);")
renderInput(this, testCase.params)
renderExpectation(this, testCase.expectation)
}

private fun renderConf(
writer: RustWriter,
testCase: SmokeTestCase,
) = writer.apply {
rustTemplate("#{config_builder_initializer}", "config_builder_initializer" to configBuilderInitializer())
indent()

// TODO(https://github.com/smithy-lang/smithy-rs/issues/3776) Once Account ID routing is supported,
// reflect the config setting here, especially to disable it if needed, as it is enabled by default in
// `AwsVendorParams`.

val vendorParams = AwsSmokeTestModel.getAwsVendorParams(testCase)
vendorParams.orNull()?.let { params ->
writer.rust(".region(config::Region::new(${params.region.dq()}))")
writer.rust(".use_dual_stack(${params.useDualstack()})")
writer.rust(".use_fips(${params.useFips()})")
params.uri.orNull()?.let { writer.rust(".endpoint_url($it)") }
rust(".region(config::Region::new(${params.region.dq()}))")
rust(".use_dual_stack(${params.useDualstack()})")
rust(".use_fips(${params.useFips()})")
params.uri.orNull()?.let { rust(".endpoint_url($it)") }
}

val s3VendorParams = AwsSmokeTestModel.getS3VendorParams(testCase)
s3VendorParams.orNull()?.let { params ->
writer.rust(".accelerate_(${params.useAccelerate()})")
writer.rust(".force_path_style_(${params.forcePathStyle()})")
writer.rust(".use_arn_region(${params.useArnRegion()})")
writer.rust(".disable_multi_region_access_points(${params.useMultiRegionAccessPoints().not()})")
rust(".accelerate_(${params.useAccelerate()})")
rust(".force_path_style_(${params.forcePathStyle()})")
rust(".use_arn_region(${params.useArnRegion()})")
rust(".disable_multi_region_access_points(${params.useMultiRegionAccessPoints().not()})")
}

writer.rust(".build();")
writer.dedent()
rust(".build();")
dedent()
}

fun renderInput(
private fun renderInput(
writer: RustWriter,
operationShape: OperationShape,
inputShape: StructureShape,
data: Optional<ObjectNode>,
headers: Map<String, String> = mapOf(),
ctx: Ctx = Ctx(),
) {
) = writer.apply {
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved
val operationBuilderName =
FluentClientGenerator.clientOperationFnName(operationShape, codegenContext.symbolProvider)
FluentClientGenerator.clientOperationFnName(operationShape, symbolProvider)
val inputShape = operationShape.inputShape(model)

writer.rust("let res = client.$operationBuilderName()")
writer.indent()
rust("let res = client.$operationBuilderName()")
indent()
data.orNull()?.let {
renderStructureMembers(writer, inputShape, it, headers, ctx)
}
writer.rust(".send().await;")
writer.dedent()
rust(".send().await;")
dedent()
ysaito1001 marked this conversation as resolved.
Show resolved Hide resolved
}

fun renderExpectation(
private fun renderExpectation(
writer: RustWriter,
model: Model,
expectation: Expectation,
) {
) = writer.apply {
if (expectation.isSuccess) {
writer.rust("""res.expect("request should succeed");""")
rust("""res.expect("request should succeed");""")
} else if (expectation.isFailure) {
val expectedErrShape = expectation.failure.orNull()?.errorId?.orNull()
println(expectedErrShape)
if (expectedErrShape != null) {
val failureShape = model.expectShape(expectedErrShape)
val errName = codegenContext.symbolProvider.toSymbol(failureShape).name.toSnakeCase()
writer.rust(
val errName = symbolProvider.toSymbol(failureShape).name.toSnakeCase()
rust(
"""
let err = res.expect_err("request should fail");
let err = err.into_service_error();
assert!(err.is_$errName())
""",
)
} else {
writer.rust("""res.expect_err("request should fail");""")
rust("""res.expect_err("request should fail");""")
}
}
}
Expand Down
Loading