diff --git a/.changes/4602d073-9393-4496-b63d-85535a2f631a.json b/.changes/4602d073-9393-4496-b63d-85535a2f631a.json new file mode 100644 index 000000000..714f1c7f9 --- /dev/null +++ b/.changes/4602d073-9393-4496-b63d-85535a2f631a.json @@ -0,0 +1,5 @@ +{ + "id": "4602d073-9393-4496-b63d-85535a2f631a", + "type": "feature", + "description": "Generate paginators and waiters with a default parameter when input shape has all optional members" +} \ No newline at end of file diff --git a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/MiscTestUtils.kt b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/MiscTestUtils.kt index d6cdc64b2..f79840c0b 100644 --- a/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/MiscTestUtils.kt +++ b/codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/MiscTestUtils.kt @@ -59,7 +59,7 @@ fun String.shouldContain(expectedStart: String, expectedEnd: String) { } fun List.indexOfSublistOrNull(sublist: List, startFrom: Int = 0): Int? = - drop(startFrom).windowed(sublist.size).indexOf(sublist) + drop(startFrom).windowed(sublist.size).indexOf(sublist).takeIf { it >= 0 } /** Format a multi-line string suitable for comparison with codegen, defaults to one level of indention. */ fun String.formatForTest(indent: String = " ") = diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/ShapeExt.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/ShapeExt.kt index 395586daa..a34114158 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/ShapeExt.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/model/ShapeExt.kt @@ -130,7 +130,7 @@ fun OperationIndex.operationSignature( val hasOutputStream = outputShape.map { it.hasStreamingMember(model) }.orElse(false) val inputParam = input.map { - if (includeOptionalDefault && inputShape.get().isOptional()) "input: $it = $it {}" else "input: $it" + if (includeOptionalDefault && inputShape.get().hasAllOptionalMembers) "input: $it = $it { }" else "input: $it" }.orElse("") val outputParam = output.map { ": $it" }.orElse("") @@ -245,9 +245,10 @@ fun UnionShape.filterEventStreamErrors(model: Model): Collection { } /** - * Test if a shape is optional. + * Test if a shape has all optional members (no member marked `@required`) */ -fun Shape.isOptional(): Boolean = members().none { it.isRequired } +val Shape.hasAllOptionalMembers: Boolean + get() = members().none { it.isRequired } /** * Derive the input and output symbols for an operation. diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGenerator.kt index 5393ccc4e..059c423f6 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes import software.amazon.smithy.kotlin.codegen.model.SymbolProperty import software.amazon.smithy.kotlin.codegen.model.expectShape +import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers import software.amazon.smithy.kotlin.codegen.model.hasTrait import software.amazon.smithy.kotlin.codegen.model.traits.PaginationTruncationMember import software.amazon.smithy.kotlin.codegen.utils.getOrNull @@ -29,7 +30,6 @@ import software.amazon.smithy.model.shapes.MapShape import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape import software.amazon.smithy.model.shapes.Shape -import software.amazon.smithy.model.shapes.StructureShape import software.amazon.smithy.model.traits.PaginatedTrait /** @@ -54,60 +54,52 @@ class PaginatorGenerator : KotlinIntegration { ?: throw CodegenException("Unexpectedly unable to get PaginationInfo from $service $paginatedOperation") val paginationItemInfo = getItemDescriptorOrNull(paginationInfo, ctx) - renderPaginatorForOperation(writer, ctx, service, paginatedOperation, paginationInfo, paginationItemInfo) + renderPaginatorForOperation(ctx, writer, paginatedOperation, paginationInfo, paginationItemInfo) } } } // Render paginator(s) for operation private fun renderPaginatorForOperation( - writer: KotlinWriter, ctx: CodegenContext, - service: ServiceShape, + writer: KotlinWriter, paginatedOperation: OperationShape, paginationInfo: PaginationInfo, itemDesc: ItemDescriptor?, ) { - val serviceSymbol = ctx.symbolProvider.toSymbol(service) - val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output) - val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input) - val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get() - val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember) - renderResponsePaginator( + ctx, writer, - serviceSymbol, paginatedOperation, - inputSymbol, - paginationInfo.output, - outputSymbol, paginationInfo, - cursorSymbol, ) // Optionally generate paginator when nested item is specified on the trait. if (itemDesc != null) { renderItemPaginator( + ctx, writer, - service, paginatedOperation, itemDesc, - outputSymbol, ) } } // Generate the paginator that iterates over responses private fun renderResponsePaginator( + ctx: CodegenContext, writer: KotlinWriter, - serviceSymbol: Symbol, operationShape: OperationShape, - inputSymbol: Symbol, - outputShape: StructureShape, - outputSymbol: Symbol, paginationInfo: PaginationInfo, - cursorSymbol: Symbol, ) { + val service = ctx.model.expectShape(ctx.settings.service) + val serviceSymbol = ctx.symbolProvider.toSymbol(service) + val outputShape = paginationInfo.output + val outputSymbol = ctx.symbolProvider.toSymbol(outputShape) + val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input) + val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get() + val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember) + val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") { it.defaultName() } @@ -124,6 +116,12 @@ class PaginatorGenerator : KotlinIntegration { """.trimIndent() val docReturn = "@return A [kotlinx.coroutines.flow.Flow] that can collect [${outputSymbol.name}]" + val inputParameter = if (paginationInfo.input.hasAllOptionalMembers) { + writer.format("initialRequest: #1T = #1T { }", inputSymbol) + } else { + writer.format("initialRequest: #T", inputSymbol) + } + writer.write("") writer .dokka( @@ -135,11 +133,12 @@ class PaginatorGenerator : KotlinIntegration { ) .addImportReferences(cursorSymbol, SymbolReference.ContextOption.DECLARE) .withBlock( - "public fun #T.#LPaginated(initialRequest: #T): #T<#T> =", + "#L fun #T.#LPaginated(#L): #T<#T> =", "", + ctx.settings.api.visibility, serviceSymbol, operationShape.defaultName(), - inputSymbol, + inputParameter, ExternalTypes.KotlinxCoroutines.Flow, outputSymbol, ) { @@ -180,8 +179,9 @@ class PaginatorGenerator : KotlinIntegration { """.trimMargin(), ) .withBlock( - "public fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =", + "#L fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =", "", + ctx.settings.api.visibility, serviceSymbol, operationShape.defaultName(), inputSymbol, @@ -195,12 +195,15 @@ class PaginatorGenerator : KotlinIntegration { // Generate a paginator that iterates over the model-specified item private fun renderItemPaginator( + ctx: CodegenContext, writer: KotlinWriter, - serviceShape: ServiceShape, operationShape: OperationShape, itemDesc: ItemDescriptor, - outputSymbol: Symbol, ) { + val serviceShape = ctx.model.expectShape(ctx.settings.service) + val outputShape = ctx.model.expectShape(operationShape.outputShape) + val outputSymbol = ctx.symbolProvider.toSymbol(outputShape) + writer.write("") writer.dokka( """ @@ -223,8 +226,9 @@ class PaginatorGenerator : KotlinIntegration { itemDesc.targetMember.defaultName(serviceShape), ) .withBlock( - "public fun #T<#T>.#L(): #T<#L> =", + "#L fun #T<#T>.#L(): #T<#L> =", "", + ctx.settings.api.visibility, ExternalTypes.KotlinxCoroutines.Flow, outputSymbol, itemDesc.itemLiteral, diff --git a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/WaiterGenerator.kt b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/WaiterGenerator.kt index 40f3525a8..552a352a1 100644 --- a/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/WaiterGenerator.kt +++ b/codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/WaiterGenerator.kt @@ -9,6 +9,7 @@ import software.amazon.smithy.kotlin.codegen.core.KotlinWriter import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes import software.amazon.smithy.kotlin.codegen.core.withBlock import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes +import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers import java.text.DecimalFormat import java.text.DecimalFormatSymbols @@ -34,12 +35,18 @@ private fun KotlinWriter.renderRetryStrategy(wi: WaiterInfo, asValName: String) internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) { write("") wi.waiter.documentation.ifPresent(::dokka) + val inputParameter = if (wi.input.hasAllOptionalMembers) { + format("request: #1T = #1T { }", wi.inputSymbol) + } else { + format("request: #T", wi.inputSymbol) + } withBlock( - "public suspend fun #T.#L(request: #T): #T<#T> {", + "#L suspend fun #T.#L(#L): #T<#T> {", "}", + wi.ctx.settings.api.visibility, wi.serviceSymbol, wi.methodName, - wi.inputSymbol, + inputParameter, RuntimeTypes.Core.Retries.Outcome, wi.outputSymbol, ) { @@ -54,7 +61,8 @@ internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) { write("") wi.waiter.documentation.ifPresent(this::dokka) write( - "public suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =", + "#L suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =", + wi.ctx.settings.api.visibility, wi.serviceSymbol, wi.methodName, wi.inputSymbol, diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGeneratorTest.kt index 7fcdfb2eb..99e2ba1b1 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/PaginatorGeneratorTest.kt @@ -93,7 +93,7 @@ class PaginatorGeneratorTest { * @param initialRequest A [ListFunctionsRequest] to start pagination * @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse] */ - public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow = + public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow = flow { var cursor: kotlin.String? = null var hasNextPage: Boolean = true @@ -204,7 +204,7 @@ class PaginatorGeneratorTest { * @param initialRequest A [ListFunctionsRequest] to start pagination * @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse] */ - public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow = + public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow = flow { var cursor: kotlin.String? = null var hasNextPage: Boolean = true @@ -333,7 +333,7 @@ class PaginatorGeneratorTest { val actual = testManifest.expectFileString("src/main/kotlin/smithy/kotlin/traits/paginators/Paginators.kt") val expectedCode = """ - public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow = + public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow = flow { var cursor: kotlin.String? = null var hasNextPage: Boolean = true @@ -352,4 +352,76 @@ class PaginatorGeneratorTest { actual.shouldContainOnlyOnceWithDiff(expectedCode) } + + @Test + fun testRenderPaginatorWithRequiredInputMembers() { + val testModelNoItem = """ + namespace com.test + + use aws.protocols#restJson1 + + service Lambda { + operations: [ListFunctions] + } + + @paginated( + inputToken: "Marker", + outputToken: "NextMarker", + pageSize: "MaxItems" + ) + @readonly + @http(method: "GET", uri: "/functions", code: 200) + operation ListFunctions { + input: ListFunctionsRequest, + output: ListFunctionsResponse + } + + structure ListFunctionsRequest { + @required + @httpQuery("FunctionVersion") + FunctionVersion: String, + @httpQuery("Marker") + Marker: String, + @httpQuery("MasterRegion") + MasterRegion: String, + @httpQuery("MaxItems") + MaxItems: Integer + } + + structure ListFunctionsResponse { + Functions: FunctionConfigurationList, + NextMarker: String + } + + list FunctionConfigurationList { + member: FunctionConfiguration + } + + structure FunctionConfiguration { + FunctionName: String + } + """.toSmithyModel() + val testContextNoItem = testModelNoItem.newTestContext("Lambda", "com.test") + + val codegenContextNoItem = object : CodegenContext { + override val model: Model = testContextNoItem.generationCtx.model + override val symbolProvider: SymbolProvider = testContextNoItem.generationCtx.symbolProvider + override val settings: KotlinSettings = testContextNoItem.generationCtx.settings + override val protocolGenerator: ProtocolGenerator = testContextNoItem.generator + override val integrations: List = testContextNoItem.generationCtx.integrations + } + + val unit = PaginatorGenerator() + unit.writeAdditionalFiles(codegenContextNoItem, testContextNoItem.generationCtx.delegator) + + testContextNoItem.generationCtx.delegator.flushWriters() + val testManifest = testContextNoItem.generationCtx.delegator.fileManifest as MockManifest + val actual = testManifest.expectFileString("src/main/kotlin/com/test/paginators/Paginators.kt") + + val expected = """ + public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow = + """.trimIndent() + + actual.shouldContainOnlyOnceWithDiff(expected) + } } diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ServiceClientGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ServiceClientGeneratorTest.kt index 41576e571..a153fbef8 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ServiceClientGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/ServiceClientGeneratorTest.kt @@ -38,13 +38,13 @@ class ServiceClientGeneratorTest { fun `it renders signatures correctly`() { val expectedSignatures = listOf( "public suspend fun getFoo(input: GetFooRequest): GetFooResponse", - "public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest {}): GetFooNoRequiredResponse", + "public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest { }): GetFooNoRequiredResponse", "public suspend fun getFooSomeRequired(input: GetFooSomeRequiredRequest): GetFooSomeRequiredResponse", - "public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest {}): GetFooNoInputResponse", + "public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest { }): GetFooNoInputResponse", "public suspend fun getFooNoOutput(input: GetFooNoOutputRequest): GetFooNoOutputResponse", "public suspend fun getFooStreamingInput(input: GetFooStreamingInputRequest): GetFooStreamingInputResponse", "public suspend fun getFooStreamingOutput(input: GetFooStreamingOutputRequest, block: suspend (GetFooStreamingOutputResponse) -> T): T", - "public suspend fun getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest {}, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T", + "public suspend fun getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest { }, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T", "public suspend fun getFooStreamingInputNoOutput(input: GetFooStreamingInputNoOutputRequest): GetFooStreamingInputNoOutputResponse", ) expectedSignatures.forEach { diff --git a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt index 3738b5c8b..7607734f7 100644 --- a/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt +++ b/codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/rendering/waiters/ServiceWaitersGeneratorTest.kt @@ -5,6 +5,7 @@ package software.amazon.smithy.kotlin.codegen.rendering.waiters +import io.kotest.matchers.string.shouldContain import io.kotest.matchers.string.shouldContainOnlyOnce import software.amazon.smithy.build.MockManifest import software.amazon.smithy.codegen.core.SymbolProvider @@ -42,7 +43,7 @@ class ServiceWaitersGeneratorTest { /** * Wait until a foo exists */ - public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest): Outcome { + public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest = DescribeFooRequest { }): Outcome { """.trimIndent() val methodFooter = """ val policy = AcceptorRetryPolicy(request, acceptors) @@ -52,6 +53,17 @@ class ServiceWaitersGeneratorTest { generated.shouldContain(methodHeader, methodFooter) } + @Test + fun testWaiterSignatureWithRequiredInput() { + val methodHeader = """ + /** + * Wait until a foo exists with required input + */ + public suspend fun TestClient.waitUntilFooRequiredExists(request: DescribeFooRequiredRequest): Outcome { + """.trimIndent() + generated.shouldContainOnlyOnceWithDiff(methodHeader) + } + @Test fun testConvenienceWaiterMethod() { val expected = """ @@ -89,7 +101,7 @@ class ServiceWaitersGeneratorTest { } } """.formatForTest() - generated.shouldContainOnlyOnce(expected) + generated.shouldContain(expected) } private fun generateService(modelResourceName: String): String { diff --git a/codegen/smithy-kotlin-codegen/src/test/resources/software/amazon/smithy/kotlin/codegen/simple-service-with-waiter.smithy b/codegen/smithy-kotlin-codegen/src/test/resources/software/amazon/smithy/kotlin/codegen/simple-service-with-waiter.smithy index 173c9cd65..395577710 100644 --- a/codegen/smithy-kotlin-codegen/src/test/resources/software/amazon/smithy/kotlin/codegen/simple-service-with-waiter.smithy +++ b/codegen/smithy-kotlin-codegen/src/test/resources/software/amazon/smithy/kotlin/codegen/simple-service-with-waiter.smithy @@ -5,7 +5,10 @@ use smithy.waiters#waitable service Test { version: "1.0.0", - operations: [DescribeFoo] + operations: [ + DescribeFoo, + DescribeFooRequired, + ] } @waitable( @@ -33,6 +36,31 @@ operation DescribeFoo { errors: [NotFound, UnknownError] } +@waitable( + FooRequiredExists: { + documentation: "Wait until a foo exists with required input", + acceptors: [ + { + state: "success", + matcher: { + success: true + } + }, + { + state: "retry", + matcher: { + errorType: "NotFound" + } + } + ] + } +) +operation DescribeFooRequired { + input: DescribeFooRequiredInput, + output: DescribeFooOutput, + errors: [NotFound, UnknownError] +} + structure DescribeFooInput { id: String } @@ -41,6 +69,12 @@ structure DescribeFooOutput { name: String } +structure DescribeFooRequiredInput { + @required + id: String +} + + @error("client") structure NotFound {}