Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom trait PaginationTruncationMember #625

Merged
merged 12 commits into from
Dec 1, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public typealias DeserializeStep<OperationStackOutput> = MiddlewareStep<HttpCont

public let DeserializeStepId = "Deserialize"

public struct DeserializeStepHandler<OperationStackOutput, H: Handler>: Handler
public struct DeserializeStepHandler<OperationStackOutput, H: Handler>: Handler
where H.Context == HttpContext,
H.Input == SdkHttpRequest,
H.Output == OperationOutput<OperationStackOutput> {
Expand Down
2 changes: 1 addition & 1 deletion Sources/ClientRuntime/Networking/Http/SdkHttpClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class SdkHttpClient {
self.engine = engine
}

public func getHandler<OperationStackOutput>()
public func getHandler<OperationStackOutput>()
-> AnyHandler<SdkHttpRequest, OperationOutput<OperationStackOutput>, HttpContext> {

let clientHandler = ClientHandler<OperationStackOutput>(engine: engine)
Expand Down
13 changes: 13 additions & 0 deletions Sources/ClientRuntime/Pagination/PaginatorSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ public struct PaginatorSequence<OperationStackInput: PaginateToken, OperationSta
let input: OperationStackInput
let inputKey: KeyPath<OperationStackInput, OperationStackInput.Token?>?
let outputKey: KeyPath<OperationStackOutput, OperationStackInput.Token?>
var isTruncatedKey: KeyPath<OperationStackOutput, Bool>?
let paginationFunction: (OperationStackInput) async throws -> OperationStackOutput

public init(input: OperationStackInput,
inputKey: KeyPath<OperationStackInput, OperationStackInput.Token?>? = nil,
outputKey: KeyPath<OperationStackOutput, OperationStackInput.Token?>,
isTruncatedKey: KeyPath<OperationStackOutput, Bool>? = nil,
paginationFunction: @escaping (OperationStackInput) async throws -> OperationStackOutput) {
self.input = input
self.inputKey = inputKey
self.outputKey = outputKey
self.isTruncatedKey = isTruncatedKey
self.paginationFunction = paginationFunction
}

Expand All @@ -45,6 +48,16 @@ public struct PaginatorSequence<OperationStackInput: PaginateToken, OperationSta
if token != nil && token == input[keyPath: sequence.inputKey!] {
break
}

// Use isTruncatedKey from the sequence to check if pagination should continue
if let isTruncatedKey = sequence.isTruncatedKey {
let isTruncated = output[keyPath: isTruncatedKey]
if !isTruncated {
// set token to nil to break out of the next iteration
token = nil
}
}

return output
}
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.PaginatedTrait
import software.amazon.smithy.swift.codegen.core.CodegenContext
import software.amazon.smithy.swift.codegen.customtraits.PaginationTruncationMember
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.defaultName
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.isBoxed
Expand Down Expand Up @@ -55,7 +57,7 @@ class PaginatorGenerator : SwiftIntegration {
service: ServiceShape,
paginatedOperation: OperationShape,
paginationInfo: PaginationInfo,
itemDesc: ItemDescriptor?
itemDesc: ItemDescriptor?,
) {
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output)
Expand All @@ -72,7 +74,7 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol,
outputSymbol,
paginationInfo,
cursorSymbol
cursorSymbol,
)

// Optionally generate paginator when nested item is specified on the trait.
Expand All @@ -83,7 +85,7 @@ class PaginatorGenerator : SwiftIntegration {
paginatedOperation,
itemDesc,
inputSymbol,
outputSymbol
outputSymbol,
)
}
}
Expand All @@ -98,8 +100,10 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol: Symbol,
outputSymbol: Symbol,
paginationInfo: PaginationInfo,
cursorSymbol: Symbol
cursorSymbol: Symbol,
) {
val outputShape = paginationInfo.output

writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") {
it.toLowerCamelCase()
Expand All @@ -122,21 +126,26 @@ class PaginatorGenerator : SwiftIntegration {
this.write(docBody)
}
writer.openBlock(
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {", "}",
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {",
"}",
operationShape.toLowerCamelCase(),
inputSymbol,
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol
outputSymbol,
) {
val isTruncatedFlag = outputShape
.members()
.firstOrNull { it.hasTrait(PaginationTruncationMember.ID) }
?.defaultName()

val isTruncatedPart = if (isTruncatedFlag != null) ", isTruncatedKey: \\.$isTruncatedFlag" else ""
writer.write(
"return \$N<\$N, \$N>(input: input, inputKey: \\\$N.$markerLiteral, outputKey: \\\$N.$nextMarkerLiteral, paginationFunction: self.\$L(input:))",
"return \$N<\$N, \$N>(input: input, inputKey: \\.$markerLiteral, outputKey: \\.$nextMarkerLiteral$isTruncatedPart, paginationFunction: self.\$L(input:))",
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol,
inputSymbol,
outputSymbol,
operationShape.toLowerCamelCase()
operationShape.toLowerCamelCase(),
)
}
}
Expand Down Expand Up @@ -212,7 +221,7 @@ private data class ItemDescriptor(
val collectionLiteral: String,
val itemLiteral: String,
val itemPathLiteral: String,
val itemSymbol: Symbol
val itemSymbol: Symbol,
)

/**
Expand All @@ -238,6 +247,6 @@ private fun getItemDescriptorOrNull(paginationInfo: PaginationInfo, ctx: Codegen
collectionLiteral,
itemLiteral,
itemPathLiteral,
ctx.symbolProvider.toSymbol(itemMember)
ctx.symbolProvider.toSymbol(itemMember),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.swift.codegen.customtraits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait

/**
* Indicates the annotated member is a truncation indicator which conveys a non-standard termination condition for
* pagination.
*/
class PaginationTruncationMember(node: ObjectNode) : AnnotationTrait(ID, node) {
companion object {
val ID: ShapeId = ShapeId.from("software.amazon.smithy.swift.codegen.synthetic#paginationTruncationMember")
}

constructor() : this(Node.objectNode())
}
18 changes: 15 additions & 3 deletions smithy-swift-codegen/src/test/kotlin/PaginatorGeneratorTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PaginatorGeneratorTest {
/// - input: A `[ListFunctionsInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `ListFunctionsOutput`
public func listFunctionsPaginated(input: ListFunctionsInput) -> ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \ListFunctionsInput.marker, outputKey: \ListFunctionsOutput.nextMarker, paginationFunction: self.listFunctions(input:))
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, paginationFunction: self.listFunctions(input:))
}
}

Expand Down Expand Up @@ -58,7 +58,7 @@ class PaginatorGeneratorTest {
/// - input: A `[ListFunctionsInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `ListFunctionsOutput`
public func listFunctionsPaginated(input: ListFunctionsInput) -> ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \ListFunctionsInput.marker, outputKey: \ListFunctionsOutput.nextMarker, paginationFunction: self.listFunctions(input:))
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, paginationFunction: self.listFunctions(input:))
}
}

Expand Down Expand Up @@ -100,7 +100,7 @@ class PaginatorGeneratorTest {
/// - input: A `[PaginatedMapInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `PaginatedMapOutput`
public func paginatedMapPaginated(input: PaginatedMapInput) -> ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput> {
return ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput>(input: input, inputKey: \PaginatedMapInput.nextToken, outputKey: \PaginatedMapOutput.inner?.token, paginationFunction: self.paginatedMap(input:))
return ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput>(input: input, inputKey: \.nextToken, outputKey: \.inner?.token, paginationFunction: self.paginatedMap(input:))
}
}

Expand All @@ -125,6 +125,18 @@ class PaginatorGeneratorTest {
contents.shouldContainOnlyOnce(expectedCode)
}

@Test
fun testRenderPaginatorTruncatable() {
val context = setupTests("pagination-truncation.smithy", "software.amazon.smithy.swift.codegen.synthetic#Lambda")
val contents = getFileContents(context.manifest, "/Test/Paginators.swift")
val expected = """
public func listFunctionsTruncatedPaginated(input: ListFunctionsTruncatedInput) -> ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, isTruncatedKey: \.isTruncated, paginationFunction: self.listFunctionsTruncated(input:))
}
"""
contents.shouldContainOnlyOnce(expected)
}

private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestJsonProtocolGenerator()) { model ->
model.defaultSettings(serviceShapeId, "Test", "2019-12-16", "Test")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
$version: "1.0"

namespace software.amazon.smithy.swift.codegen.synthetic

use aws.protocols#restJson1

@trait(selector: "*")
structure paginationTruncationMember { }

service Lambda {
operations: [ListFunctionsTruncated]
}

list FunctionConfigurationList {
member: FunctionConfiguration
}

structure FunctionConfiguration {
functionName: String
}

@paginated(
inputToken: "marker",
outputToken: "nextMarker",
pageSize: "maxItems"
)
@readonly
@http(method: "GET", uri: "/functions/truncated", code: 200)
operation ListFunctionsTruncated {
input: ListFunctionsRequestTruncated,
output: ListFunctionsResponseTruncated
}

structure ListFunctionsRequestTruncated {
@httpQuery("FunctionVersion")
functionVersion: String,
@httpQuery("Marker")
marker: String,
@httpQuery("MasterRegion")
masterRegion: String,
@httpQuery("MaxItems")
maxItems: Integer,
}

structure ListFunctionsResponseTruncated {
Functions: FunctionConfigurationList,
@paginationTruncationMember
IsTruncated: Boolean,
nextMarker: String
}
Loading