Skip to content

Commit

Permalink
support event stream for RPC protocols (#573)
Browse files Browse the repository at this point in the history
* support eventstream response in RPC protocols

* support eventstream request in RPC protocols

Test by manuall update the kinesis model

* rebase to the eventstream generator interface change
  • Loading branch information
AllanZhengYP authored Jul 22, 2022
1 parent 95ef65a commit 0a3ca16
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ public static boolean isEventStreamShape(Shape shape) {
return shape instanceof UnionShape && shape.hasTrait(StreamingTrait.class);
}

public static boolean hasEventStreamInput(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
return eventStreamIndex.getInputInfo(operation).isPresent();
}

public static UnionShape getEventStreamInputShape(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get();
return eventStreamInfo.getEventStreamTarget().asUnionShape().get();
}

public static boolean hasEventStreamOutput(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
return eventStreamIndex.getOutputInfo(operation).isPresent();
}

public static UnionShape getEventStreamOutputShape(GenerationContext context, OperationShape operation) {
Model model = context.getModel();
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get();
return eventStreamInfo.getEventStreamTarget().asUnionShape().get();
}

/**
* Generate eventstream serializers, and related serializers for events.
* @param context Code generation context instance.
Expand All @@ -74,13 +100,11 @@ public void generateEventStreamSerializers(

TopDownIndex topDownIndex = TopDownIndex.of(model);
Set<OperationShape> operations = topDownIndex.getContainedOperations(service);
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
TreeSet<UnionShape> eventUnionsToSerialize = new TreeSet<>();
TreeSet<StructureShape> eventShapesToMarshall = new TreeSet<>();
for (OperationShape operation : operations) {
if (eventStreamIndex.getInputInfo(operation).isPresent()) {
EventStreamInfo eventStreamInfo = eventStreamIndex.getInputInfo(operation).get();
UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get();
if (hasEventStreamInput(context, operation)) {
UnionShape eventsUnion = getEventStreamInputShape(context, operation);
eventUnionsToSerialize.add(eventsUnion);
Set<StructureShape> eventShapes = eventsUnion.members().stream()
.map(member -> model.expectShape(member.getTarget()).asStructureShape().get())
Expand Down Expand Up @@ -124,13 +148,11 @@ public void generateEventStreamDeserializers(

TopDownIndex topDownIndex = TopDownIndex.of(model);
Set<OperationShape> operations = topDownIndex.getContainedOperations(service);
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
TreeSet<UnionShape> eventUnionsToDeserialize = new TreeSet<>();
TreeSet<StructureShape> eventShapesToUnmarshall = new TreeSet<>();
for (OperationShape operation : operations) {
if (eventStreamIndex.getOutputInfo(operation).isPresent()) {
EventStreamInfo eventStreamInfo = eventStreamIndex.getOutputInfo(operation).get();
UnionShape eventsUnion = eventStreamInfo.getEventStreamTarget().asUnionShape().get();
if (hasEventStreamOutput(context, operation)) {
UnionShape eventsUnion = getEventStreamOutputShape(context, operation);
eventUnionsToDeserialize.add(eventsUnion);
Set<StructureShape> eventShapes = eventsUnion.members().stream()
.map(member -> model.expectShape(member.getTarget()).asStructureShape().get())
Expand Down Expand Up @@ -162,7 +184,7 @@ private void generateEventStreamSerializer(GenerationContext context, UnionShape
writer.openBlock("const $L = (\n"
+ " input: any,\n"
+ " context: $L\n"
+ "): any => {", "}", methodName, getEventStreamSerializerContextType(context, eventsUnion), () -> {
+ "): any => {", "}", methodName, getEventStreamSerdeContextType(context, eventsUnion), () -> {
writer.openBlock("const eventMarshallingVisitor = (event: any): __Message => $T.visit(event, {", "});",
eventsUnionSymbol, () -> {
eventsUnion.getAllMembers().forEach((memberName, memberShape) -> {
Expand All @@ -186,7 +208,7 @@ public String getEventSerFunctionName(GenerationContext context, Shape shape) {
return getSerFunctionName(context, shape) + "_event";
}

private String getEventStreamSerializerContextType(GenerationContext context, UnionShape eventsUnion) {
private String getEventStreamSerdeContextType(GenerationContext context, UnionShape eventsUnion) {
TypeScriptWriter writer = context.getWriter();
writer.addImport("SerdeContext", "__SerdeContext", TypeScriptDependency.AWS_SDK_TYPES.packageName);
String contextType = "__SerdeContext";
Expand Down Expand Up @@ -354,7 +376,7 @@ private void generateEventStreamDeserializer(GenerationContext context, UnionSha
Symbol eventsUnionSymbol = getSymbol(context, eventsUnion);
TypeScriptWriter writer = context.getWriter();
Model model = context.getModel();
String contextType = getEventStreamSerializerContextType(context, eventsUnion);
String contextType = getEventStreamSerdeContextType(context, eventsUnion);
writer.openBlock("const $L = (\n"
+ " output: any,\n"
+ " context: $L\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
import java.util.Set;
import java.util.TreeSet;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.EndpointTrait;
import software.amazon.smithy.typescript.codegen.ApplicationProtocol;
import software.amazon.smithy.typescript.codegen.CodegenUtils;
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
import software.amazon.smithy.typescript.codegen.TypeScriptWriter;
import software.amazon.smithy.utils.OptionalUtils;
Expand All @@ -46,6 +50,7 @@ public abstract class HttpRpcProtocolGenerator implements ProtocolGenerator {
private final Set<Shape> deserializingDocumentShapes = new TreeSet<>();
private final Set<StructureShape> deserializingErrorShapes = new TreeSet<>();
private final boolean isErrorCodeInBody;
private final EventStreamGenerator eventStreamGenerator = new EventStreamGenerator();

/**
* Creates a Http RPC protocol generator.
Expand Down Expand Up @@ -93,7 +98,29 @@ public final ApplicationProtocol getApplicationProtocol() {

@Override
public void generateSharedComponents(GenerationContext context) {
ServiceShape service = context.getService();
deserializingErrorShapes.forEach(error -> generateErrorDeserializer(context, error));
eventStreamGenerator.generateEventStreamSerializers(
context,
service,
getDocumentContentType(),
() -> {
TypeScriptWriter writer = context.getWriter();
writer.write("body = context.utf8Decoder(body);");
},
serializingDocumentShapes
);
// Error shapes that only referred in the error event of an eventstream
Set<StructureShape> errorEventShapes = new TreeSet<>();
eventStreamGenerator.generateEventStreamDeserializers(
context,
service,
errorEventShapes,
deserializingDocumentShapes,
isErrorCodeInBody
);
errorEventShapes.removeIf(deserializingErrorShapes::contains);
errorEventShapes.forEach(error -> generateErrorDeserializer(context, error));
generateDocumentBodyShapeSerializers(context, serializingDocumentShapes);
generateDocumentBodyShapeDeserializers(context, deserializingDocumentShapes);
HttpProtocolGeneratorUtils.generateMetadataDeserializer(context, getApplicationProtocol().getResponseType());
Expand Down Expand Up @@ -191,17 +218,17 @@ private void generateOperationSerializer(GenerationContext context, OperationSha

// Ensure that the request type is imported.
writer.addUseImports(requestType);
writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
writer.addImport("Endpoint", "__Endpoint", "@aws-sdk/types");
// e.g., serializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getSerFunctionName(symbol, getName());
// Add the normalized input type.
Symbol inputType = symbol.expectProperty("inputType", Symbol.class);
String serdeContextType = CodegenUtils.getOperationSerializerContextType(writer, context.getModel(), operation);

writer.openBlock("export const $L = async(\n"
+ " input: $T,\n"
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "}", methodName, inputType, requestType, () -> {
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, inputType, serdeContextType, requestType, () -> {
writeRequestHeaders(context, operation);
boolean hasRequestBody = writeRequestBody(context, operation);
boolean hasHostPrefix = operation.hasTrait(EndpointTrait.class);
Expand Down Expand Up @@ -238,13 +265,22 @@ private boolean writeRequestBody(GenerationContext context, OperationShape opera
// If there's an input present, we know it's a structure.
StructureShape inputShape = context.getModel().expectShape(operation.getInput().get())
.asStructureShape().get();

// Track input shapes so their serializers may be generated.
serializingDocumentShapes.add(inputShape);

// Write the default `body` property.
context.getWriter().write("let body: any;");
serializeInputDocument(context, operation, inputShape);
TypeScriptWriter writer = context.getWriter();
// Write the default `body` property.
writer.write("let body: any;");
if (EventStreamGenerator.hasEventStreamInput(context, operation)) {
// There must only one eventstream member in request structure.
MemberShape member = inputShape.members().stream().collect(Collectors.toList()).get(0);
Shape target = context.getModel().expectShape(member.getTarget());
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
String serFunctionName = ProtocolGenerator.getSerFunctionName(targetSymbol, context.getProtocolName());
String memberName = member.getMemberName();
writer.write("body = $L(input.$L, context);", serFunctionName, memberName);
} else {
// Track input shapes so their serializers may be generated.
serializingDocumentShapes.add(inputShape);
serializeInputDocument(context, operation, inputShape);
}
return true;
}

Expand Down Expand Up @@ -341,18 +377,19 @@ private void generateOperationDeserializer(GenerationContext context, OperationS

// Ensure that the response type is imported.
writer.addUseImports(responseType);
writer.addImport("SerdeContext", "__SerdeContext", "@aws-sdk/types");
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
String serdeContextType = CodegenUtils.getOperationDeserializerContextType(writer, context.getModel(),
operation);
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);

// Handle the general response.
writer.openBlock("export const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext\n"
+ "): Promise<$T> => {", "}", methodName, responseType, outputType, () -> {
+ " context: $L\n"
+ "): Promise<$T> => {", "}", methodName, responseType, serdeContextType, outputType, () -> {
// Redirect error deserialization to the dispatcher
writer.openBlock("if (output.statusCode >= 300) {", "}", () -> {
writer.write("return $L(output, context);", errorMethodName);
Expand Down Expand Up @@ -425,17 +462,24 @@ private void readResponseBody(GenerationContext context, OperationShape operatio
OptionalUtils.ifPresentOrElse(
operation.getOutput(),
outputId -> {
// We only need to load the body and prepare a contents object if there is a response.
writer.write("const data: any = await parseBody(output.body, context)");
writer.write("let contents: any = {};");

// If there's an output present, we know it's a structure.
StructureShape outputShape = context.getModel().expectShape(outputId).asStructureShape().get();

// Track output shapes so their deserializers may be generated.
deserializingDocumentShapes.add(outputShape);

deserializeOutputDocument(context, operation, outputShape);
if (EventStreamGenerator.hasEventStreamOutput(context, operation)) {
// There must only one eventstream member in response structure.
MemberShape member = outputShape.members().stream().collect(Collectors.toList()).get(0);
Shape target = context.getModel().expectShape(member.getTarget());
Symbol targetSymbol = context.getSymbolProvider().toSymbol(target);
writer.write("const contents = { $L: $L(output.body, context) };", member.getMemberName(),
ProtocolGenerator.getDeserFunctionName(targetSymbol, context.getProtocolName()));
} else {
// We only need to load the body and prepare a contents object if there is a response.
writer.write("const data: any = await parseBody(output.body, context)");
writer.write("let contents: any = {};");
// Track output shapes so their deserializers may be generated.
deserializingDocumentShapes.add(outputShape);

deserializeOutputDocument(context, operation, outputShape);
}
},
() -> {
// If there is no output, the body still needs to be collected so the process can exit.
Expand Down

0 comments on commit 0a3ca16

Please sign in to comment.