diff --git a/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthOption.swift b/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthOption.swift index b4b3a0d96..8349f3075 100644 --- a/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthOption.swift +++ b/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthOption.swift @@ -7,6 +7,16 @@ public struct AuthOption { let schemeID: String - var identityProperties: Attributes - var signingProperties: Attributes + public var identityProperties: Attributes + public var signingProperties: Attributes + + public init ( + schemeID: String, + identityProperties: Attributes = Attributes(), + signingProperties: Attributes = Attributes() + ) { + self.schemeID = schemeID + self.identityProperties = identityProperties + self.signingProperties = signingProperties + } } diff --git a/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthSchemeResolver.swift b/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthSchemeResolver.swift index 93d7d60da..980000198 100644 --- a/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthSchemeResolver.swift +++ b/Sources/ClientRuntime/Auth/HTTPAuthAPI/AuthSchemeResolver.swift @@ -6,6 +6,6 @@ // public protocol AuthSchemeResolver { - func resolveAuthScheme(params: AuthSchemeResolverParameters) -> [AuthOption] + func resolveAuthScheme(params: AuthSchemeResolverParameters) throws -> [AuthOption] func constructParameters(context: HttpContext) throws -> AuthSchemeResolverParameters } diff --git a/Sources/ClientRuntime/Middleware/Attribute.swift b/Sources/ClientRuntime/Middleware/Attribute.swift index b7c1261f5..b8c6ea5db 100644 --- a/Sources/ClientRuntime/Middleware/Attribute.swift +++ b/Sources/ClientRuntime/Middleware/Attribute.swift @@ -39,4 +39,8 @@ public struct Attributes { public mutating func remove(key: AttributeKey) { attributes.removeValue(forKey: key.name) } + + public func getSize() -> Int { + return attributes.count + } } diff --git a/Sources/ClientRuntime/Networking/Http/HttpContext.swift b/Sources/ClientRuntime/Networking/Http/HttpContext.swift index 618e64996..1d4071d89 100644 --- a/Sources/ClientRuntime/Networking/Http/HttpContext.swift +++ b/Sources/ClientRuntime/Networking/Http/HttpContext.swift @@ -67,6 +67,10 @@ public class HttpContext: MiddlewareContext { return attributes.get(key: AttributeKeys.method)! } + public func getOperation() -> String? { + return attributes.get(key: AttributeKeys.operation) + } + /// The partition ID to be used for this context. /// /// Requests made with the same partition ID will be grouped together for retry throttling purposes. @@ -79,6 +83,10 @@ public class HttpContext: MiddlewareContext { return attributes.get(key: AttributeKeys.path)! } + public func getRegion() -> String? { + return attributes.get(key: AttributeKeys.region) + } + public func getSelectedAuthScheme() -> SelectedAuthScheme? { return attributes.get(key: AttributeKeys.selectedAuthScheme) } @@ -87,6 +95,14 @@ public class HttpContext: MiddlewareContext { return attributes.get(key: AttributeKeys.serviceName)! } + public func getSigningName() -> String? { + return attributes.get(key: AttributeKeys.signingName) + } + + public func getSigningRegion() -> String? { + return attributes.get(key: AttributeKeys.signingRegion) + } + public func isBidirectionalStreamingEnabled() -> Bool { return attributes.get(key: AttributeKeys.bidirectionalStreaming) ?? false } @@ -128,6 +144,14 @@ public class HttpContextBuilder { return self } + @discardableResult + public func withAuthSchemes(value: [AuthScheme]) -> HttpContextBuilder { + for scheme in value { + self.withAuthScheme(value: scheme) + } + return self + } + @discardableResult public func withDecoder(value: ResponseDecoder) -> HttpContextBuilder { self.attributes.set(key: AttributeKeys.decoder, value: value) @@ -202,6 +226,12 @@ public class HttpContextBuilder { return self } + @discardableResult + public func withRegion(value: String?) -> HttpContextBuilder { + self.attributes.set(key: AttributeKeys.region, value: value) + return self + } + @discardableResult public func withResponse(value: HttpResponse) -> HttpContextBuilder { self.response = value @@ -220,6 +250,18 @@ public class HttpContextBuilder { return self } + @discardableResult + public func withSigningName(value: String) -> HttpContextBuilder { + self.attributes.set(key: AttributeKeys.signingName, value: value) + return self + } + + @discardableResult + public func withSigningRegion(value: String?) -> HttpContextBuilder { + self.attributes.set(key: AttributeKeys.signingRegion, value: value) + return self + } + public func build() -> HttpContext { return HttpContext(attributes: attributes) } @@ -244,9 +286,12 @@ public enum AttributeKeys { public static let operation = AttributeKey(name: "Operation") public static let partitionId = AttributeKey(name: "PartitionID") public static let path = AttributeKey(name: "Path") + public static let region = AttributeKey(name: "Region") public static let selectedAuthScheme = AttributeKey(name: "SelectedAuthScheme") public static let serviceName = AttributeKey(name: "ServiceName") + public static let signingName = AttributeKey(name: "SigningName") + public static let signingRegion = AttributeKey(name: "SigningRegion") // The attribute key used to store a credentials provider configured on service client config onto middleware context. - public static let awsIdResolver = AttributeKey(name: "AWSIDResolver") + public static let awsIdResolver = AttributeKey(name: "\(IdentityKind.aws)") } diff --git a/Sources/ClientRuntime/Networking/Http/Middlewares/AuthSchemeMiddleware.swift b/Sources/ClientRuntime/Networking/Http/Middlewares/AuthSchemeMiddleware.swift index f00589dbc..63ce07827 100644 --- a/Sources/ClientRuntime/Networking/Http/Middlewares/AuthSchemeMiddleware.swift +++ b/Sources/ClientRuntime/Networking/Http/Middlewares/AuthSchemeMiddleware.swift @@ -32,10 +32,11 @@ public struct AuthSchemeMiddleware 0 else { throw ClientError.authError("No identity resolver has been configured on the service.") } let identityResolverConfig = DefaultIdentityResolverConfiguration(configuredIdResolvers: identityResolvers) @@ -94,7 +95,7 @@ public struct AuthSchemeMiddleware [AuthOption] {", + "}", + ServiceTypes.AuthSchemeResolverParams + ) { + // Return value of array of auth options + write("var validAuthOptions = Array()") + + // Cast params to service specific params object + openBlock( + "guard let serviceParams = params as? \$L else {", + "}", + serviceParamsName + ) { + write("throw ClientError.authError(\"Service specific auth scheme parameters type must be passed to auth scheme resolver.\")") + } + + renderSwitchBlock(serviceIndex, ctx, this) + } + } + } + + private fun renderSwitchBlock( + serviceIndex: ServiceIndex, + ctx: ProtocolGenerator.GenerationContext, + writer: SwiftWriter + ) { + writer.apply { + // Switch block for iterating over operation name cases + openBlock("switch serviceParams.operation {", "}") { + // Handle each operation name case + val operations = ctx.service.operations + operations.filter { op -> + val opShape = ctx.model.getShape(op).get() as OperationShape + opShape.hasTrait(AuthTrait::class.java) || + opShape.hasTrait(OptionalAuthTrait::class.java) || + opShape.hasTrait(UnsignedPayloadTrait::class.java) + }.forEach { op -> + val opName = op.name.toLowerCamelCase() + val sdkId = getSdkId(ctx) + val validSchemesForOp = serviceIndex.getEffectiveAuthSchemes( + ctx.service, op, ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE + ) + renderOperationSwitchCase( + sdkId, + ctx.model.getShape(op).get() as OperationShape, + opName, + validSchemesForOp, + writer + ) + } + // Handle default case, where operations default to auth schemes defined on service shape + val validSchemesForService = serviceIndex.getEffectiveAuthSchemes(ctx.service, ServiceIndex.AuthSchemeMode.NO_AUTH_AWARE) + renderDefaultSwitchCase(getSdkId(ctx), validSchemesForService, writer) + } + + // Return result + write("return validAuthOptions") + } + } + + private fun renderOperationSwitchCase(sdkId: String, opShape: OperationShape, opName: String, schemes: Map, writer: SwiftWriter) { + writer.apply { + write("case \"$opName\":") + indent() + schemes.forEach { + if (it.key == SigV4Trait.ID) { + write("var sigV4Option = AuthOption(schemeID: \"${it.key}\")") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: \"${(it.value as SigV4Trait).name}\")") + openBlock("guard let region = serviceParams.region else {", "}") { + val errorMessage = "\"Missing region in auth scheme parameters for SigV4 auth scheme.\"" + write("throw ClientError.authError($errorMessage)") + } + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region)") + write("validAuthOptions.append(sigV4Option)") + } else { + write("validAuthOptions.append(AuthOption(schemeID: \"${it.key}\"))") + } + } + dedent() + } + } + + private fun renderDefaultSwitchCase(sdkId: String, schemes: Map, writer: SwiftWriter) { + writer.apply { + write("default:") + indent() + schemes.forEach { + if (it.key == SigV4Trait.ID) { + write("var sigV4Option = AuthOption(schemeID: \"${it.key}\")") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: \"${(it.value as SigV4Trait).name}\")") + openBlock("guard let region = serviceParams.region else {", "}") { + val errorMessage = "\"Missing region in auth scheme parameters for SigV4 auth scheme.\"" + write("throw ClientError.authError($errorMessage)") + } + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region)") + write("validAuthOptions.append(sigV4Option)") + } else { + write("validAuthOptions.append(AuthOption(schemeID: \"${it.key}\"))") + } + } + dedent() + } + } + + private fun renderConstructParametersMethod( + hasSigV4: Boolean, + returnTypeName: String, + writer: SwiftWriter + ) { + writer.apply { + openBlock( + "public func constructParameters(context: HttpContext) throws -> \$L {", + "}", + ServiceTypes.AuthSchemeResolverParams + ) { + openBlock("guard let opName = context.getOperation() else {", "}") { + write("throw ClientError.dataNotFound(\"Operation name not configured in middleware context for auth scheme resolver params construction.\")") + } + if (hasSigV4) { + write("let opRegion = context.getRegion()") + write("return $returnTypeName(operation: opName, region: opRegion)") + } else { + write("return $returnTypeName(operation: opName)") + } + } + } + } + + // Utility function for returning sdkId from generation context + fun getSdkId(ctx: ProtocolGenerator.GenerationContext): String { + return if (ctx.service.hasTrait(ServiceTrait::class.java)) + ctx.service.getTrait(ServiceTrait::class.java).get().sdkId.clientName() + else ctx.settings.sdkId.clientName() + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt index d4ae215c7..79f21a629 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/ClientRuntimeTypes.kt @@ -76,6 +76,8 @@ object ClientRuntimeTypes { val SerializableBodyMiddleware = runtimeSymbol("SerializableBodyMiddleware") val RetryMiddleware = runtimeSymbol("RetryMiddleware") val NoopHandler = runtimeSymbol("NoopHandler") + val SigningMiddleware = runtimeSymbol("SignerMiddleware") + val AuthSchemeMiddleware = runtimeSymbol("AuthSchemeMiddleware") object Providers { val URLPathProvider = runtimeSymbol("URLPathProvider") @@ -108,6 +110,8 @@ object ClientRuntimeTypes { val DateFormatter = runtimeSymbol("DateFormatter") val PaginateToken = runtimeSymbol("PaginateToken") val PaginatorSequence = runtimeSymbol("PaginatorSequence") + val AuthSchemeResolver = runtimeSymbol("AuthSchemeResolver") + val AuthSchemeResolverParameters = runtimeSymbol("AuthSchemeResolverParameters") } } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/DefaultHttpProtocolCustomizations.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/DefaultHttpProtocolCustomizations.kt index 83cdbc3f5..b897882d9 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/DefaultHttpProtocolCustomizations.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/DefaultHttpProtocolCustomizations.kt @@ -6,6 +6,7 @@ package software.amazon.smithy.swift.codegen.integration import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.swift.codegen.AuthSchemeResolverGenerator import software.amazon.smithy.swift.codegen.SwiftWriter abstract class DefaultHttpProtocolCustomizations : HttpProtocolCustomizable { @@ -25,4 +26,8 @@ abstract class DefaultHttpProtocolCustomizations : HttpProtocolCustomizable { ) { // Default implementation is no-op } + + override fun renderInternals(ctx: ProtocolGenerator.GenerationContext) { + AuthSchemeResolverGenerator().render(ctx) + } } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt index 0d1a7af38..ec4ad1b88 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpBindingProtocolGenerator.kt @@ -38,6 +38,7 @@ import software.amazon.smithy.swift.codegen.SwiftTypes import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.integration.codingKeys.CodingKeysGenerator import software.amazon.smithy.swift.codegen.integration.httpResponse.HttpResponseGeneratable +import software.amazon.smithy.swift.codegen.integration.middlewares.AuthSchemeMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.ContentLengthMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.ContentMD5Middleware import software.amazon.smithy.swift.codegen.integration.middlewares.ContentTypeMiddleware @@ -50,6 +51,7 @@ import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInp import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlHostMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.OperationInputUrlPathMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.RetryMiddleware +import software.amazon.smithy.swift.codegen.integration.middlewares.SigningMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.HttpBodyMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpHeaderProvider import software.amazon.smithy.swift.codegen.integration.middlewares.providers.HttpQueryItemProvider @@ -439,8 +441,12 @@ abstract class HttpBindingProtocolGenerator : ProtocolGenerator { operationMiddleware.appendMiddleware(operation, LoggingMiddleware(ctx.model, ctx.symbolProvider)) operationMiddleware.appendMiddleware(operation, RetryMiddleware(ctx.model, ctx.symbolProvider, retryErrorInfoProviderSymbol)) + operationMiddleware.appendMiddleware(operation, SigningMiddleware(ctx.model, ctx.symbolProvider)) + addProtocolSpecificMiddleware(ctx, operation) + operationMiddleware.appendMiddleware(operation, AuthSchemeMiddleware(ctx.model, ctx.symbolProvider)) + for (integration in ctx.integrations) { integration.customizeMiddleware(ctx, operation, operationMiddleware) } diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolServiceClient.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolServiceClient.kt index 0c298d25b..030de1a32 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolServiceClient.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolServiceClient.kt @@ -5,12 +5,14 @@ package software.amazon.smithy.swift.codegen.integration +import software.amazon.smithy.aws.traits.auth.SigV4Trait import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.knowledge.ServiceIndex import software.amazon.smithy.swift.codegen.ClientRuntimeTypes import software.amazon.smithy.swift.codegen.SwiftWriter open class HttpProtocolServiceClient( - ctx: ProtocolGenerator.GenerationContext, + private val ctx: ProtocolGenerator.GenerationContext, private val writer: SwiftWriter, private val properties: List, private val serviceConfig: ServiceConfig @@ -38,7 +40,12 @@ open class HttpProtocolServiceClient( } prop.renderInitialization(writer, "config") } - + // Render auth schemes + writer.write("var modeledAuthSchemes: [ClientRuntime.AuthScheme] = Array()") + if (ServiceIndex(ctx.model).getEffectiveAuthSchemes(ctx.service).contains(SigV4Trait.ID)) { + writer.write("modeledAuthSchemes.append(SigV4AuthScheme())") + } + writer.write("config.authSchemes = config.authSchemes ?? modeledAuthSchemes") writer.write("self.config = config") } writer.write("") diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolTestGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolTestGenerator.kt index b6385824f..d4fae0b25 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolTestGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HttpProtocolTestGenerator.kt @@ -74,8 +74,9 @@ class HttpProtocolTestGenerator( cloned.removeMiddleware(operation, MiddlewareStep.INITIALIZESTEP, "OperationInputUrlHostMiddleware") cloned.removeMiddleware(operation, MiddlewareStep.BUILDSTEP, "EndpointResolverMiddleware") cloned.removeMiddleware(operation, MiddlewareStep.BUILDSTEP, "UserAgentMiddleware") + cloned.removeMiddleware(operation, MiddlewareStep.BUILDSTEP, "AuthSchemeMiddleware") cloned.removeMiddleware(operation, MiddlewareStep.FINALIZESTEP, "RetryMiddleware") - cloned.removeMiddleware(operation, MiddlewareStep.FINALIZESTEP, "AWSSigningMiddleware") // causes tests to halt :( + cloned.removeMiddleware(operation, MiddlewareStep.FINALIZESTEP, "SigningMiddleware") cloned.removeMiddleware(operation, MiddlewareStep.DESERIALIZESTEP, "DeserializeMiddleware") cloned.removeMiddleware(operation, MiddlewareStep.DESERIALIZESTEP, "LoggingMiddleware") diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/ServiceConfig.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/ServiceConfig.kt index f96df2b25..8d06eb03a 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/ServiceConfig.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/ServiceConfig.kt @@ -8,6 +8,7 @@ package software.amazon.smithy.swift.codegen.integration import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.swift.codegen.ClientRuntimeTypes import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.model.buildSymbol import software.amazon.smithy.swift.codegen.utils.toUpperCamelCase /** @@ -54,3 +55,13 @@ abstract class ServiceConfig(val writer: SwiftWriter, val clientName: String, va open fun serviceConfigProperties(): List = listOf() } + +object ServiceTypes { + val AuthSchemes = symbol("[ClientRuntime.AuthScheme]") + val AuthSchemeResolver = symbol("ClientRuntime.AuthSchemeResolver") + val AuthSchemeResolverParams = symbol("ClientRuntime.AuthSchemeResolverParameters") + + private fun symbol(name: String): Symbol = buildSymbol { + this.name = name + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/AuthSchemeMiddleware.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/AuthSchemeMiddleware.kt new file mode 100644 index 000000000..4ecbb0c49 --- /dev/null +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/AuthSchemeMiddleware.kt @@ -0,0 +1,35 @@ +package software.amazon.smithy.swift.codegen.integration.middlewares + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils +import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition +import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable +import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep + +class AuthSchemeMiddleware( + val model: Model, + val symbolProvider: SymbolProvider +) : MiddlewareRenderable { + override val name = "AuthSchemeMiddleware" + + override val middlewareStep = MiddlewareStep.BUILDSTEP + + override val position = MiddlewarePosition.BEFORE + + override fun render( + writer: SwiftWriter, + op: OperationShape, + operationStackName: String + ) { + val output = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op) + val outputError = MiddlewareShapeUtils.outputErrorSymbol(op) + writer.write( + "$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>())", + ClientRuntimeTypes.Middleware.AuthSchemeMiddleware, output, outputError + ) + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/SigningMiddleware.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/SigningMiddleware.kt new file mode 100644 index 000000000..98024fce9 --- /dev/null +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/middlewares/SigningMiddleware.kt @@ -0,0 +1,35 @@ +package software.amazon.smithy.swift.codegen.integration.middlewares + +import software.amazon.smithy.codegen.core.SymbolProvider +import software.amazon.smithy.model.Model +import software.amazon.smithy.model.shapes.OperationShape +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils +import software.amazon.smithy.swift.codegen.middleware.MiddlewarePosition +import software.amazon.smithy.swift.codegen.middleware.MiddlewareRenderable +import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep + +class SigningMiddleware( + val model: Model, + val symbolProvider: SymbolProvider +) : MiddlewareRenderable { + override val name = "SigningMiddleware" + + override val middlewareStep = MiddlewareStep.FINALIZESTEP + + override val position = MiddlewarePosition.BEFORE + + override fun render( + writer: SwiftWriter, + op: OperationShape, + operationStackName: String + ) { + val output = MiddlewareShapeUtils.outputSymbol(symbolProvider, model, op) + val outputError = MiddlewareShapeUtils.outputErrorSymbol(op) + writer.write( + "$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>())", + ClientRuntimeTypes.Middleware.SigningMiddleware, output, outputError + ) + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt index c81302a98..d77897603 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/middleware/MiddlewareExecutionGenerator.kt @@ -51,6 +51,8 @@ class MiddlewareExecutionGenerator( writer.write(" .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator)") writer.write(" .withLogger(value: config.logger)") writer.write(" .withPartitionID(value: config.partitionID)") + writer.write(" .withAuthSchemes(value: config.authSchemes!)") + writer.write(" .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver)") val serviceShape = ctx.service httpProtocolCustomizable.renderContextAttributes(ctx, writer, serviceShape, op) diff --git a/smithy-swift-codegen/src/test/kotlin/AuthSchemeResolverGeneratorTests.kt b/smithy-swift-codegen/src/test/kotlin/AuthSchemeResolverGeneratorTests.kt new file mode 100644 index 000000000..1f3c73727 --- /dev/null +++ b/smithy-swift-codegen/src/test/kotlin/AuthSchemeResolverGeneratorTests.kt @@ -0,0 +1,115 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +package software.amazon.smithy.aws.swift.codegen + +import MockHttpRestJsonProtocolGenerator +import TestContext +import defaultSettings +import getFileContents +import io.kotest.matchers.string.shouldContainOnlyOnce +import org.junit.jupiter.api.Test +import shouldSyntacticSanityCheck + +class AuthSchemeResolverGeneratorTests { + @Test + fun `test auth scheme resolver generation`() { + val context = setupTests("auth-scheme-resolver-generator-test.smithy", "com.test#Example") + val contents = getFileContents(context.manifest, "/Example/AuthSchemeResolver.swift") + contents.shouldSyntacticSanityCheck() + contents.shouldContainOnlyOnce( + """ + public struct ExampleAuthSchemeResolverParameters: ClientRuntime.AuthSchemeResolverParameters { + public let operation: String + // Region is used for SigV4 auth scheme + public let region: String? + } + + public protocol ExampleAuthSchemeResolver: ClientRuntime.AuthSchemeResolver { + // Intentionally empty. + // This is the parent protocol that all auth scheme resolver implementations of + // the service Example must conform to. + } + + public struct DefaultExampleAuthSchemeResolver: ExampleAuthSchemeResolver { + public func resolveAuthScheme(params: ClientRuntime.AuthSchemeResolverParameters) throws -> [AuthOption] { + var validAuthOptions = Array() + guard let serviceParams = params as? ExampleAuthSchemeResolverParameters else { + throw ClientError.authError("Service specific auth scheme parameters type must be passed to auth scheme resolver.") + } + switch serviceParams.operation { + case "onlyHttpApiKeyAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + case "onlyHttpApiKeyAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyHttpBearerAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + case "onlyHttpBearerAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyHttpApiKeyAndBearerAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + case "onlyHttpApiKeyAndBearerAuthReversed": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + case "onlySigv4Auth": + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + case "onlySigv4AuthOptional": + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyCustomAuth": + validAuthOptions.append(AuthOption(schemeID: "com.test#customAuth")) + case "onlyCustomAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "com.test#customAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + default: + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + } + return validAuthOptions + } + + public func constructParameters(context: HttpContext) throws -> ClientRuntime.AuthSchemeResolverParameters { + guard let opName = context.getOperation() else { + throw ClientError.dataNotFound("Operation name not configured in middleware context for auth scheme resolver params construction.") + } + let opRegion = context.getRegion() + return ExampleAuthSchemeResolverParameters(operation: opName, region: opRegion) + } + } + """.trimIndent() + ) + } + + private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext { + val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestJsonProtocolGenerator()) { model -> + model.defaultSettings(serviceShapeId, "Example", "2023-11-02", "Example") + } + context.generator.initializeMiddleware(context.generationCtx) + context.generator.generateProtocolClient(context.generationCtx) + context.generator.generateSerializers(context.generationCtx) + context.generationCtx.delegator.flushWriters() + return context + } +} diff --git a/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt b/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt index cee419a12..1931499d8 100644 --- a/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/ContentMd5MiddlewareTests.kt @@ -25,6 +25,8 @@ class ContentMd5MiddlewareTests { .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) .withPartitionID(value: config.partitionID) + .withAuthSchemes(value: config.authSchemes!) + .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) .build() var operation = ClientRuntime.OperationStack(id: "idempotencyTokenWithStructure") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in @@ -38,10 +40,12 @@ class ContentMd5MiddlewareTests { operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware()) operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware()) operation.buildStep.intercept(position: .before, middleware: ClientRuntime.ContentMD5Middleware()) + operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/xml")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "IdempotencyToken")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) + operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let result = try await operation.handleMiddleware(context: context, input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt b/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt index 1e46c20a7..fd8c8aaac 100644 --- a/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/HttpProtocolClientGeneratorTests.kt @@ -33,6 +33,8 @@ class HttpProtocolClientGeneratorTests { decoder.dateDecodingStrategy = .secondsSince1970 decoder.nonConformingFloatDecodingStrategy = .convertFromString(positiveInfinity: "Infinity", negativeInfinity: "-Infinity", nan: "NaN") self.decoder = config.decoder ?? decoder + var modeledAuthSchemes: [ClientRuntime.AuthScheme] = Array() + config.authSchemes = config.authSchemes ?? modeledAuthSchemes self.config = config } @@ -123,6 +125,8 @@ class HttpProtocolClientGeneratorTests { .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) .withPartitionID(value: config.partitionID) + .withAuthSchemes(value: config.authSchemes!) + .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) .build() var operation = ClientRuntime.OperationStack(id: "allocateWidget") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in @@ -135,10 +139,12 @@ class HttpProtocolClientGeneratorTests { } operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware()) operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware()) + operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/json")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "AllocateWidgetInput")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) + operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let result = try await operation.handleMiddleware(context: context, input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt b/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt index 4ec4980a1..661568a78 100644 --- a/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt +++ b/smithy-swift-codegen/src/test/kotlin/IdempotencyTokenTraitTests.kt @@ -25,6 +25,8 @@ class IdempotencyTokenTraitTests { .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) .withLogger(value: config.logger) .withPartitionID(value: config.partitionID) + .withAuthSchemes(value: config.authSchemes!) + .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) .build() var operation = ClientRuntime.OperationStack(id: "idempotencyTokenWithStructure") operation.initializeStep.intercept(position: .after, id: "IdempotencyTokenMiddleware") { (context, input, next) -> ClientRuntime.OperationOutput in @@ -37,10 +39,12 @@ class IdempotencyTokenTraitTests { } operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware()) operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware()) + operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/xml")) operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.SerializableBodyMiddleware(xmlName: "IdempotencyToken")) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) + operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let result = try await operation.handleMiddleware(context: context, input: input, next: client.getHandler()) diff --git a/smithy-swift-codegen/src/test/resources/auth-scheme-resolver-generator-test.smithy b/smithy-swift-codegen/src/test/resources/auth-scheme-resolver-generator-test.smithy new file mode 100644 index 000000000..9bca46df4 --- /dev/null +++ b/smithy-swift-codegen/src/test/resources/auth-scheme-resolver-generator-test.smithy @@ -0,0 +1,85 @@ +$version: "2.0" + +namespace com.test + +use aws.auth#sigv4 +use aws.protocols#restJson1 + +@authDefinition +@trait +structure customAuth {} + +@restJson1 +@httpApiKeyAuth(name: "X-Api-Key", in: "header") +@httpBearerAuth +@sigv4(name: "weather") +@customAuth +@auth([sigv4]) +service Example { + version: "2023-11-02" + operations: [ + // experimentalIdentityAndAuth + OnlyHttpApiKeyAuth + OnlyHttpApiKeyAuthOptional + OnlyHttpBearerAuth + OnlyHttpBearerAuthOptional + OnlyHttpApiKeyAndBearerAuth + OnlyHttpApiKeyAndBearerAuthReversed + OnlySigv4Auth + OnlySigv4AuthOptional + OnlyCustomAuth + OnlyCustomAuthOptional + SameAsService + ] +} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAuth") +@auth([httpApiKeyAuth]) +operation OnlyHttpApiKeyAuth {} + +@http(method: "GET", uri: "/OnlyHttpBearerAuth") +@auth([httpBearerAuth]) +operation OnlyHttpBearerAuth {} + +@http(method: "GET", uri: "/OnlySigv4Auth") +@auth([sigv4]) +operation OnlySigv4Auth {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAndBearerAuth") +@auth([httpApiKeyAuth, httpBearerAuth]) +operation OnlyHttpApiKeyAndBearerAuth {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAndBearerAuthReversed") +@auth([httpBearerAuth, httpApiKeyAuth]) +operation OnlyHttpApiKeyAndBearerAuthReversed {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAuthOptional") +@auth([httpApiKeyAuth]) +@optionalAuth +operation OnlyHttpApiKeyAuthOptional {} + +@http(method: "GET", uri: "/OnlyHttpBearerAuthOptional") +@auth([httpBearerAuth]) +@optionalAuth +operation OnlyHttpBearerAuthOptional {} + +@http(method: "GET", uri: "/OnlySigv4AuthOptional") +@auth([sigv4]) +@optionalAuth +operation OnlySigv4AuthOptional {} + +@http(method: "GET", uri: "/OnlyCustomAuth") +@auth([customAuth]) +operation OnlyCustomAuth {} + +@http(method: "GET", uri: "/OnlyCustomAuthOptional") +@auth([customAuth]) +@optionalAuth +operation OnlyCustomAuthOptional {} + +@http(method: "GET", uri: "/SameAsService") +operation SameAsService { + output := { + service: String + } +} \ No newline at end of file