From 8dab44451140b3e29e0dfd0eecd9472aadd1052a Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Wed, 17 Aug 2022 22:36:24 -0400 Subject: [PATCH] protoparse: match the way protoc populates aggregate values in uninterpreted options (#526) --- desc/protoparse/descriptor_protos.go | 37 ++++++++++++++++---- desc/protoparse/options_test.go | 4 +-- desc/protoparse/parser.go | 51 ---------------------------- desc/protoparse/parser_test.go | 15 ++++++-- 4 files changed, 46 insertions(+), 61 deletions(-) diff --git a/desc/protoparse/descriptor_protos.go b/desc/protoparse/descriptor_protos.go index aef22787..993ef253 100644 --- a/desc/protoparse/descriptor_protos.go +++ b/desc/protoparse/descriptor_protos.go @@ -105,16 +105,41 @@ func (r *parseResult) asUninterpretedOption(node *ast.OptionNode) *dpb.Uninterpr opt.StringValue = []byte(val) case ast.Identifier: opt.IdentifierValue = proto.String(string(val)) - case []*ast.MessageFieldNode: - var buf bytes.Buffer - aggToString(val, &buf) - aggStr := buf.String() - opt.AggregateValue = proto.String(aggStr) - //the grammar does not allow arrays here, so no case for []ast.ValueNode + default: + // the grammar does not allow arrays here, so the only possible case + // left should be []*ast.MessageFieldNode, which corresponds to an + // *ast.MessageLiteralNode + if n, ok := node.Val.(*ast.MessageLiteralNode); ok { + var buf bytes.Buffer + for i, el := range n.Elements { + flattenNode(r.root, el, &buf) + if len(n.Seps) > i && n.Seps[i] != nil { + buf.WriteRune(' ') + buf.WriteRune(n.Seps[i].Rune) + } + } + aggStr := buf.String() + opt.AggregateValue = proto.String(aggStr) + } + // TODO: else that reports an error or panics?? } return opt } +func flattenNode(f *ast.FileNode, n ast.Node, buf *bytes.Buffer) { + if cn, ok := n.(ast.CompositeNode); ok { + for _, ch := range cn.Children() { + flattenNode(f, ch, buf) + } + return + } + + if buf.Len() > 0 { + buf.WriteRune(' ') + } + buf.WriteString(n.(ast.TerminalNode).RawText()) +} + func (r *parseResult) asUninterpretedOptionName(parts []*ast.FieldReferenceNode) []*dpb.UninterpretedOption_NamePart { ret := make([]*dpb.UninterpretedOption_NamePart, len(parts)) for i, part := range parts { diff --git a/desc/protoparse/options_test.go b/desc/protoparse/options_test.go index d2c1b4ef..ecc89924 100644 --- a/desc/protoparse/options_test.go +++ b/desc/protoparse/options_test.go @@ -59,7 +59,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { // field where default is uninterpretable contents: `enum TestEnum{ ZERO = 0; ONE = 1; } message Test { optional TestEnum uid = 1 [(must.link) = {foo: bar}, default = ONE, json_name = "UID", deprecated = true]; }`, uninterpreted: map[string]interface{}{ - "Test.uid:(must.link)": aggregate("{ foo: bar }"), + "Test.uid:(must.link)": aggregate("foo : bar"), "Test.uid:default": ident("ONE"), }, checkInterpreted: func(t *testing.T, fd *dpb.FileDescriptorProto) { @@ -108,7 +108,7 @@ func TestOptionsInUnlinkedFiles(t *testing.T) { // service options contents: `service Test { option deprecated = true; option (must.link) = {foo:1, foo:2, bar:3}; }`, uninterpreted: map[string]interface{}{ - "Test:(must.link)": aggregate("{ foo: 1 foo: 2 bar: 3 }"), + "Test:(must.link)": aggregate("foo : 1 , foo : 2 , bar : 3"), }, checkInterpreted: func(t *testing.T, fd *dpb.FileDescriptorProto) { testutil.Require(t, fd.GetService()[0].GetOptions().GetDeprecated()) diff --git a/desc/protoparse/parser.go b/desc/protoparse/parser.go index 817d0f0b..7af9866d 100644 --- a/desc/protoparse/parser.go +++ b/desc/protoparse/parser.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "io/ioutil" - "math" "os" "path/filepath" "sort" @@ -817,56 +816,6 @@ func checkTag(pos *SourcePos, v uint64, maxTag int32) error { return nil } -func aggToString(agg []*ast.MessageFieldNode, buf *bytes.Buffer) { - buf.WriteString("{") - for _, a := range agg { - buf.WriteString(" ") - buf.WriteString(a.Name.Value()) - if v, ok := a.Val.(*ast.MessageLiteralNode); ok { - aggToString(v.Elements, buf) - } else { - buf.WriteString(": ") - elementToString(a.Val.Value(), buf) - } - } - buf.WriteString(" }") -} - -func elementToString(v interface{}, buf *bytes.Buffer) { - switch v := v.(type) { - case bool, int64, uint64, ast.Identifier: - _, _ = fmt.Fprintf(buf, "%v", v) - case float64: - if math.IsInf(v, 1) { - buf.WriteString(": inf") - } else if math.IsInf(v, -1) { - buf.WriteString(": -inf") - } else if math.IsNaN(v) { - buf.WriteString(": nan") - } else { - _, _ = fmt.Fprintf(buf, ": %v", v) - } - case string: - buf.WriteRune('"') - writeEscapedBytes(buf, []byte(v)) - buf.WriteRune('"') - case []ast.ValueNode: - buf.WriteString(": [") - first := true - for _, e := range v { - if first { - first = false - } else { - buf.WriteString(", ") - } - elementToString(e.Value(), buf) - } - buf.WriteString("]") - case []*ast.MessageFieldNode: - aggToString(v, buf) - } -} - func writeEscapedBytes(buf *bytes.Buffer, b []byte) { for _, c := range b { switch c { diff --git a/desc/protoparse/parser_test.go b/desc/protoparse/parser_test.go index 84252b0c..3e4c0d75 100644 --- a/desc/protoparse/parser_test.go +++ b/desc/protoparse/parser_test.go @@ -226,11 +226,22 @@ func TestAggregateValueInUninterpretedOptions(t *testing.T) { testutil.Ok(t, err) fd := res.fd + // service TestTestService, method UserAuth; first option aggregateValue1 := *fd.Service[0].Method[0].Options.UninterpretedOption[0].AggregateValue - testutil.Eq(t, "{ authenticated: true permission{ action: LOGIN entity: \"client\" } }", aggregateValue1) + testutil.Eq(t, `authenticated : true permission : { action : LOGIN entity : "client" }`, aggregateValue1) + // service TestTestService, method Get; first option aggregateValue2 := *fd.Service[0].Method[1].Options.UninterpretedOption[0].AggregateValue - testutil.Eq(t, "{ authenticated: true permission{ action: READ entity: \"user\" } }", aggregateValue2) + testutil.Eq(t, `authenticated : true permission : { action : READ entity : "user" }`, aggregateValue2) + + // message Another; first option + aggregateValue3 := *fd.MessageType[4].Options.UninterpretedOption[0].AggregateValue + testutil.Eq(t, `foo : "abc" s < name : "foo" , id : 123 > , array : [ 1 , 2 , 3 ] , r : [ < name : "f" > , { name : "s" } , { id : 456 } ] ,`, aggregateValue3) + + // message Test.Nested._NestedNested; second option (rept) + // (Test.Nested is at index 1 instead of 0 because of implicit nested message from map field m) + aggregateValue4 := *fd.MessageType[1].NestedType[1].NestedType[0].Options.UninterpretedOption[1].AggregateValue + testutil.Eq(t, `foo : "goo" [ foo . bar . Test . Nested . _NestedNested . _garblez ] : "boo"`, aggregateValue4) } func TestParseFilesMessageComments(t *testing.T) {