Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generate unified error dispatcher #1150

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .changeset/shy-nails-wonder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
---
---
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ public void generateResponseDeserializers(GenerationContext context) {

Set<OperationShape> containedOperations = new TreeSet<>(
topDownIndex.getContainedOperations(context.getService()));

for (OperationShape operation : containedOperations) {
OptionalUtils.ifPresentOrElse(
operation.getTrait(HttpTrait.class),
Expand All @@ -525,6 +526,18 @@ public void generateResponseDeserializers(GenerationContext context) {
"Unable to generate %s protocol response bindings for %s because it does not have an "
+ "http binding trait", getName(), operation.getId())));
}

SymbolReference responseType = getApplicationProtocol().getResponseType();
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
context,
containedOperations.stream().toList(),
responseType,
this::writeErrorCodeParser,
isErrorCodeInBody,
this::getErrorBodyLocation,
this::getOperationErrors
);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateOperationResponseSerializer(
Expand Down Expand Up @@ -2091,7 +2104,7 @@ private void generateOperationResponseDeserializer(
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
String errorMethodName = "de_CommandError";
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String contextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
Expand All @@ -2108,7 +2121,7 @@ private void generateOperationResponseDeserializer(
// status code that's not the modeled code (300 or higher). This allows for
// returning other 2XX codes that don't match the defined value.
writer.openBlock("if (output.statusCode !== $L && output.statusCode >= 300) {", "}", trait.getCode(),
() -> writer.write("return $L(output, context);", errorMethodName));
() -> writer.write("return $L(output, context) as any;", errorMethodName));
syall marked this conversation as resolved.
Show resolved Hide resolved

// Start deserializing the response.
writer.openBlock("const contents: any = map({", "});", () -> {
Expand All @@ -2129,11 +2142,6 @@ private void generateOperationResponseDeserializer(
writer.write("return contents;");
});
writer.write("");
// Write out the error deserialization dispatcher.
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser,
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateErrorDeserializer(GenerationContext context, StructureShape error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,51 +298,49 @@ public static void writeRetryableTrait(TypeScriptWriter writer, StructureShape e

/**
* Writes a function used to dispatch to the proper error deserializer
* for each error that the operation can return. The generated function
* for each error that any operation can return. The generated function
* assumes a deserialization function is generated for the structures
* returned.
*
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorCodeGenerator A consumer
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
* @param bodyErrorLocationModifier A function that returns the location of an error in a body given a data source.
* @param operationErrorsToShapes A map of error names to their {@link ShapeId}.
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator,
boolean shouldParseErrorBody,
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
static Set<StructureShape> generateUnifiedErrorDispatcher(
GenerationContext context,
List<OperationShape> operations,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator,
boolean shouldParseErrorBody,
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
BiFunction<GenerationContext, List<OperationShape>, Map<String, ShapeId>> operationErrorsToShapes
) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();
Set<StructureShape> errorShapes = new TreeSet<>();

Symbol symbol = symbolProvider.toSymbol(operation);
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String errorMethodName = ProtocolGenerator.getDeserFunctionShortName(symbol) + "Error";
String errorMethodLongName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
+ "Error";
String errorMethodName = "de_CommandError";
String errorMethodLongName = "deserialize_"
+ ProtocolGenerator.getSanitizedName(context.getProtocolName())
+ "CommandError";

writer.writeDocs(errorMethodLongName);
writer.openBlock("const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext,\n"
+ "): Promise<$T> => {", "}", errorMethodName, responseType, outputType, () -> {
writer.openBlock("const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext,\n"
+ "): Promise<unknown> => {", "}", errorMethodName, responseType, () -> {
syall marked this conversation as resolved.
Show resolved Hide resolved
// Prepare error response for parsing error code. If error code needs to be parsed from response body
// then we collect body and parse it to JS object, otherwise leave the response body as is.
if (shouldParseErrorBody) {
writer.openBlock("const parsedOutput: any = {", "};",
() -> {
writer.write("...output,");
writer.write("body: await parseErrorBody(output.body, context)");
});
() -> {
writer.write("...output,");
writer.write("body: await parseErrorBody(output.body, context)");
});
}

// Error responses must be at least BaseException interface
Expand Down Expand Up @@ -370,7 +368,8 @@ static Set<StructureShape> generateErrorDispatcher(
});
};

Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operation);
Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operations);

if (!operationNamesToShapes.isEmpty()) {
writer.openBlock("switch (errorCode) {", "}", () -> {
// Generate the case statement for each error, invoking the specific deserializer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ public void generateResponseDeserializers(GenerationContext context) {
for (OperationShape operation : containedOperations) {
generateOperationDeserializer(context, operation);
}

SymbolReference responseType = getApplicationProtocol().getResponseType();
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
context,
containedOperations.stream().toList(),
responseType,
this::writeErrorCodeParser,
isErrorCodeInBody,
this::getErrorBodyLocation,
this::getOperationErrors
);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateOperationSerializer(GenerationContext context, OperationShape operation) {
Expand Down Expand Up @@ -435,7 +447,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
String errorMethodName = "de_CommandError";
String serdeContextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
context.getModel(), operation);
// Add the normalized output type.
Expand All @@ -449,7 +461,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
+ "): Promise<$T> => {", "}", methodName, responseType, serdeContextType, outputType, () -> {
// Redirect error deserialization to the dispatcher
writer.openBlock("if (output.statusCode >= 300) {", "}", () -> {
writer.write("return $L(output, context);", errorMethodName);
writer.write("return $L(output, context) as any;", errorMethodName);
});

// Start deserializing the response.
Expand All @@ -465,12 +477,6 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
writer.write("return response;");
});
writer.write("");

// Write out the error deserialization dispatcher.
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser,
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateErrorDeserializer(GenerationContext context, StructureShape error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.smithy.typescript.codegen.integration;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
Expand Down Expand Up @@ -303,6 +304,23 @@ default Map<String, ShapeId> getOperationErrors(GenerationContext context, Opera
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operations the operation shapes to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
default Map<String, ShapeId> getOperationErrors(GenerationContext context, Collection<OperationShape> operations) {
Map<String, ShapeId> errors = new LinkedHashMap<>();
for (OperationShape operation : operations) {
errors.putAll(
getOperationErrors(context, operation)
);
}
return errors;
}

/**
* Context object used for service serialization and deserialization.
*/
Expand Down
Loading