Skip to content

Commit

Permalink
feat: add custom trait PaginationTruncationMember (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
dayaffe authored Dec 1, 2023
1 parent af13f00 commit 794b36b
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public typealias DeserializeStep<OperationStackOutput> = MiddlewareStep<HttpCont

public let DeserializeStepId = "Deserialize"

public struct DeserializeStepHandler<OperationStackOutput, H: Handler>: Handler
public struct DeserializeStepHandler<OperationStackOutput, H: Handler>: Handler
where H.Context == HttpContext,
H.Input == SdkHttpRequest,
H.Output == OperationOutput<OperationStackOutput> {
Expand Down
2 changes: 1 addition & 1 deletion Sources/ClientRuntime/Networking/Http/SdkHttpClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public class SdkHttpClient {
self.engine = engine
}

public func getHandler<OperationStackOutput>()
public func getHandler<OperationStackOutput>()
-> AnyHandler<SdkHttpRequest, OperationOutput<OperationStackOutput>, HttpContext> {

let clientHandler = ClientHandler<OperationStackOutput>(engine: engine)
Expand Down
13 changes: 13 additions & 0 deletions Sources/ClientRuntime/Pagination/PaginatorSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ public struct PaginatorSequence<OperationStackInput: PaginateToken, OperationSta
let input: OperationStackInput
let inputKey: KeyPath<OperationStackInput, OperationStackInput.Token?>?
let outputKey: KeyPath<OperationStackOutput, OperationStackInput.Token?>
var isTruncatedKey: KeyPath<OperationStackOutput, Bool>?
let paginationFunction: (OperationStackInput) async throws -> OperationStackOutput

public init(input: OperationStackInput,
inputKey: KeyPath<OperationStackInput, OperationStackInput.Token?>? = nil,
outputKey: KeyPath<OperationStackOutput, OperationStackInput.Token?>,
isTruncatedKey: KeyPath<OperationStackOutput, Bool>? = nil,
paginationFunction: @escaping (OperationStackInput) async throws -> OperationStackOutput) {
self.input = input
self.inputKey = inputKey
self.outputKey = outputKey
self.isTruncatedKey = isTruncatedKey
self.paginationFunction = paginationFunction
}

Expand All @@ -45,6 +48,16 @@ public struct PaginatorSequence<OperationStackInput: PaginateToken, OperationSta
if token != nil && token == input[keyPath: sequence.inputKey!] {
break
}

// Use isTruncatedKey from the sequence to check if pagination should continue
if let isTruncatedKey = sequence.isTruncatedKey {
let isTruncated = output[keyPath: isTruncatedKey]
if !isTruncated {
// set token to nil to break out of the next iteration
token = nil
}
}

return output
}
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import software.amazon.smithy.model.shapes.ServiceShape
import software.amazon.smithy.model.shapes.Shape
import software.amazon.smithy.model.traits.PaginatedTrait
import software.amazon.smithy.swift.codegen.core.CodegenContext
import software.amazon.smithy.swift.codegen.customtraits.PaginationTruncationMember
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.SwiftIntegration
import software.amazon.smithy.swift.codegen.model.SymbolProperty
import software.amazon.smithy.swift.codegen.model.defaultName
import software.amazon.smithy.swift.codegen.model.expectShape
import software.amazon.smithy.swift.codegen.model.hasTrait
import software.amazon.smithy.swift.codegen.model.isBoxed
Expand Down Expand Up @@ -55,7 +57,7 @@ class PaginatorGenerator : SwiftIntegration {
service: ServiceShape,
paginatedOperation: OperationShape,
paginationInfo: PaginationInfo,
itemDesc: ItemDescriptor?
itemDesc: ItemDescriptor?,
) {
val serviceSymbol = ctx.symbolProvider.toSymbol(service)
val outputSymbol = ctx.symbolProvider.toSymbol(paginationInfo.output)
Expand All @@ -72,7 +74,7 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol,
outputSymbol,
paginationInfo,
cursorSymbol
cursorSymbol,
)

// Optionally generate paginator when nested item is specified on the trait.
Expand All @@ -83,7 +85,7 @@ class PaginatorGenerator : SwiftIntegration {
paginatedOperation,
itemDesc,
inputSymbol,
outputSymbol
outputSymbol,
)
}
}
Expand All @@ -98,8 +100,10 @@ class PaginatorGenerator : SwiftIntegration {
inputSymbol: Symbol,
outputSymbol: Symbol,
paginationInfo: PaginationInfo,
cursorSymbol: Symbol
cursorSymbol: Symbol,
) {
val outputShape = paginationInfo.output

writer.addImport(SwiftDependency.CLIENT_RUNTIME.target)
val nextMarkerLiteral = paginationInfo.outputTokenMemberPath.joinToString(separator = "?.") {
it.toLowerCamelCase()
Expand All @@ -122,21 +126,26 @@ class PaginatorGenerator : SwiftIntegration {
this.write(docBody)
}
writer.openBlock(
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {", "}",
"public func \$LPaginated(input: \$N) -> \$N<\$N, \$N> {",
"}",
operationShape.toLowerCamelCase(),
inputSymbol,
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol
outputSymbol,
) {
val isTruncatedFlag = outputShape
.members()
.firstOrNull { it.hasTrait(PaginationTruncationMember.ID) }
?.defaultName()

val isTruncatedPart = if (isTruncatedFlag != null) ", isTruncatedKey: \\.$isTruncatedFlag" else ""
writer.write(
"return \$N<\$N, \$N>(input: input, inputKey: \\\$N.$markerLiteral, outputKey: \\\$N.$nextMarkerLiteral, paginationFunction: self.\$L(input:))",
"return \$N<\$N, \$N>(input: input, inputKey: \\.$markerLiteral, outputKey: \\.$nextMarkerLiteral$isTruncatedPart, paginationFunction: self.\$L(input:))",
ClientRuntimeTypes.Core.PaginatorSequence,
inputSymbol,
outputSymbol,
inputSymbol,
outputSymbol,
operationShape.toLowerCamelCase()
operationShape.toLowerCamelCase(),
)
}
}
Expand Down Expand Up @@ -212,7 +221,7 @@ private data class ItemDescriptor(
val collectionLiteral: String,
val itemLiteral: String,
val itemPathLiteral: String,
val itemSymbol: Symbol
val itemSymbol: Symbol,
)

/**
Expand All @@ -238,6 +247,6 @@ private fun getItemDescriptorOrNull(paginationInfo: PaginationInfo, ctx: Codegen
collectionLiteral,
itemLiteral,
itemPathLiteral,
ctx.symbolProvider.toSymbol(itemMember)
ctx.symbolProvider.toSymbol(itemMember),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.swift.codegen.customtraits

import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.node.ObjectNode
import software.amazon.smithy.model.shapes.ShapeId
import software.amazon.smithy.model.traits.AnnotationTrait

/**
* Indicates the annotated member is a truncation indicator which conveys a non-standard termination condition for
* pagination.
*/
class PaginationTruncationMember(node: ObjectNode) : AnnotationTrait(ID, node) {
companion object {
val ID: ShapeId = ShapeId.from("software.amazon.smithy.swift.codegen.synthetic#paginationTruncationMember")
}

constructor() : this(Node.objectNode())
}
18 changes: 15 additions & 3 deletions smithy-swift-codegen/src/test/kotlin/PaginatorGeneratorTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PaginatorGeneratorTest {
/// - input: A `[ListFunctionsInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `ListFunctionsOutput`
public func listFunctionsPaginated(input: ListFunctionsInput) -> ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \ListFunctionsInput.marker, outputKey: \ListFunctionsOutput.nextMarker, paginationFunction: self.listFunctions(input:))
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, paginationFunction: self.listFunctions(input:))
}
}
Expand Down Expand Up @@ -58,7 +58,7 @@ class PaginatorGeneratorTest {
/// - input: A `[ListFunctionsInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `ListFunctionsOutput`
public func listFunctionsPaginated(input: ListFunctionsInput) -> ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \ListFunctionsInput.marker, outputKey: \ListFunctionsOutput.nextMarker, paginationFunction: self.listFunctions(input:))
return ClientRuntime.PaginatorSequence<ListFunctionsInput, ListFunctionsOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, paginationFunction: self.listFunctions(input:))
}
}
Expand Down Expand Up @@ -100,7 +100,7 @@ class PaginatorGeneratorTest {
/// - input: A `[PaginatedMapInput]` to start pagination
/// - Returns: An `AsyncSequence` that can iterate over `PaginatedMapOutput`
public func paginatedMapPaginated(input: PaginatedMapInput) -> ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput> {
return ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput>(input: input, inputKey: \PaginatedMapInput.nextToken, outputKey: \PaginatedMapOutput.inner?.token, paginationFunction: self.paginatedMap(input:))
return ClientRuntime.PaginatorSequence<PaginatedMapInput, PaginatedMapOutput>(input: input, inputKey: \.nextToken, outputKey: \.inner?.token, paginationFunction: self.paginatedMap(input:))
}
}
Expand All @@ -125,6 +125,18 @@ class PaginatorGeneratorTest {
contents.shouldContainOnlyOnce(expectedCode)
}

@Test
fun testRenderPaginatorTruncatable() {
val context = setupTests("pagination-truncation.smithy", "software.amazon.smithy.swift.codegen.synthetic#Lambda")
val contents = getFileContents(context.manifest, "/Test/Paginators.swift")
val expected = """
public func listFunctionsTruncatedPaginated(input: ListFunctionsTruncatedInput) -> ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput> {
return ClientRuntime.PaginatorSequence<ListFunctionsTruncatedInput, ListFunctionsTruncatedOutput>(input: input, inputKey: \.marker, outputKey: \.nextMarker, isTruncatedKey: \.isTruncated, paginationFunction: self.listFunctionsTruncated(input:))
}
"""
contents.shouldContainOnlyOnce(expected)
}

private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext {
val context = TestContext.initContextFrom(smithyFile, serviceShapeId, MockHttpRestJsonProtocolGenerator()) { model ->
model.defaultSettings(serviceShapeId, "Test", "2019-12-16", "Test")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
$version: "1.0"

namespace software.amazon.smithy.swift.codegen.synthetic

use aws.protocols#restJson1

@trait(selector: "*")
structure paginationTruncationMember { }

service Lambda {
operations: [ListFunctionsTruncated]
}

list FunctionConfigurationList {
member: FunctionConfiguration
}

structure FunctionConfiguration {
functionName: String
}

@paginated(
inputToken: "marker",
outputToken: "nextMarker",
pageSize: "maxItems"
)
@readonly
@http(method: "GET", uri: "/functions/truncated", code: 200)
operation ListFunctionsTruncated {
input: ListFunctionsRequestTruncated,
output: ListFunctionsResponseTruncated
}

structure ListFunctionsRequestTruncated {
@httpQuery("FunctionVersion")
functionVersion: String,
@httpQuery("Marker")
marker: String,
@httpQuery("MasterRegion")
masterRegion: String,
@httpQuery("MaxItems")
maxItems: Integer,
}

structure ListFunctionsResponseTruncated {
Functions: FunctionConfigurationList,
@paginationTruncationMember
IsTruncated: Boolean,
nextMarker: String
}

0 comments on commit 794b36b

Please sign in to comment.