Skip to content

Commit

Permalink
Fix request id accessor when it is marked as required
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoh committed Sep 21, 2023
1 parent e75ad16 commit 99de898
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
9 changes: 5 additions & 4 deletions aws/sdk-adhoc-test/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ val allCodegenTests = listOf(
,
"codegen": {
"includeFluentClient": false,
"nullabilityCheckMode": \"${getNullabilityCheckMode()}\"
"nullabilityCheckMode": "${getNullabilityCheckMode()}"
},
"customizationConfig": {
"awsSdk": {
Expand All @@ -66,7 +66,7 @@ val allCodegenTests = listOf(
,
"codegen": {
"includeFluentClient": false,
"nullabilityCheckMode": ${getNullabilityCheckMode()}
"nullabilityCheckMode": "${getNullabilityCheckMode()}"
},
"customizationConfig": {
"awsSdk": {
Expand All @@ -83,11 +83,12 @@ val allCodegenTests = listOf(
,
"codegen": {
"includeFluentClient": false,
"nullabilityCheckMode": ${getNullabilityCheckMode()}
"nullabilityCheckMode": "${getNullabilityCheckMode()}"
},
"customizationConfig": {
"awsSdk": {
"generateReadme": false
"generateReadme": false,
"requireEndpointResolver": false
}
}
""",
Expand Down
6 changes: 5 additions & 1 deletion aws/sdk-adhoc-test/models/required-value-test.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ $version: "1.0"

namespace com.amazonaws.testservice

use aws.api#service
use aws.protocols#restJson1

@restJson1
@title("Test Service")
@service(sdkId: "Test")
Expand All @@ -10,11 +13,12 @@ service RequiredValues {
operations: [TestOperation]
}

@http(method: "GET", uri: "/")
operation TestOperation {
errors: [Error]
}

@error
@error("client")
structure Error {
@required
requestId: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

package software.amazon.smithy.rustsdk

import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
Expand All @@ -19,13 +21,15 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureSection
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.error.ErrorImplSection
import software.amazon.smithy.rust.codegen.core.smithy.isOptional
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.hasTrait

Expand Down Expand Up @@ -72,6 +76,11 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
}
}

open fun asMemberShape(container: StructureShape): MemberShape? {
println(container.members())
return container.members().firstOrNull { member -> member.memberName.lowercase() == "requestid" }
}

private inner class RequestIdOperationCustomization(private val codegenContext: ClientCodegenContext) :
OperationCustomization() {
override fun section(section: OperationSection): Writable = writable {
Expand All @@ -82,19 +91,22 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
"apply_to_error" to applyToError(codegenContext),
)
}

is OperationSection.MutateOutput -> {
rust(
"output._set_$fieldName(#T::$accessorFunctionName(${section.responseHeadersName}).map(str::to_string));",
accessorTrait(codegenContext),
)
}

is OperationSection.BeforeParseResponse -> {
rustTemplate(
"#{tracing}::debug!($fieldName = ?#{trait}::$accessorFunctionName(${section.responseName}));",
"tracing" to RuntimeType.Tracing,
"trait" to accessorTrait(codegenContext),
)
}

else -> {}
}
}
Expand Down Expand Up @@ -123,8 +135,17 @@ abstract class BaseRequestIdDecorator : ClientCodegenDecorator {
rustBlock("fn $accessorFunctionName(&self) -> Option<&str>") {
rustBlock("match self") {
section.allErrors.forEach { error ->
val optional = asMemberShape(error)?.let { member ->
codegenContext.symbolProvider.toSymbol(member).isOptional()
} ?: true
val wrapped = writable {
when (optional) {
false -> rustTemplate("#{Some}(e.$accessorFunctionName())", *preludeScope)
true -> rustTemplate("e.$accessorFunctionName()")
}
}
val sym = codegenContext.symbolProvider.toSymbol(error)
rust("Self::${sym.name}(e) => e.$accessorFunctionName(),")
rust("Self::${sym.name}(e) => #T,", wrapped)
}
rust("Self::Unhandled(e) => e.$accessorFunctionName(),")
}
Expand Down

0 comments on commit 99de898

Please sign in to comment.