Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add builder symbol/module resolution to symbol providers #2395

Merged
merged 8 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions aws/rust-runtime/aws-config/src/profile/credentials/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,21 @@
* SPDX-License-Identifier: Apache-2.0
*/

use std::sync::Arc;

use aws_sdk_sts::operation::AssumeRole;
use aws_sdk_sts::{Config, Credentials};
use aws_types::region::Region;

use super::repr::{self, BaseProvider};

use crate::credential_process::CredentialProcessProvider;
use crate::profile::credentials::ProfileFileError;
use crate::provider_config::ProviderConfig;
use crate::sso::{SsoConfig, SsoCredentialsProvider};
use crate::sts;
use crate::web_identity_token::{StaticConfiguration, WebIdentityTokenCredentialsProvider};
use aws_credential_types::provider::{self, error::CredentialsError, ProvideCredentials};
use aws_sdk_sts::input::AssumeRoleInput;
use aws_sdk_sts::middleware::DefaultMiddleware;
use aws_sdk_sts::{Config, Credentials};
use aws_smithy_client::erase::DynConnector;

use aws_types::region::Region;
use std::fmt::Debug;
use std::sync::Arc;

#[derive(Debug)]
pub(super) struct AssumeRoleProvider {
Expand Down Expand Up @@ -51,7 +47,7 @@ impl AssumeRoleProvider {
.as_ref()
.cloned()
.unwrap_or_else(|| sts::util::default_session_name("assume-role-from-profile"));
let operation = AssumeRole::builder()
let operation = AssumeRoleInput::builder()
.role_arn(&self.role_arn)
.set_external_id(self.external_id.clone())
.role_session_name(session_name)
Expand Down
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-config/src/sso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ async fn load_sso_credentials(
let config = aws_sdk_sso::Config::builder()
.region(sso_config.region.clone())
.build();
let operation = aws_sdk_sso::operation::GetRoleCredentials::builder()
let operation = aws_sdk_sso::input::GetRoleCredentialsInput::builder()
.role_name(&sso_config.role_name)
.access_token(&*token.access_token)
.account_id(&sso_config.account_id)
Expand Down
9 changes: 4 additions & 5 deletions aws/rust-runtime/aws-config/src/sts/assume_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,18 @@

//! Assume credentials for a role through the AWS Security Token Service (STS).

use crate::provider_config::ProviderConfig;
use aws_credential_types::cache::CredentialsCache;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_sdk_sts::error::AssumeRoleError;
use aws_sdk_sts::input::AssumeRoleInput;
use aws_sdk_sts::middleware::DefaultMiddleware;
use aws_sdk_sts::model::PolicyDescriptorType;
use aws_sdk_sts::operation::AssumeRole;
use aws_smithy_client::erase::DynConnector;
use aws_smithy_http::result::SdkError;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::region::Region;
use std::time::Duration;

use crate::provider_config::ProviderConfig;
use aws_smithy_types::error::display::DisplayErrorContext;
use tracing::Instrument;

/// Credentials provider that uses credentials provided by another provider to assume a role
Expand Down Expand Up @@ -225,7 +224,7 @@ impl AssumeRoleProviderBuilder {
.session_name
.unwrap_or_else(|| super::util::default_session_name("assume-role-provider"));

let operation = AssumeRole::builder()
let operation = AssumeRoleInput::builder()
.set_role_arn(Some(self.role_arn))
.set_external_id(self.external_id)
.set_role_session_name(Some(session_name))
Expand Down
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-config/src/web_identity_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async fn load_credentials(
.region(region.clone())
.build();

let operation = aws_sdk_sts::operation::AssumeRoleWithWebIdentity::builder()
let operation = aws_sdk_sts::input::AssumeRoleWithWebIdentityInput::builder()
.role_arn(role_arn)
.role_session_name(session_name)
.web_identity_token(token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.generators.BuilderGenerat
import software.amazon.smithy.rust.codegen.core.smithy.generators.StructureGenerator
import software.amazon.smithy.rust.codegen.core.smithy.generators.UnionGenerator
import software.amazon.smithy.rust.codegen.core.smithy.protocols.ProtocolGeneratorFactory
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.transformers.EventStreamNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.OperationNormalizer
import software.amazon.smithy.rust.codegen.core.smithy.transformers.RecursiveShapeBoxer
Expand Down Expand Up @@ -75,6 +74,12 @@ class ClientCodegenVisitor(
true -> ClientModuleProvider
else -> OldModuleSchemeClientModuleProvider
},
nameBuilderFor = { symbol ->
when (settings.codegenConfig.enableNewCrateOrganizationScheme) {
true -> "${symbol.name}Builder"
else -> "Builder"
}
},
)
val baseModel = baselineTransform(context.model)
val untransformedService = settings.getService(baseModel)
Expand All @@ -85,7 +90,7 @@ class ClientCodegenVisitor(
model = codegenDecorator.transformModel(untransformedService, baseModel)
// the model transformer _might_ change the service shape
val service = settings.getService(model)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(model, service, rustSymbolProviderConfig)
symbolProvider = RustClientCodegenPlugin.baseSymbolProvider(settings, model, service, rustSymbolProviderConfig)

codegenContext = ClientCodegenContext(model, symbolProvider, service, protocol, settings, codegenDecorator)

Expand Down Expand Up @@ -187,9 +192,9 @@ class ClientCodegenVisitor(
* This function _does not_ generate any serializers
*/
override fun structureShape(shape: StructureShape) {
rustCrate.useShapeWriter(shape) {
when (val errorTrait = shape.getTrait<ErrorTrait>()) {
null -> {
when (val errorTrait = shape.getTrait<ErrorTrait>()) {
null -> {
rustCrate.useShapeWriter(shape) {
StructureGenerator(
model,
symbolProvider,
Expand All @@ -198,31 +203,30 @@ class ClientCodegenVisitor(
codegenDecorator.structureCustomizations(codegenContext, emptyList()),
).render()

if (!shape.hasTrait<SyntheticInputTrait>()) {
val builderGenerator =
BuilderGenerator(
codegenContext.model,
codegenContext.symbolProvider,
shape,
codegenDecorator.builderCustomizations(codegenContext, emptyList()),
)
builderGenerator.render(this)
implBlock(symbolProvider.toSymbol(shape)) {
builderGenerator.renderConvenienceMethod(this)
}
implBlock(symbolProvider.toSymbol(shape)) {
BuilderGenerator.renderConvenienceMethod(this, symbolProvider, shape)
}
}
else -> {
ErrorGenerator(
model,
symbolProvider,
this,

rustCrate.withModule(symbolProvider.moduleForBuilder(shape)) {
BuilderGenerator(
codegenContext.model,
codegenContext.symbolProvider,
shape,
errorTrait,
codegenDecorator.errorImplCustomizations(codegenContext, emptyList()),
).render()
codegenDecorator.builderCustomizations(codegenContext, emptyList()),
).render(this)
}
}
else -> {
ErrorGenerator(
rustCrate,
model,
symbolProvider,
shape,
errorTrait,
codegenDecorator.errorImplCustomizations(codegenContext, emptyList()),
).render()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.rust.codegen.client.smithy

import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.Model
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.Shape
Expand All @@ -13,10 +14,12 @@ import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.ErrorTrait
import software.amazon.smithy.rust.codegen.core.rustlang.RustModule
import software.amazon.smithy.rust.codegen.core.rustlang.RustReservedWords
import software.amazon.smithy.rust.codegen.core.rustlang.Visibility
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProvider
import software.amazon.smithy.rust.codegen.core.smithy.ModuleProviderContext
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.contextName
import software.amazon.smithy.rust.codegen.core.smithy.module
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticInputTrait
import software.amazon.smithy.rust.codegen.core.smithy.traits.SyntheticOutputTrait
import software.amazon.smithy.rust.codegen.core.util.UNREACHABLE
Expand Down Expand Up @@ -75,6 +78,9 @@ object ClientModuleProvider : ModuleProvider {
eventStream: UnionShape,
): RustModule.LeafModule = ClientRustModule.Error

override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule =
RustModule.public("builders", parent = symbol.module(), documentation = "Builders")

private fun Shape.findOperation(model: Model): OperationShape {
val inputTrait = getTrait<SyntheticInputTrait>()
val outputTrait = getTrait<SyntheticOutputTrait>()
Expand Down Expand Up @@ -122,6 +128,17 @@ object OldModuleSchemeClientModuleProvider : ModuleProvider {
context: ModuleProviderContext,
eventStream: UnionShape,
): RustModule.LeafModule = ClientRustModule.Error

override fun moduleForBuilder(context: ModuleProviderContext, shape: Shape, symbol: Symbol): RustModule.LeafModule {
val builderNamespace = RustReservedWords.escapeIfNeeded(symbol.name.toSnakeCase())
return RustModule.new(
builderNamespace,
visibility = Visibility.PUBLIC,
parent = symbol.module(),
inline = true,
documentation = "See [`${symbol.name}`](${symbol.module().fullyQualifiedPath()}::${symbol.name}).",
)
}
}

// TODO(CrateReorganization): Remove when cleaning up `enableNewCrateOrganizationScheme`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ class RustClientCodegenPlugin : ClientDecoratableBuildPlugin() {
* The Symbol provider is composed of a base [SymbolVisitor] which handles the core functionality, then is layered
* with other symbol providers, documented inline, to handle the full scope of Smithy types.
*/
fun baseSymbolProvider(model: Model, serviceShape: ServiceShape, rustSymbolProviderConfig: RustSymbolProviderConfig) =
SymbolVisitor(model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
fun baseSymbolProvider(
settings: ClientRustSettings,
model: Model,
serviceShape: ServiceShape,
rustSymbolProviderConfig: RustSymbolProviderConfig,
) =
SymbolVisitor(settings, model, serviceShape = serviceShape, config = rustSymbolProviderConfig)
// Generate different types for EventStream shapes (e.g. transcribe streaming)
.let { EventStreamSymbolProvider(rustSymbolProviderConfig.runtimeConfig, it, CodegenTarget.CLIENT) }
// Generate `ByteStream` instead of `Blob` for streaming binary shapes (e.g. S3 GetObject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.stripOuter
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.PANIC
import software.amazon.smithy.rust.codegen.core.util.findMemberWithTrait
Expand Down Expand Up @@ -98,7 +97,7 @@ class PaginatorGenerator private constructor(
"Input" to inputType,
"Output" to outputType,
"Error" to errorType,
"Builder" to operation.inputShape(model).builderSymbol(symbolProvider),
"Builder" to symbolProvider.symbolForBuilder(operation.inputShape(model)),

// SDK Types
"SdkError" to RuntimeType.sdkError(runtimeConfig),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.RustSymbolProvider
import software.amazon.smithy.rust.codegen.core.smithy.customize.writeCustomizations
import software.amazon.smithy.rust.codegen.core.smithy.expectRustMetadata
import software.amazon.smithy.rust.codegen.core.smithy.generators.builderSymbol
import software.amazon.smithy.rust.codegen.core.smithy.generators.setterName
import software.amazon.smithy.rust.codegen.core.smithy.rustType
import software.amazon.smithy.rust.codegen.core.util.inputShape
Expand Down Expand Up @@ -255,7 +254,7 @@ class FluentClientGenerator(
inner: #{Inner}
}
""",
"Inner" to input.builderSymbol(symbolProvider),
"Inner" to symbolProvider.symbolForBuilder(input),
"client" to RuntimeType.smithyClient(runtimeConfig),
"generics" to generics.decl,
"operation" to operationSymbol,
Expand Down
Loading