diff --git a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/EventStreamGenerator.java b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/EventStreamGenerator.java index 581625b9094..423ae2e4af9 100644 --- a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/EventStreamGenerator.java +++ b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/EventStreamGenerator.java @@ -53,6 +53,32 @@ public static boolean isEventStreamShape(Shape shape) { return shape instanceof UnionShape && shape.hasTrait(StreamingTrait.class); } + public static boolean hasEventStreamInput(GenerationContext context, OperationShape operation) { + Model model = context.getModel(); + EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); + return eventStreamIndex.getInputInfo(operation).isPresent(); + } + + public static UnionShape getEventStreamInputShape(GenerationContext context, OperationShape operation) { + Model model = context.getModel(); + EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); + EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get(); + return eventStreamInfo.getEventStreamTarget().asUnionShape().get(); + } + + public static boolean hasEventStreamOutput(GenerationContext context, OperationShape operation) { + Model model = context.getModel(); + EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); + return eventStreamIndex.getOutputInfo(operation).isPresent(); + } + + public static UnionShape getEventStreamOutputShape(GenerationContext context, OperationShape operation) { + Model model = context.getModel(); + EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); + EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get(); + return eventStreamInfo.getEventStreamTarget().asUnionShape().get(); + } + /** * Generate eventstream serializers, and related serializers for events. * @param context Code generation context instance. @@ -74,13 +100,11 @@ public void generateEventStreamSerializers( TopDownIndex topDownIndex = TopDownIndex.of(model); Set operations = topDownIndex.getContainedOperations(service); - EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); TreeSet eventUnionsToSerialize = new TreeSet<>(); TreeSet eventShapesToMarshall = new TreeSet<>(); for (OperationShape operation : operations) { - if (eventStreamIndex.getInputInfo(operation).isPresent()) { - EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get(); - UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get(); + if (hasEventStreamInput(context, operation)) { + UnionShape eventsUnion = getEventStreamInputShape(context, operation); eventUnionsToSerialize.add(eventsUnion); Set eventShapes = eventsUnion.members().stream() .map(member -> model.expectShape(member.getTarget()).asStructureShape().get()) @@ -124,13 +148,11 @@ public void generateEventStreamDeserializers( TopDownIndex topDownIndex = TopDownIndex.of(model); Set operations = topDownIndex.getContainedOperations(service); - EventStreamIndex eventStreamIndex = EventStreamIndex.of(model); TreeSet eventUnionsToDeserialize = new TreeSet<>(); TreeSet eventShapesToUnmarshall = new TreeSet<>(); for (OperationShape operation : operations) { - if (eventStreamIndex.getOutputInfo(operation).isPresent()) { - EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get(); - UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get(); + if (hasEventStreamOutput(context, operation)) { + UnionShape eventsUnion = getEventStreamOutputShape(context, operation); eventUnionsToDeserialize.add(eventsUnion); Set eventShapes = eventsUnion.members().stream() .map(member -> model.expectShape(member.getTarget()).asStructureShape().get()) @@ -162,7 +184,7 @@ private void generateEventStreamSerializer(GenerationContext context, UnionShape writer.openBlock("const $L = (\n" + " input: any,\n" + " context: $L\n" - + "): any => {", "}", methodName, getEventStreamSerializerContextType(context, eventsUnion), () -> { + + "): any => {", "}", methodName, getEventStreamSerdeContextType(context, eventsUnion), () -> { writer.openBlock("const eventMarshallingVisitor = (event: any): __Message => $T.visit(event, {", "});", eventsUnionSymbol, () -> { eventsUnion.getAllMembers().forEach((memberName, memberShape) -> { @@ -186,7 +208,7 @@ public String getEventSerFunctionName(GenerationContext context, Shape shape) { return getSerFunctionName(context, shape) + "_event"; } - private String getEventStreamSerializerContextType(GenerationContext context, UnionShape eventsUnion) { + private String getEventStreamSerdeContextType(GenerationContext context, UnionShape eventsUnion) { TypeScriptWriter writer = context.getWriter(); writer.addImport("SerdeContext", "__SerdeContext", TypeScriptDependency.AWS_SDK_TYPES.packageName); String contextType = "__SerdeContext"; @@ -354,7 +376,7 @@ private void generateEventStreamDeserializer(GenerationContext context, UnionSha Symbol eventsUnionSymbol = getSymbol(context, eventsUnion); TypeScriptWriter writer = context.getWriter(); Model model = context.getModel(); - String contextType = getEventStreamSerializerContextType(context, eventsUnion); + String contextType = getEventStreamSerdeContextType(context, eventsUnion); writer.openBlock("const $L = (\n" + " output: any,\n" + " context: $L\n" diff --git a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java index 0334c5d1fa4..b11e4bfd095 100644 --- a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java +++ b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/integration/HttpRpcProtocolGenerator.java @@ -18,15 +18,19 @@ import java.util.Set; import java.util.TreeSet; import java.util.logging.Logger; +import java.util.stream.Collectors; import software.amazon.smithy.codegen.core.Symbol; import software.amazon.smithy.codegen.core.SymbolProvider; import software.amazon.smithy.codegen.core.SymbolReference; import software.amazon.smithy.model.knowledge.TopDownIndex; +import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; +import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.Shape; import software.amazon.smithy.model.shapes.StructureShape; import software.amazon.smithy.model.traits.EndpointTrait; import software.amazon.smithy.typescript.codegen.ApplicationProtocol; +import software.amazon.smithy.typescript.codegen.CodegenUtils; import software.amazon.smithy.typescript.codegen.TypeScriptDependency; import software.amazon.smithy.typescript.codegen.TypeScriptWriter; import software.amazon.smithy.utils.OptionalUtils; @@ -46,6 +50,7 @@ public abstract class HttpRpcProtocolGenerator implements ProtocolGenerator { private final Set deserializingDocumentShapes = new TreeSet<>(); private final Set deserializingErrorShapes = new TreeSet<>(); private final boolean isErrorCodeInBody; + private final EventStreamGenerator eventStreamGenerator = new EventStreamGenerator(); /** * Creates a Http RPC protocol generator. @@ -93,7 +98,29 @@ public final ApplicationProtocol getApplicationProtocol() { @Override public void generateSharedComponents(GenerationContext context) { + ServiceShape service = context.getService(); deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error)); + eventStreamGenerator.generateEventStreamSerializers( + context, + service, + getDocumentContentType(), + () -> { + TypeScriptWriter writer = context.getWriter(); + writer.write("body = context.utf8Decoder(body);"); + }, + serializingDocumentShapes + ); + // Error shapes that only referred in the error event of an eventstream + Set errorEventShapes = new TreeSet<>(); + eventStreamGenerator.generateEventStreamDeserializers( + context, + service, + errorEventShapes, + deserializingDocumentShapes, + isErrorCodeInBody + ); + errorEventShapes.removeIf(deserializingErrorShapes::contains); + errorEventShapes.forEach(error -> generateErrorDeserializer(context, error)); generateDocumentBodyShapeSerializers(context, serializingDocumentShapes); generateDocumentBodyShapeDeserializers(context, deserializingDocumentShapes); HttpProtocolGeneratorUtils.generateMetadataDeserializer(context, getApplicationProtocol().getResponseType()); @@ -191,17 +218,17 @@ private void generateOperationSerializer(GenerationContext context, OperationSha // Ensure that the request type is imported. writer.addUseImports(requestType); - writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types"); writer.addImport("Endpoint", "__Endpoint", "@aws-sdk/types"); // e.g., serializeAws_restJson1_1ExecuteStatement String methodName = ProtocolGenerator.getSerFunctionName(symbol, getName()); // Add the normalized input type. Symbol inputType = symbol.expectProperty("inputType", Symbol.class); + String serdeContextType = CodegenUtils.getOperationSerializerContextType(writer, context.getModel(), operation); writer.openBlock("export const $L = async(\n" + " input: $T,\n" - + " context: __SerdeContext\n" - + "): Promise<$T> => {", "}", methodName, inputType, requestType, () -> { + + " context: $L\n" + + "): Promise<$T> => {", "}", methodName, inputType, serdeContextType, requestType, () -> { writeRequestHeaders(context, operation); boolean hasRequestBody = writeRequestBody(context, operation); boolean hasHostPrefix = operation.hasTrait(EndpointTrait.class); @@ -238,13 +265,22 @@ private boolean writeRequestBody(GenerationContext context, OperationShape opera // If there's an input present, we know it's a structure. StructureShape inputShape = context.getModel().expectShape(operation.getInput().get()) .asStructureShape().get(); - - // Track input shapes so their serializers may be generated. - serializingDocumentShapes.add(inputShape); - - // Write the default `body` property. - context.getWriter().write("let body: any;"); - serializeInputDocument(context, operation, inputShape); + TypeScriptWriter writer = context.getWriter(); + // Write the default `body` property. + writer.write("let body: any;"); + if (EventStreamGenerator.hasEventStreamInput(context, operation)) { + // There must only one eventstream member in request structure. + MemberShape member = inputShape.members().stream().collect(Collectors.toList()).get(0); + Shape target = context.getModel().expectShape(member.getTarget()); + Symbol targetSymbol = context.getSymbolProvider().toSymbol(target); + String serFunctionName = ProtocolGenerator.getSerFunctionName(targetSymbol, context.getProtocolName()); + String memberName = member.getMemberName(); + writer.write("body = $L(input.$L, context);", serFunctionName, memberName); + } else { + // Track input shapes so their serializers may be generated. + serializingDocumentShapes.add(inputShape); + serializeInputDocument(context, operation, inputShape); + } return true; } @@ -341,18 +377,19 @@ private void generateOperationDeserializer(GenerationContext context, OperationS // Ensure that the response type is imported. writer.addUseImports(responseType); - writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types"); // e.g., deserializeAws_restJson1_1ExecuteStatement String methodName = ProtocolGenerator.getDeserFunctionName(symbol, getName()); String errorMethodName = methodName + "Error"; + String serdeContextType = CodegenUtils.getOperationDeserializerContextType(writer, context.getModel(), + operation); // Add the normalized output type. Symbol outputType = symbol.expectProperty("outputType", Symbol.class); // Handle the general response. writer.openBlock("export const $L = async(\n" + " output: $T,\n" - + " context: __SerdeContext\n" - + "): Promise<$T> => {", "}", methodName, responseType, outputType, () -> { + + " context: $L\n" + + "): Promise<$T> => {", "}", methodName, responseType, serdeContextType, outputType, () -> { // Redirect error deserialization to the dispatcher writer.openBlock("if (output.statusCode >= 300) {", "}", () -> { writer.write("return $L(output, context);", errorMethodName); @@ -425,17 +462,24 @@ private void readResponseBody(GenerationContext context, OperationShape operatio OptionalUtils.ifPresentOrElse( operation.getOutput(), outputId -> { - // We only need to load the body and prepare a contents object if there is a response. - writer.write("const data: any = await parseBody(output.body, context)"); - writer.write("let contents: any = {};"); - // If there's an output present, we know it's a structure. StructureShape outputShape = context.getModel().expectShape(outputId).asStructureShape().get(); - - // Track output shapes so their deserializers may be generated. - deserializingDocumentShapes.add(outputShape); - - deserializeOutputDocument(context, operation, outputShape); + if (EventStreamGenerator.hasEventStreamOutput(context, operation)) { + // There must only one eventstream member in response structure. + MemberShape member = outputShape.members().stream().collect(Collectors.toList()).get(0); + Shape target = context.getModel().expectShape(member.getTarget()); + Symbol targetSymbol = context.getSymbolProvider().toSymbol(target); + writer.write("const contents = { $L: $L(output.body, context) };", member.getMemberName(), + ProtocolGenerator.getDeserFunctionName(targetSymbol, context.getProtocolName())); + } else { + // We only need to load the body and prepare a contents object if there is a response. + writer.write("const data: any = await parseBody(output.body, context)"); + writer.write("let contents: any = {};"); + // Track output shapes so their deserializers may be generated. + deserializingDocumentShapes.add(outputShape); + + deserializeOutputDocument(context, operation, outputShape); + } }, () -> { // If there is no output, the body still needs to be collected so the process can exit.