Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support event stream for RPC protocols #573

Merged
merged 3 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ public static boolean isEventStreamShape(Shape shape) {
return shape instanceof UnionShape && shape.hasTrait(StreamingTrait.class);
}

public static boolean hasEventStreamInput(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
return eventStreamIndex.getInputInfo(operation).isPresent();
}

public static UnionShape getEventStreamInputShape(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get();
return eventStreamInfo.getEventStreamTarget().asUnionShape().get();
}

public static boolean hasEventStreamOutput(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
return eventStreamIndex.getOutputInfo(operation).isPresent();
}

public static UnionShape getEventStreamOutputShape(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get();
return eventStreamInfo.getEventStreamTarget().asUnionShape().get();
}

/**
* Generate eventstream serializers, and related serializers for events.
* @param context Code generation context instance.
Expand All @@ -74,13 +100,11 @@ public void generateEventStreamSerializers(

TopDownIndex topDownIndex = TopDownIndex.of(model);
Set<OperationShape> operations = topDownIndex.getContainedOperations(service);
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
TreeSet<UnionShape> eventUnionsToSerialize = new TreeSet<>();
TreeSet<StructureShape> eventShapesToMarshall = new TreeSet<>();
for (OperationShape operation : operations) {
if (eventStreamIndex.getInputInfo(operation).isPresent()) {
EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get();
UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get();
if (hasEventStreamInput(context, operation)) {
UnionShape eventsUnion = getEventStreamInputShape(context, operation);
eventUnionsToSerialize.add(eventsUnion);
Set<StructureShape> eventShapes = eventsUnion.members().stream()
.map(member -> model.expectShape(member.getTarget()).asStructureShape().get())
Expand Down Expand Up @@ -124,13 +148,11 @@ public void generateEventStreamDeserializers(

TopDownIndex topDownIndex = TopDownIndex.of(model);
Set<OperationShape> operations = topDownIndex.getContainedOperations(service);
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
TreeSet<UnionShape> eventUnionsToDeserialize = new TreeSet<>();
TreeSet<StructureShape> eventShapesToUnmarshall = new TreeSet<>();
for (OperationShape operation : operations) {
if (eventStreamIndex.getOutputInfo(operation).isPresent()) {
EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get();
UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get();
if (hasEventStreamOutput(context, operation)) {
UnionShape eventsUnion = getEventStreamOutputShape(context, operation);
eventUnionsToDeserialize.add(eventsUnion);
Set<StructureShape> eventShapes = eventsUnion.members().stream()
.map(member -> model.expectShape(member.getTarget()).asStructureShape().get())
Expand Down Expand Up @@ -162,7 +184,7 @@ private void generateEventStreamSerializer(GenerationContext context, UnionShape
writer.openBlock("const $L = (\n"
+ " input: any,\n"
+ " context: $L\n"
+ "): any => {", "}", methodName, getEventStreamSerializerContextType(context, eventsUnion), () -> {
+ "): any => {", "}", methodName, getEventStreamSerdeContextType(context, eventsUnion), () -> {
writer.openBlock("const eventMarshallingVisitor = (event: any): __Message => $T.visit(event, {", "});",
eventsUnionSymbol, () -> {
eventsUnion.getAllMembers().forEach((memberName, memberShape) -> {
Expand All @@ -186,7 +208,7 @@ public String getEventSerFunctionName(GenerationContext context, Shape shape) {
return getSerFunctionName(context, shape) + "_event";
}

private String getEventStreamSerializerContextType(GenerationContext context, UnionShape eventsUnion) {
private String getEventStreamSerdeContextType(GenerationContext context, UnionShape eventsUnion) {
TypeScriptWriter writer = context.getWriter();
writer.addImport("SerdeContext", "__SerdeContext", TypeScriptDependency.AWS_SDK_TYPES.packageName);
String contextType = "__SerdeContext";
Expand Down Expand Up @@ -354,7 +376,7 @@ private void generateEventStreamDeserializer(GenerationContext context, UnionSha
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
String contextType = getEventStreamSerializerContextType(context, eventsUnion);
String contextType = getEventStreamSerdeContextType(context, eventsUnion);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: $L\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.EndpointTrait;
import software.amazon.smithy.typescript.codegen.ApplicationProtocol;
import software.amazon.smithy.typescript.codegen.CodegenUtils;
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.utils.OptionalUtils;
Expand All @@ -46,6 +50,7 @@ public abstract class HttpRpcProtocolGenerator implements ProtocolGenerator {
private final Set<Shape> deserializingDocumentShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
private final boolean isErrorCodeInBody;
private final EventStreamGenerator eventStreamGenerator = new EventStreamGenerator();

/**
* Creates a Http RPC protocol generator.
Expand Down Expand Up @@ -93,7 +98,29 @@ public final ApplicationProtocol getApplicationProtocol() {

@Override
public void generateSharedComponents(GenerationContext context) {
ServiceShape service = context.getService();
deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
eventStreamGenerator.generateEventStreamSerializers(
context,
service,
getDocumentContentType(),
() -> {
TypeScriptWriter writer = context.getWriter();
writer.write("body = context.utf8Decoder(body);");
},
serializingDocumentShapes
);
// Error shapes that only referred in the error event of an eventstream
Set<StructureShape> errorEventShapes = new TreeSet<>();
eventStreamGenerator.generateEventStreamDeserializers(
context,
service,
errorEventShapes,
deserializingDocumentShapes,
isErrorCodeInBody
);
errorEventShapes.removeIf(deserializingErrorShapes::contains);
errorEventShapes.forEach(error -> generateErrorDeserializer(context, error));
generateDocumentBodyShapeSerializers(context, serializingDocumentShapes);
generateDocumentBodyShapeDeserializers(context, deserializingDocumentShapes);
HttpProtocolGeneratorUtils.generateMetadataDeserializer(context, getApplicationProtocol().getResponseType());
Expand Down Expand Up @@ -191,17 +218,17 @@ private void generateOperationSerializer(GenerationContext context, OperationSha

// Ensure that the request type is imported.
writer.addUseImports(requestType);
writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
writer.addImport("Endpoint", "__Endpoint", "@aws-sdk/types");
// e.g., serializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getSerFunctionName(symbol, getName());
// Add the normalized input type.
Symbol inputType = symbol.expectProperty("inputType", Symbol.class);
String serdeContextType = CodegenUtils.getOperationSerializerContextType(writer, context.getModel(), operation);

writer.openBlock("export const $L = async(\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "}", methodName, inputType, requestType, () -> {
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, inputType, serdeContextType, requestType, () -> {
writeRequestHeaders(context, operation);
boolean hasRequestBody = writeRequestBody(context, operation);
boolean hasHostPrefix = operation.hasTrait(EndpointTrait.class);
Expand Down Expand Up @@ -238,13 +265,22 @@ private boolean writeRequestBody(GenerationContext context, OperationShape opera
// If there's an input present, we know it's a structure.
StructureShape inputShape = context.getModel().expectShape(operation.getInput().get())
.asStructureShape().get();

// Track input shapes so their serializers may be generated.
serializingDocumentShapes.add(inputShape);

// Write the default `body` property.
context.getWriter().write("let body: any;");
serializeInputDocument(context, operation, inputShape);
TypeScriptWriter writer = context.getWriter();
// Write the default `body` property.
writer.write("let body: any;");
if (EventStreamGenerator.hasEventStreamInput(context, operation)) {
// There must only one eventstream member in request structure.
MemberShape member = inputShape.members().stream().collect(Collectors.toList()).get(0);
Shape target = context.getModel().expectShape(member.getTarget());
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
String serFunctionName = ProtocolGenerator.getSerFunctionName(targetSymbol, context.getProtocolName());
String memberName = member.getMemberName();
writer.write("body = $L(input.$L, context);", serFunctionName, memberName);
} else {
// Track input shapes so their serializers may be generated.
serializingDocumentShapes.add(inputShape);
serializeInputDocument(context, operation, inputShape);
}
return true;
}

Expand Down Expand Up @@ -341,18 +377,19 @@ private void generateOperationDeserializer(GenerationContext context, OperationS

// Ensure that the response type is imported.
writer.addUseImports(responseType);
writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
String serdeContextType = CodegenUtils.getOperationDeserializerContextType(writer, context.getModel(),
operation);
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);

// Handle the general response.
writer.openBlock("export const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "}", methodName, responseType, outputType, () -> {
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, responseType, serdeContextType, outputType, () -> {
// Redirect error deserialization to the dispatcher
writer.openBlock("if (output.statusCode >= 300) {", "}", () -> {
writer.write("return $L(output, context);", errorMethodName);
Expand Down Expand Up @@ -425,17 +462,24 @@ private void readResponseBody(GenerationContext context, OperationShape operatio
OptionalUtils.ifPresentOrElse(
operation.getOutput(),
outputId -> {
// We only need to load the body and prepare a contents object if there is a response.
writer.write("const data: any = await parseBody(output.body, context)");
writer.write("let contents: any = {};");

// If there's an output present, we know it's a structure.
StructureShape outputShape = context.getModel().expectShape(outputId).asStructureShape().get();

// Track output shapes so their deserializers may be generated.
deserializingDocumentShapes.add(outputShape);

deserializeOutputDocument(context, operation, outputShape);
if (EventStreamGenerator.hasEventStreamOutput(context, operation)) {
// There must only one eventstream member in response structure.
MemberShape member = outputShape.members().stream().collect(Collectors.toList()).get(0);
Shape target = context.getModel().expectShape(member.getTarget());
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
writer.write("const contents = { $L: $L(output.body, context) };", member.getMemberName(),
ProtocolGenerator.getDeserFunctionName(targetSymbol, context.getProtocolName()));
} else {
// We only need to load the body and prepare a contents object if there is a response.
writer.write("const data: any = await parseBody(output.body, context)");
writer.write("let contents: any = {};");
// Track output shapes so their deserializers may be generated.
deserializingDocumentShapes.add(outputShape);

deserializeOutputDocument(context, operation, outputShape);
}
},
() -> {
// If there is no output, the body still needs to be collected so the process can exit.
Expand Down