Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom trait PaginationTruncationMember #625

Merged
merged 12 commits into from
Dec 1, 2023
13 changes: 13 additions & 0 deletions Sources/ClientRuntime/Pagination/PaginatorSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ where Input.Token: Equatable {
let input: Input
let inputKey: KeyPath<Input, Input.Token?>?
let outputKey: KeyPath<Output, Input.Token?>
var isTruncatedKey: KeyPath<Output, Bool>?
let paginationFunction: (Input) async throws -> Output

public init(input: Input,
inputKey: KeyPath<Input, Input.Token?>? = nil,
outputKey: KeyPath<Output, Input.Token?>,
isTruncatedKey: KeyPath<Output, Bool>? = nil,
paginationFunction: @escaping (Input) async throws -> Output) {
self.input = input
self.inputKey = inputKey
self.outputKey = outputKey
self.isTruncatedKey = isTruncatedKey
self.paginationFunction = paginationFunction
}

Expand All @@ -45,6 +48,16 @@ where Input.Token: Equatable {
if token != nil && token == input[keyPath: sequence.inputKey!] {
break
}

// Use isTruncatedKey from the sequence to check if pagination should continue
if let isTruncatedKey = sequence.isTruncatedKey {
let isTruncated = output[keyPath: isTruncatedKey]
if !isTruncated {
// set token to nil to break out of the next iteration
token = nil
}
}

return output
}
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.PaginatedTrait
import software.amazon.smithy.swift.codegen.core.CodegenContext
import software.amazon.smithy.swift.codegen.customtraits.PaginationTruncationMember
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.defaultName
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.isBoxed
Expand Down Expand Up @@ -55,7 +57,7 @@ class PaginatorGenerator : SwiftIntegration {
service: ServiceShape,
paginatedOperation: OperationShape,
paginationInfo: PaginationInfo,
itemDesc: ItemDescriptor?
itemDesc: ItemDescriptor?,
) {
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output)
Expand All @@ -72,7 +74,7 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol,
outputSymbol,
paginationInfo,
cursorSymbol
cursorSymbol,
)

// Optionally generate paginator when nested item is specified on the trait.
Expand All @@ -83,7 +85,7 @@ class PaginatorGenerator : SwiftIntegration {
paginatedOperation,
itemDesc,
inputSymbol,
outputSymbol
outputSymbol,
)
}
}
Expand All @@ -98,8 +100,10 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol: Symbol,
outputSymbol: Symbol,
paginationInfo: PaginationInfo,
cursorSymbol: Symbol
cursorSymbol: Symbol,
) {
val outputShape = paginationInfo.output

writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") {
it.toLowerCamelCase()
Expand All @@ -122,22 +126,41 @@ class PaginatorGenerator : SwiftIntegration {
this.write(docBody)
}
writer.openBlock(
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {", "}",
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {",
"}",
operationShape.toLowerCamelCase(),
inputSymbol,
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol
outputSymbol,
) {
writer.write(
"return \$N<\$N, \$N>(input: input, inputKey: \\\$N.$markerLiteral, outputKey: \\\$N.$nextMarkerLiteral, paginationFunction: self.\$L(input:))",
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol,
inputSymbol,
outputSymbol,
operationShape.toLowerCamelCase()
)
val isTruncatedFlag = outputShape
.members()
.firstOrNull { it.hasTrait(PaginationTruncationMember.ID) }
?.defaultName()

if (isTruncatedFlag != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than an if-else for the entire return statement, can we construct a smaller string for the isTruncatedKey: param that is different whether or not the isTruncated flag is set?

This would save a lot of duplication

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah I tried that approach at first with StringBuilder and then reverted when it caused errors, Ill try again

writer.write(
"return \$N<\$N, \$N>(input: input, inputKey: \\\$N.$markerLiteral, outputKey: \\\$N.$nextMarkerLiteral, isTruncatedKey: \\\$N.$isTruncatedFlag, paginationFunction: self.\$L(input:))",
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol,
inputSymbol,
outputSymbol,
outputSymbol,
operationShape.toLowerCamelCase(),
)
} else {
writer.write(
"return \$N<\$N, \$N>(input: input, inputKey: \\\$N.$markerLiteral, outputKey: \\\$N.$nextMarkerLiteral, paginationFunction: self.\$L(input:))",
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol,
inputSymbol,
outputSymbol,
operationShape.toLowerCamelCase(),
)
}
}
}

Expand Down Expand Up @@ -212,7 +235,7 @@ private data class ItemDescriptor(
val collectionLiteral: String,
val itemLiteral: String,
val itemPathLiteral: String,
val itemSymbol: Symbol
val itemSymbol: Symbol,
)

/**
Expand All @@ -238,6 +261,6 @@ private fun getItemDescriptorOrNull(paginationInfo: PaginationInfo, ctx: Codegen
collectionLiteral,
itemLiteral,
itemPathLiteral,
ctx.symbolProvider.toSymbol(itemMember)
ctx.symbolProvider.toSymbol(itemMember),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.swift.codegen.customtraits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait

/**
* Indicates the annotated member is a truncation indicator which conveys a non-standard termination condition for
* pagination.
*/
class PaginationTruncationMember(node: ObjectNode) : AnnotationTrait(ID, node) {
companion object {
val ID: ShapeId = ShapeId.from("software.amazon.smithy.swift.codegen.synthetic#paginationTruncationMember")
}

constructor() : this(Node.objectNode())
}
12 changes: 12 additions & 0 deletions smithy-swift-codegen/src/test/kotlin/PaginatorGeneratorTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ class PaginatorGeneratorTest {
contents.shouldContainOnlyOnce(expectedCode)
}

@Test
fun testRenderPaginatorTruncatable() {
val context = setupTests("pagination-truncation.smithy", "software.amazon.smithy.swift.codegen.synthetic#Lambda")
val contents = getFileContents(context.manifest, "/Test/Paginators.swift")
val expected = """
public func listFunctionsTruncatedPaginated(input: ListFunctionsTruncatedInput) -> ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput>(input: input, inputKey: \ListFunctionsTruncatedInput.marker, outputKey: \ListFunctionsTruncatedOutput.nextMarker, isTruncatedKey: \ListFunctionsTruncatedOutput.isTruncated, paginationFunction: self.listFunctionsTruncated(input:))
}
"""
contents.shouldContainOnlyOnce(expected)
}

private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestJsonProtocolGenerator()) { model ->
model.defaultSettings(serviceShapeId, "Test", "2019-12-16", "Test")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
$version: "1.0"

namespace software.amazon.smithy.swift.codegen.synthetic

use aws.protocols#restJson1

@trait(selector: "*")
structure paginationTruncationMember { }

service Lambda {
operations: [ListFunctionsTruncated]
}

list FunctionConfigurationList {
member: FunctionConfiguration
}

structure FunctionConfiguration {
functionName: String
}

@paginated(
inputToken: "marker",
outputToken: "nextMarker",
pageSize: "maxItems"
)
@readonly
@http(method: "GET", uri: "/functions/truncated", code: 200)
operation ListFunctionsTruncated {
input: ListFunctionsRequestTruncated,
output: ListFunctionsResponseTruncated
}

structure ListFunctionsRequestTruncated {
@httpQuery("FunctionVersion")
functionVersion: String,
@httpQuery("Marker")
marker: String,
@httpQuery("MasterRegion")
masterRegion: String,
@httpQuery("MaxItems")
maxItems: Integer,
}

structure ListFunctionsResponseTruncated {
Functions: FunctionConfigurationList,
@paginationTruncationMember
IsTruncated: Boolean,
nextMarker: String
}
Loading