Skip to content

Commit

Permalink
feat: generate unified error dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhe committed Jan 31, 2024
1 parent 280ef3a commit 3617f09
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 37 deletions.
2 changes: 2 additions & 0 deletions .changeset/shy-nails-wonder.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
---
---
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ public void generateResponseDeserializers(GenerationContext context) {

Set<OperationShape> containedOperations = new TreeSet<>(
topDownIndex.getContainedOperations(context.getService()));

for (OperationShape operation : containedOperations) {
OptionalUtils.ifPresentOrElse(
operation.getTrait(HttpTrait.class),
Expand All @@ -525,6 +526,18 @@ public void generateResponseDeserializers(GenerationContext context) {
"Unable to generate %s protocol response bindings for %s because it does not have an "
+ "http binding trait", getName(), operation.getId())));
}

SymbolReference responseType = getApplicationProtocol().getResponseType();
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
context,
containedOperations.stream().toList(),
responseType,
this::writeErrorCodeParser,
isErrorCodeInBody,
this::getErrorBodyLocation,
this::getOperationErrors
);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateOperationResponseSerializer(
Expand Down Expand Up @@ -2091,7 +2104,7 @@ private void generateOperationResponseDeserializer(
// e.g., deserializeAws_restJson1_1ExecuteStatement
String methodName = ProtocolGenerator.getDeserFunctionShortName(symbol);
String methodLongName = ProtocolGenerator.getDeserFunctionName(symbol, getName());
String errorMethodName = methodName + "Error";
String errorMethodName = "de_CommandError";
// Add the normalized output type.
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String contextType = CodegenUtils.getOperationDeserializerContextType(context.getSettings(), writer,
Expand All @@ -2108,7 +2121,7 @@ private void generateOperationResponseDeserializer(
// status code that's not the modeled code (300 or higher). This allows for
// returning other 2XX codes that don't match the defined value.
writer.openBlock("if (output.statusCode !== $L && output.statusCode >= 300) {", "}", trait.getCode(),
() -> writer.write("return $L(output, context);", errorMethodName));
() -> writer.write("return $L(output, context) as any;", errorMethodName));

// Start deserializing the response.
writer.openBlock("const contents: any = map({", "});", () -> {
Expand All @@ -2129,11 +2142,6 @@ private void generateOperationResponseDeserializer(
writer.write("return contents;");
});
writer.write("");
// Write out the error deserialization dispatcher.
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser,
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateErrorDeserializer(GenerationContext context, StructureShape error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,51 +298,49 @@ public static void writeRetryableTrait(TypeScriptWriter writer, StructureShape e

/**
* Writes a function used to dispatch to the proper error deserializer
* for each error that the operation can return. The generated function
* for each error that any operation can return. The generated function
* assumes a deserialization function is generated for the structures
* returned.
*
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorCodeGenerator A consumer
* @param shouldParseErrorBody Flag indicating whether need to parse response body in this dispatcher function
* @param bodyErrorLocationModifier A function that returns the location of an error in a body given a data source.
* @param operationErrorsToShapes A map of error names to their {@link ShapeId}.
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator,
boolean shouldParseErrorBody,
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
static Set<StructureShape> generateUnifiedErrorDispatcher(
GenerationContext context,
List<OperationShape> operations,
SymbolReference responseType,
Consumer<GenerationContext> errorCodeGenerator,
boolean shouldParseErrorBody,
BiFunction<GenerationContext, String, String> bodyErrorLocationModifier,
BiFunction<GenerationContext, List<OperationShape>, Map<String, ShapeId>> operationErrorsToShapes
) {
TypeScriptWriter writer = context.getWriter();
SymbolProvider symbolProvider = context.getSymbolProvider();
Set<StructureShape> errorShapes = new TreeSet<>();

Symbol symbol = symbolProvider.toSymbol(operation);
Symbol outputType = symbol.expectProperty("outputType", Symbol.class);
String errorMethodName = ProtocolGenerator.getDeserFunctionShortName(symbol) + "Error";
String errorMethodLongName = ProtocolGenerator.getDeserFunctionName(symbol, context.getProtocolName())
+ "Error";
String errorMethodName = "de_CommandError";
String errorMethodLongName = "deserialize_"
+ ProtocolGenerator.getSanitizedName(context.getProtocolName())
+ "CommandError";

writer.writeDocs(errorMethodLongName);
writer.openBlock("const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext,\n"
+ "): Promise<$T> => {", "}", errorMethodName, responseType, outputType, () -> {
writer.openBlock("const $L = async(\n"
+ " output: $T,\n"
+ " context: __SerdeContext,\n"
+ "): Promise<unknown> => {", "}", errorMethodName, responseType, () -> {
// Prepare error response for parsing error code. If error code needs to be parsed from response body
// then we collect body and parse it to JS object, otherwise leave the response body as is.
if (shouldParseErrorBody) {
writer.openBlock("const parsedOutput: any = {", "};",
() -> {
writer.write("...output,");
writer.write("body: await parseErrorBody(output.body, context)");
});
() -> {
writer.write("...output,");
writer.write("body: await parseErrorBody(output.body, context)");
});
}

// Error responses must be at least BaseException interface
Expand Down Expand Up @@ -370,7 +368,8 @@ static Set<StructureShape> generateErrorDispatcher(
});
};

Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operation);
Map<String, ShapeId> operationNamesToShapes = operationErrorsToShapes.apply(context, operations);

if (!operationNamesToShapes.isEmpty()) {
writer.openBlock("switch (errorCode) {", "}", () -> {
// Generate the case statement for each error, invoking the specific deserializer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,18 @@ public void generateResponseDeserializers(GenerationContext context) {
for (OperationShape operation : containedOperations) {
generateOperationDeserializer(context, operation);
}

SymbolReference responseType = getApplicationProtocol().getResponseType();
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateUnifiedErrorDispatcher(
context,
containedOperations.stream().toList(),
responseType,
this::writeErrorCodeParser,
isErrorCodeInBody,
this::getErrorBodyLocation,
this::getOperationErrors
);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateOperationSerializer(GenerationContext context, OperationShape operation) {
Expand Down Expand Up @@ -465,12 +477,6 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
writer.write("return response;");
});
writer.write("");

// Write out the error deserialization dispatcher.
Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorCodeParser,
isErrorCodeInBody, this::getErrorBodyLocation, this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
}

private void generateErrorDeserializer(GenerationContext context, StructureShape error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.smithy.typescript.codegen.integration;

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
Expand Down Expand Up @@ -303,6 +304,23 @@ default Map<String, ShapeId> getOperationErrors(GenerationContext context, Opera
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operations the operation shapes to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
default Map<String, ShapeId> getOperationErrors(GenerationContext context, Collection<OperationShape> operations) {
Map<String, ShapeId> errors = new LinkedHashMap<>();
for (OperationShape operation : operations) {
errors.putAll(
getOperationErrors(context, operation)
);
}
return errors;
}

/**
* Context object used for service serialization and deserialization.
*/
Expand Down

0 comments on commit 3617f09

Please sign in to comment.