Skip to content

Commit

Permalink
feat(codegen): generate paginators and waiters with default parameter…
Browse files Browse the repository at this point in the history
…s for all optional inputs (#959)
  • Loading branch information
aajtodd authored Sep 26, 2023
1 parent 39f85c3 commit 24f9441
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 44 deletions.
5 changes: 5 additions & 0 deletions .changes/4602d073-9393-4496-b63d-85535a2f631a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"id": "4602d073-9393-4496-b63d-85535a2f631a",
"type": "feature",
"description": "Generate paginators and waiters with a default parameter when input shape has all optional members"
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fun String.shouldContain(expectedStart: String, expectedEnd: String) {
}

fun <T> List<T>.indexOfSublistOrNull(sublist: List<T>, startFrom: Int = 0): Int? =
drop(startFrom).windowed(sublist.size).indexOf(sublist)
drop(startFrom).windowed(sublist.size).indexOf(sublist).takeIf { it >= 0 }

/** Format a multi-line string suitable for comparison with codegen, defaults to one level of indention. */
fun String.formatForTest(indent: String = " ") =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ fun OperationIndex.operationSignature(

val hasOutputStream = outputShape.map { it.hasStreamingMember(model) }.orElse(false)
val inputParam = input.map {
if (includeOptionalDefault && inputShape.get().isOptional()) "input: $it = $it {}" else "input: $it"
if (includeOptionalDefault && inputShape.get().hasAllOptionalMembers) "input: $it = $it { }" else "input: $it"
}.orElse("")
val outputParam = output.map { ": $it" }.orElse("")

Expand Down Expand Up @@ -245,9 +245,10 @@ fun UnionShape.filterEventStreamErrors(model: Model): Collection<MemberShape> {
}

/**
* Test if a shape is optional.
* Test if a shape has all optional members (no member marked `@required`)
*/
fun Shape.isOptional(): Boolean = members().none { it.isRequired }
val Shape.hasAllOptionalMembers: Boolean
get() = members().none { it.isRequired }

/**
* Derive the input and output symbols for an operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.SymbolProperty
import software.amazon.smithy.kotlin.codegen.model.expectShape
import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers
import software.amazon.smithy.kotlin.codegen.model.hasTrait
import software.amazon.smithy.kotlin.codegen.model.traits.PaginationTruncationMember
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
Expand All @@ -29,7 +30,6 @@ import software.amazon.smithy.model.shapes.MapShape
import software.amazon.smithy.model.shapes.OperationShape
import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.shapes.StructureShape
import software.amazon.smithy.model.traits.PaginatedTrait

/**
Expand All @@ -54,60 +54,52 @@ class PaginatorGenerator : KotlinIntegration {
?: throw CodegenException("Unexpectedly unable to get PaginationInfo from $service $paginatedOperation")
val paginationItemInfo = getItemDescriptorOrNull(paginationInfo, ctx)

renderPaginatorForOperation(writer, ctx, service, paginatedOperation, paginationInfo, paginationItemInfo)
renderPaginatorForOperation(ctx, writer, paginatedOperation, paginationInfo, paginationItemInfo)
}
}
}

// Render paginator(s) for operation
private fun renderPaginatorForOperation(
writer: KotlinWriter,
ctx: CodegenContext,
service: ServiceShape,
writer: KotlinWriter,
paginatedOperation: OperationShape,
paginationInfo: PaginationInfo,
itemDesc: ItemDescriptor?,
) {
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output)
val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input)
val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get()
val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember)

renderResponsePaginator(
ctx,
writer,
serviceSymbol,
paginatedOperation,
inputSymbol,
paginationInfo.output,
outputSymbol,
paginationInfo,
cursorSymbol,
)

// Optionally generate paginator when nested item is specified on the trait.
if (itemDesc != null) {
renderItemPaginator(
ctx,
writer,
service,
paginatedOperation,
itemDesc,
outputSymbol,
)
}
}

// Generate the paginator that iterates over responses
private fun renderResponsePaginator(
ctx: CodegenContext,
writer: KotlinWriter,
serviceSymbol: Symbol,
operationShape: OperationShape,
inputSymbol: Symbol,
outputShape: StructureShape,
outputSymbol: Symbol,
paginationInfo: PaginationInfo,
cursorSymbol: Symbol,
) {
val service = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputShape = paginationInfo.output
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)
val inputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.input)
val cursorMember = ctx.model.getShape(paginationInfo.inputTokenMember.target).get()
val cursorSymbol = ctx.symbolProvider.toSymbol(cursorMember)

val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") {
it.defaultName()
}
Expand All @@ -124,6 +116,12 @@ class PaginatorGenerator : KotlinIntegration {
""".trimIndent()
val docReturn = "@return A [kotlinx.coroutines.flow.Flow] that can collect [${outputSymbol.name}]"

val inputParameter = if (paginationInfo.input.hasAllOptionalMembers) {
writer.format("initialRequest: #1T = #1T { }", inputSymbol)
} else {
writer.format("initialRequest: #T", inputSymbol)
}

writer.write("")
writer
.dokka(
Expand All @@ -135,11 +133,12 @@ class PaginatorGenerator : KotlinIntegration {
)
.addImportReferences(cursorSymbol, SymbolReference.ContextOption.DECLARE)
.withBlock(
"public fun #T.#LPaginated(initialRequest: #T): #T<#T> =",
"#L fun #T.#LPaginated(#L): #T<#T> =",
"",
ctx.settings.api.visibility,
serviceSymbol,
operationShape.defaultName(),
inputSymbol,
inputParameter,
ExternalTypes.KotlinxCoroutines.Flow,
outputSymbol,
) {
Expand Down Expand Up @@ -180,8 +179,9 @@ class PaginatorGenerator : KotlinIntegration {
""".trimMargin(),
)
.withBlock(
"public fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =",
"#L fun #T.#LPaginated(block: #T.Builder.() -> #T): #T<#T> =",
"",
ctx.settings.api.visibility,
serviceSymbol,
operationShape.defaultName(),
inputSymbol,
Expand All @@ -195,12 +195,15 @@ class PaginatorGenerator : KotlinIntegration {

// Generate a paginator that iterates over the model-specified item
private fun renderItemPaginator(
ctx: CodegenContext,
writer: KotlinWriter,
serviceShape: ServiceShape,
operationShape: OperationShape,
itemDesc: ItemDescriptor,
outputSymbol: Symbol,
) {
val serviceShape = ctx.model.expectShape<ServiceShape>(ctx.settings.service)
val outputShape = ctx.model.expectShape(operationShape.outputShape)
val outputSymbol = ctx.symbolProvider.toSymbol(outputShape)

writer.write("")
writer.dokka(
"""
Expand All @@ -223,8 +226,9 @@ class PaginatorGenerator : KotlinIntegration {
itemDesc.targetMember.defaultName(serviceShape),
)
.withBlock(
"public fun #T<#T>.#L(): #T<#L> =",
"#L fun #T<#T>.#L(): #T<#L> =",
"",
ctx.settings.api.visibility,
ExternalTypes.KotlinxCoroutines.Flow,
outputSymbol,
itemDesc.itemLiteral,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
import software.amazon.smithy.kotlin.codegen.core.withBlock
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
import software.amazon.smithy.kotlin.codegen.model.hasAllOptionalMembers
import java.text.DecimalFormat
import java.text.DecimalFormatSymbols

Expand All @@ -34,12 +35,18 @@ private fun KotlinWriter.renderRetryStrategy(wi: WaiterInfo, asValName: String)
internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) {
write("")
wi.waiter.documentation.ifPresent(::dokka)
val inputParameter = if (wi.input.hasAllOptionalMembers) {
format("request: #1T = #1T { }", wi.inputSymbol)
} else {
format("request: #T", wi.inputSymbol)
}
withBlock(
"public suspend fun #T.#L(request: #T): #T<#T> {",
"#L suspend fun #T.#L(#L): #T<#T> {",
"}",
wi.ctx.settings.api.visibility,
wi.serviceSymbol,
wi.methodName,
wi.inputSymbol,
inputParameter,
RuntimeTypes.Core.Retries.Outcome,
wi.outputSymbol,
) {
Expand All @@ -54,7 +61,8 @@ internal fun KotlinWriter.renderWaiter(wi: WaiterInfo) {
write("")
wi.waiter.documentation.ifPresent(this::dokka)
write(
"public suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =",
"#L suspend fun #T.#L(block: #T.Builder.() -> Unit): #T<#T> =",
wi.ctx.settings.api.visibility,
wi.serviceSymbol,
wi.methodName,
wi.inputSymbol,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class PaginatorGeneratorTest {
* @param initialRequest A [ListFunctionsRequest] to start pagination
* @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse]
*/
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand Down Expand Up @@ -204,7 +204,7 @@ class PaginatorGeneratorTest {
* @param initialRequest A [ListFunctionsRequest] to start pagination
* @return A [kotlinx.coroutines.flow.Flow] that can collect [ListFunctionsResponse]
*/
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand Down Expand Up @@ -333,7 +333,7 @@ class PaginatorGeneratorTest {
val actual = testManifest.expectFileString("src/main/kotlin/smithy/kotlin/traits/paginators/Paginators.kt")

val expectedCode = """
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest = ListFunctionsRequest { }): Flow<ListFunctionsResponse> =
flow {
var cursor: kotlin.String? = null
var hasNextPage: Boolean = true
Expand All @@ -352,4 +352,76 @@ class PaginatorGeneratorTest {

actual.shouldContainOnlyOnceWithDiff(expectedCode)
}

@Test
fun testRenderPaginatorWithRequiredInputMembers() {
val testModelNoItem = """
namespace com.test
use aws.protocols#restJson1
service Lambda {
operations: [ListFunctions]
}
@paginated(
inputToken: "Marker",
outputToken: "NextMarker",
pageSize: "MaxItems"
)
@readonly
@http(method: "GET", uri: "/functions", code: 200)
operation ListFunctions {
input: ListFunctionsRequest,
output: ListFunctionsResponse
}
structure ListFunctionsRequest {
@required
@httpQuery("FunctionVersion")
FunctionVersion: String,
@httpQuery("Marker")
Marker: String,
@httpQuery("MasterRegion")
MasterRegion: String,
@httpQuery("MaxItems")
MaxItems: Integer
}
structure ListFunctionsResponse {
Functions: FunctionConfigurationList,
NextMarker: String
}
list FunctionConfigurationList {
member: FunctionConfiguration
}
structure FunctionConfiguration {
FunctionName: String
}
""".toSmithyModel()
val testContextNoItem = testModelNoItem.newTestContext("Lambda", "com.test")

val codegenContextNoItem = object : CodegenContext {
override val model: Model = testContextNoItem.generationCtx.model
override val symbolProvider: SymbolProvider = testContextNoItem.generationCtx.symbolProvider
override val settings: KotlinSettings = testContextNoItem.generationCtx.settings
override val protocolGenerator: ProtocolGenerator = testContextNoItem.generator
override val integrations: List<KotlinIntegration> = testContextNoItem.generationCtx.integrations
}

val unit = PaginatorGenerator()
unit.writeAdditionalFiles(codegenContextNoItem, testContextNoItem.generationCtx.delegator)

testContextNoItem.generationCtx.delegator.flushWriters()
val testManifest = testContextNoItem.generationCtx.delegator.fileManifest as MockManifest
val actual = testManifest.expectFileString("src/main/kotlin/com/test/paginators/Paginators.kt")

val expected = """
public fun TestClient.listFunctionsPaginated(initialRequest: ListFunctionsRequest): Flow<ListFunctionsResponse> =
""".trimIndent()

actual.shouldContainOnlyOnceWithDiff(expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class ServiceClientGeneratorTest {
fun `it renders signatures correctly`() {
val expectedSignatures = listOf(
"public suspend fun getFoo(input: GetFooRequest): GetFooResponse",
"public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest {}): GetFooNoRequiredResponse",
"public suspend fun getFooNoRequired(input: GetFooNoRequiredRequest = GetFooNoRequiredRequest { }): GetFooNoRequiredResponse",
"public suspend fun getFooSomeRequired(input: GetFooSomeRequiredRequest): GetFooSomeRequiredResponse",
"public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest {}): GetFooNoInputResponse",
"public suspend fun getFooNoInput(input: GetFooNoInputRequest = GetFooNoInputRequest { }): GetFooNoInputResponse",
"public suspend fun getFooNoOutput(input: GetFooNoOutputRequest): GetFooNoOutputResponse",
"public suspend fun getFooStreamingInput(input: GetFooStreamingInputRequest): GetFooStreamingInputResponse",
"public suspend fun <T> getFooStreamingOutput(input: GetFooStreamingOutputRequest, block: suspend (GetFooStreamingOutputResponse) -> T): T",
"public suspend fun <T> getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest {}, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T",
"public suspend fun <T> getFooStreamingOutputNoInput(input: GetFooStreamingOutputNoInputRequest = GetFooStreamingOutputNoInputRequest { }, block: suspend (GetFooStreamingOutputNoInputResponse) -> T): T",
"public suspend fun getFooStreamingInputNoOutput(input: GetFooStreamingInputNoOutputRequest): GetFooStreamingInputNoOutputResponse",
)
expectedSignatures.forEach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package software.amazon.smithy.kotlin.codegen.rendering.waiters

import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldContainOnlyOnce
import software.amazon.smithy.build.MockManifest
import software.amazon.smithy.codegen.core.SymbolProvider
Expand Down Expand Up @@ -42,7 +43,7 @@ class ServiceWaitersGeneratorTest {
/**
* Wait until a foo exists
*/
public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest): Outcome<DescribeFooResponse> {
public suspend fun TestClient.waitUntilFooExists(request: DescribeFooRequest = DescribeFooRequest { }): Outcome<DescribeFooResponse> {
""".trimIndent()
val methodFooter = """
val policy = AcceptorRetryPolicy(request, acceptors)
Expand All @@ -52,6 +53,17 @@ class ServiceWaitersGeneratorTest {
generated.shouldContain(methodHeader, methodFooter)
}

@Test
fun testWaiterSignatureWithRequiredInput() {
val methodHeader = """
/**
* Wait until a foo exists with required input
*/
public suspend fun TestClient.waitUntilFooRequiredExists(request: DescribeFooRequiredRequest): Outcome<DescribeFooRequiredResponse> {
""".trimIndent()
generated.shouldContainOnlyOnceWithDiff(methodHeader)
}

@Test
fun testConvenienceWaiterMethod() {
val expected = """
Expand Down Expand Up @@ -89,7 +101,7 @@ class ServiceWaitersGeneratorTest {
}
}
""".formatForTest()
generated.shouldContainOnlyOnce(expected)
generated.shouldContain(expected)
}

private fun generateService(modelResourceName: String): String {
Expand Down
Loading

0 comments on commit 24f9441

Please sign in to comment.