diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java index 87b63e9e3..b372ba3e9 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/SmithyGoDependency.java @@ -39,6 +39,7 @@ public final class SmithyGoDependency { public static final GoDependency CRYPTORAND = stdlib("crypto/rand", "cryptorand"); public static final GoDependency TESTING = stdlib("testing"); public static final GoDependency ERRORS = stdlib("errors"); + public static final GoDependency XML = stdlib("encoding/xml"); public static final GoDependency SMITHY = smithy(null, "smithy"); public static final GoDependency SMITHY_HTTP_TRANSPORT = smithy("transport/http", "smithyhttp"); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DocumentShapeDeserVisitor.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DocumentShapeDeserVisitor.java index f21bbea7b..129016af2 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DocumentShapeDeserVisitor.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/DocumentShapeDeserVisitor.java @@ -41,8 +41,8 @@ /** * Visitor to generate deserialization functions for shapes in protocol document bodies. * - * Visitor methods for aggregate types are final and will generate functions that dispatch - * their loading from the body to the matching abstract method. + * Visitor methods for aggregate types except maps and collections are final and will + * generate functions that dispatch their loading from the body to the matching abstract method. * * Visitor methods for all other types will default to not generating deserialization * functions. This may be overwritten by downstream implementations if the protocol requires @@ -472,7 +472,7 @@ public final Void documentShape(DocumentShape shape) { * @return null */ @Override - public final Void listShape(ListShape shape) { + public Void listShape(ListShape shape) { generateDeserFunction(shape, (c, s) -> deserializeCollection(c, s.asListShape().get())); return null; } @@ -484,7 +484,7 @@ public final Void listShape(ListShape shape) { * @return null */ @Override - public final Void mapShape(MapShape shape) { + public Void mapShape(MapShape shape) { generateDeserFunction(shape, (c, s) -> deserializeMap(c, s.asMapShape().get())); return null; } @@ -496,7 +496,7 @@ public final Void mapShape(MapShape shape) { * @return null */ @Override - public final Void setShape(SetShape shape) { + public Void setShape(SetShape shape) { generateDeserFunction(shape, (c, s) -> deserializeCollection(c, s.asSetShape().get())); return null; } diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java index 852a9a2e4..4740235ed 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpBindingProtocolGenerator.java @@ -280,6 +280,7 @@ private void generateOperationDeserializerMiddleware(GenerationContext context, ApplicationProtocol applicationProtocol = getApplicationProtocol(); Symbol responseType = applicationProtocol.getResponseType(); GoWriter goWriter = context.getWriter(); + String errorFunctionName = ProtocolGenerator.getOperationErrorDeserFunctionName( operation, context.getProtocolName()); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolGeneratorUtils.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolGeneratorUtils.java index da0004481..96cc8c02b 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolGeneratorUtils.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpProtocolGeneratorUtils.java @@ -45,7 +45,7 @@ private HttpProtocolGeneratorUtils() {} * and {@code errorMessage} variables from the http response. * @return A set of all error structure shapes for the operation that were dispatched to. */ - static Set generateErrorDispatcher( + static Set generateErrorDispatcher( GenerationContext context, OperationShape operation, Symbol responseType, @@ -78,13 +78,14 @@ static Set generateErrorDispatcher( // Dispatch to the message/code generator to try to get the specific code and message. errorMessageCodeGenerator.accept(context); - writer.openBlock("switch errorCode {", "}", () -> { + writer.openBlock("switch {", "}", () -> { new TreeSet<>(operation.getErrors()).forEach(errorId -> { StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get(); errorShapes.add(error); String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName( error, context.getProtocolName()); - writer.openBlock("case $S:", "", errorId.getName(), () -> { + writer.addUseImports(SmithyGoDependency.STRINGS); + writer.openBlock("case strings.EqualFold($S, errorCode):", "", errorId.getName(), () -> { writer.write("return $L(response, errorBody)", errorDeserFunctionName); }); }); diff --git a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java index b43af02f3..0c0dd247b 100644 --- a/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java +++ b/codegen/smithy-go-codegen/src/main/java/software/amazon/smithy/go/codegen/integration/HttpRpcProtocolGenerator.java @@ -120,7 +120,7 @@ private void generateOperationSerializer(GenerationContext context, OperationSha writer.write("request.Request.URL.Path = $S", getOperationPath(context, operation)); writer.write("request.Request.Method = \"POST\""); writer.write("httpBindingEncoder, err := httpbinding.NewEncoder(request.URL.Path, " - + "request.URL.RawQuery, request.Header)"); + + "request.URL.RawQuery, request.Header)"); writer.openBlock("if err != nil {", "}", () -> { writer.write("return out, metadata, &smithy.SerializationError{Err: err}"); }); @@ -160,16 +160,16 @@ private void writeRequestHeaders(GenerationContext context, OperationShape opera *
  • {@code ctx: context.Context}: a type containing context and tools for type serde.
  • * * - * @param context The generation context. + * @param context The generation context. * @param operation The operation being generated. - * @param writer The writer to use. + * @param writer The writer to use. */ protected void writeDefaultHeaders(GenerationContext context, OperationShape operation, GoWriter writer) {} /** * Provides the request path for the operation. * - * @param context The generation context. + * @param context The generation context. * @param operation The operation being generated. * @return The path to send HTTP requests to. */ @@ -185,7 +185,7 @@ protected void writeDefaultHeaders(GenerationContext context, OperationShape ope *
  • {@code ctx: context.Context}: a type containing context and tools for type serde.
  • * * - * @param context The generation context. + * @param context The generation context. * @param operation The operation to serialize for. */ protected abstract void serializeInputDocument(GenerationContext context, OperationShape operation); @@ -204,7 +204,7 @@ public void generateSharedDeserializerComponents(GenerationContext context) { * {@code deserializeOutputDocument}. * * @param context The generation context. - * @param shapes The shapes to generate deserialization for. + * @param shapes The shapes to generate deserialization for. */ protected abstract void generateDocumentBodyShapeDeserializers(GenerationContext context, Set shapes); @@ -280,7 +280,7 @@ private void generateOperationDeserializer(GenerationContext context, OperationS *
  • {@code ctx: context.Context}: a type containing context and tools for type serde.
  • * * - * @param context The generation context + * @param context The generation context * @param operation The operation to deserialize for. */ protected abstract void deserializeOutputDocument(GenerationContext context, OperationShape operation); @@ -306,7 +306,7 @@ private void generateErrorDeserializer(GenerationContext context, StructureShape * * * @param context The generation context. - * @param shape The error shape. + * @param shape The error shape. */ protected abstract void deserializeError(GenerationContext context, StructureShape shape); diff --git a/xml/xml_decoder.go b/xml/xml_decoder.go new file mode 100644 index 000000000..b8fbf2663 --- /dev/null +++ b/xml/xml_decoder.go @@ -0,0 +1,87 @@ +package xml + +import ( + "encoding/xml" + "fmt" +) + +// NodeDecoder is a XML decoder wrapper that is responsible to decoding +// a single XML Node element and it's nested member elements. This wrapper decoder +// takes in the start element of the top level node being decoded. +type NodeDecoder struct { + Decoder *xml.Decoder + StartEl xml.StartElement +} + +// WrapNodeDecoder returns an initialized XMLNodeDecoder +func WrapNodeDecoder(decoder *xml.Decoder, startEl xml.StartElement) NodeDecoder { + return NodeDecoder{ + Decoder: decoder, + StartEl: startEl, + } +} + +// Token on a Node Decoder returns a xml StartElement. It returns a boolean that indicates the +// a token is the node decoder's end node token; and an error which indicates any error +// that occurred while retrieving the start element +func (d NodeDecoder) Token() (t xml.StartElement, done bool, err error) { + for { + token, e := d.Decoder.Token() + if e != nil { + return t, done, e + } + + // check if we reach end of the node being decoded + if el, ok := token.(xml.EndElement); ok { + return t, el == d.StartEl.End(), err + } + + if t, ok := token.(xml.StartElement); ok { + return t, false, err + } + + // skip token if it is a comment or preamble or empty space value due to indentation + // or if it's a value and is not expected + } + + return +} + +// Value provides an abstraction to retrieve char data value within an xml element. +// The method will return an error if it encounters a nested xml element instead of char data. +// This method should only be used to retrieve simple type or blob shape values as []byte. +func (d NodeDecoder) Value() (c []byte, done bool, err error) { + t, e := d.Decoder.Token() + if e != nil { + return c, done, e + } + + // check if token is of type charData + if ev, ok := t.(xml.CharData); ok { + return ev, done, err + } + + if ev, ok := t.(xml.EndElement); ok { + if ev == d.StartEl.End() { + return c, true, err + } + } + + return c, done, fmt.Errorf("expected value for %v element, got %T type %v instead", d.StartEl.Name.Local, t, t) +} + +// FetchRootElement takes in a decoder and returns the first start element within the xml body. +// This function is useful in fetching the start element of an XML response and ignore the +// comments and preamble +func FetchRootElement(decoder *xml.Decoder) (startElement xml.StartElement, err error) { + for { + t, e := decoder.Token() + if e != nil { + return startElement, e + } + + if startElement, ok := t.(xml.StartElement); ok { + return startElement, err + } + } +} diff --git a/xml/xml_decoder_test.go b/xml/xml_decoder_test.go new file mode 100644 index 000000000..3b4bfacb9 --- /dev/null +++ b/xml/xml_decoder_test.go @@ -0,0 +1,305 @@ +package xml + +import ( + "bytes" + "encoding/xml" + "io" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestXMLNodeDecoder_Token(t *testing.T) { + cases := map[string]struct { + responseBody io.Reader + expectedStartElement xml.StartElement + expectedDone bool + expectedError string + }{ + "simple success case": { + responseBody: bytes.NewReader([]byte(`abc`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "", + }, + }, + expectedDone: true, + }, + "no value": { + responseBody: bytes.NewReader([]byte(``)), + expectedDone: true, + }, + "empty body": { + responseBody: bytes.NewReader([]byte(``)), + expectedError: "EOF", + }, + "with indentation": { + responseBody: bytes.NewReader([]byte(` abc`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "Struct", + }, + Attr: []xml.Attr{}, + }, + }, + "with comment and indentation": { + responseBody: bytes.NewReader([]byte(` abc`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "Struct", + }, + Attr: []xml.Attr{}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + xmlDecoder := xml.NewDecoder(c.responseBody) + st, err := FetchRootElement(xmlDecoder) + if err != nil { + if len(c.expectedError) == 0 { + t.Fatalf("Expected no error, got %v", err) + } + + if e, a := c.expectedError, err; !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expected error to contain %v, found %v", e, a.Error()) + } + } + nodeDecoder := WrapNodeDecoder(xmlDecoder, st) + token, done, err := nodeDecoder.Token() + if err != nil { + if len(c.expectedError) == 0 { + t.Fatalf("Expected no error, got %v", err) + } + + if e, a := c.expectedError, err; !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expected error to contain %v, found %v", e, a.Error()) + } + } + + if e, a := c.expectedDone, done; e != a { + t.Fatalf("expected a valid end element token for the xml document, got none") + } + + if diff := cmp.Diff(c.expectedStartElement, token); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + }) + } +} + +func TestXMLNodeDecoder_TokenExample(t *testing.T) { + responseBody := bytes.NewReader([]byte(`abc`)) + + xmlDecoder := xml.NewDecoder(responseBody) + // Fetches tag as start element. + st, err := FetchRootElement(xmlDecoder) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + // nodeDecoder will track tag as root node of the document + nodeDecoder := WrapNodeDecoder(xmlDecoder, st) + + // Retrieves tag + token, done, err := nodeDecoder.Token() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + + } + if diff := cmp.Diff(token, xml.StartElement{Name: xml.Name{Local: "Response"}, Attr: []xml.Attr{}}); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + if done { + t.Fatalf("expected decoding to not be done yet") + } + + // Skips the value and gets that is the end token of previously retrieved tag. + // The way node decoder works it only keeps track of the root start tag using which it was initialized. + // Here is used to initialize, while is end element corresponding to already read + // tag. We won't be done until we receive + token, done, err = nodeDecoder.Token() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + + } + if diff := cmp.Diff(token, xml.StartElement{Name: xml.Name{Local: ""}, Attr: nil}); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + if done { + t.Fatalf("expected decoding to not be done yet") + } + + // Retrieves end element tag corresponding to tag. + // Since we got the end element that corresponds to the start element being track, we are done decoding. + token, done, err = nodeDecoder.Token() + if err != nil { + t.Fatalf("Expected no error, got %v", err) + + } + if diff := cmp.Diff(token, xml.StartElement{Name: xml.Name{Local: ""}, Attr: nil}); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + if !done { + t.Fatalf("expected decoding to be done as we fetched the end element ") + } +} + +func TestXMLNodeDecoder_Value(t *testing.T) { + cases := map[string]struct { + responseBody io.Reader + expectedValue []byte + expectedDone bool + expectedError string + }{ + "simple success case": { + responseBody: bytes.NewReader([]byte(`abc`)), + expectedValue: []byte(`abc`), + }, + "no value": { + responseBody: bytes.NewReader([]byte(``)), + expectedDone: true, + }, + "empty body": { + responseBody: bytes.NewReader([]byte(``)), + expectedError: "EOF", + }, + "start element retrieved": { + responseBody: bytes.NewReader([]byte(`abc`)), + expectedError: "expected value for Response element, got xml.StartElement type", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + xmlDecoder := xml.NewDecoder(c.responseBody) + st, err := FetchRootElement(xmlDecoder) + if err != nil { + if len(c.expectedError) == 0 { + t.Fatalf("Expected no error, got %v", err) + } + + if e, a := c.expectedError, err; !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expected error to contain %v, found %v", e, a.Error()) + } + } + nodeDecoder := WrapNodeDecoder(xmlDecoder, st) + token, done, err := nodeDecoder.Value() + if err != nil { + if len(c.expectedError) == 0 { + t.Fatalf("Expected no error, got %v", err) + } + + if e, a := c.expectedError, err; !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expected error to contain %v, found %v", e, a.Error()) + } + } + + if e, a := c.expectedDone, done; e != a { + t.Fatalf("expected a valid end element token for the xml document, got none") + } + + if diff := cmp.Diff(c.expectedValue, token); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + }) + } +} + +func Test_FetchXMLRootElement(t *testing.T) { + cases := map[string]struct { + responseBody io.Reader + expectedStartElement xml.StartElement + expectedError string + }{ + "simple success case": { + responseBody: bytes.NewReader([]byte(`abc`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "Response", + }, + Attr: []xml.Attr{}, + }, + }, + "empty body": { + responseBody: bytes.NewReader([]byte(``)), + expectedError: "EOF", + }, + "with indentation": { + responseBody: bytes.NewReader([]byte(` + + Sender + InvalidGreeting + Hi + setting + + foo-id +`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "ErrorResponse", + }, + Attr: []xml.Attr{}, + }, + }, + "with preamble": { + responseBody: bytes.NewReader([]byte(` + + + Sender + InvalidGreeting + Hi + setting + + foo-id +`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "ErrorResponse", + }, + Attr: []xml.Attr{}, + }, + }, + "with comments": { + responseBody: bytes.NewReader([]byte(` + + + + Sender + InvalidGreeting + Hi + setting + + foo-id +`)), + expectedStartElement: xml.StartElement{ + Name: xml.Name{ + Local: "ErrorResponse", + }, + Attr: []xml.Attr{}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + decoder := xml.NewDecoder(c.responseBody) + st, err := FetchRootElement(decoder) + if err != nil { + if len(c.expectedError) == 0 { + t.Fatalf("Expected no error, got %v", err) + } + + if e, a := c.expectedError, err; !strings.Contains(err.Error(), c.expectedError) { + t.Fatalf("expected error to contain %v, found %v", e, a.Error()) + } + } + + if diff := cmp.Diff(c.expectedStartElement, st); len(diff) != 0 { + t.Fatalf("Found diff : (-expected,+actual), \n %v", diff) + } + }) + } +} diff --git a/xml/xml_utils.go b/xml/xml_utils.go new file mode 100644 index 000000000..41fe18e57 --- /dev/null +++ b/xml/xml_utils.go @@ -0,0 +1,44 @@ +package xml + +import ( + "encoding/xml" + "fmt" + "io" + "io/ioutil" +) + +// GetResponseErrorCode returns the error code from an xml error response body +func GetResponseErrorCode(r io.Reader, noErrorWrapping bool) (string, error) { + rb, err := ioutil.ReadAll(r) + if err != nil { + return "", err + } + + if noErrorWrapping { + var errResponse errorBody + err := xml.Unmarshal(rb, &errResponse) + if err != nil { + return "", fmt.Errorf("error while fetching xml error response code: %w", err) + } + return errResponse.Code, err + } + + var errResponse errorResponse + if err := xml.Unmarshal(rb, &errResponse); err != nil { + return "", fmt.Errorf("error while fetching xml error response code: %w", err) + } + return errResponse.Err.Code, nil +} + +// errorResponse represents the outer error response body +// i.e. ... +type errorResponse struct { + Err errorBody `xml:"Error"` +} + +// errorBody represents the inner error body is wrapped by tag +// eg. if error response is ... +// here errorBody represents ... +type errorBody struct { + Code string `xml:"Code"` +} diff --git a/xml/xml_utils_test.go b/xml/xml_utils_test.go new file mode 100644 index 000000000..7ae9abdf2 --- /dev/null +++ b/xml/xml_utils_test.go @@ -0,0 +1,53 @@ +package xml + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestGetResponseErrorCode(t *testing.T) { + cases := map[string]struct { + errorResponse io.Reader + noErrorWrappingEnabled bool + expectedErrorCode string + }{ + "no error wrapping enabled": { + errorResponse: bytes.NewReader([]byte(` + + Sender + InvalidGreeting + Hi + setting + + foo-id +`)), + expectedErrorCode: "InvalidGreeting", + }, + "no error wrapping disabled": { + errorResponse: bytes.NewReader([]byte(` + Sender + InvalidGreeting + Hi + setting + foo-id +`)), + noErrorWrappingEnabled: true, + expectedErrorCode: "InvalidGreeting", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + errorcode, err := GetResponseErrorCode(c.errorResponse, c.noErrorWrappingEnabled) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if e, a := c.expectedErrorCode, errorcode; !strings.EqualFold(e, a) { + t.Fatalf("expected %v, got %v", e, a) + } + }) + } +}