From e52bab97a7adafa869f1b0ba0385b9996587807e Mon Sep 17 00:00:00 2001 From: Joshua Humphries Date: Mon, 11 May 2020 09:32:44 -0400 Subject: [PATCH] work around the fact that v2 messages don't have an XXX_unrecognized field (#319) --- desc/protoparse/parser_test.go | 48 +++++++++++++++++++ dynamic/dynamic_message.go | 26 +++++----- go.sum | 1 + internal/unrecognized.go | 86 ++++++++++++++++++++++++++++++++++ 4 files changed, 147 insertions(+), 14 deletions(-) create mode 100644 internal/unrecognized.go diff --git a/desc/protoparse/parser_test.go b/desc/protoparse/parser_test.go index 53cea3c7..0ac8abce 100644 --- a/desc/protoparse/parser_test.go +++ b/desc/protoparse/parser_test.go @@ -9,9 +9,12 @@ import ( "sort" "testing" + "github.com/golang/protobuf/proto" dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" + "github.com/jhump/protoreflect/codec" "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/internal" "github.com/jhump/protoreflect/internal/testutil" ) @@ -337,3 +340,48 @@ message Foo { comment := fds[0].GetMessageTypes()[0].GetFields()[0].GetSourceInfo().GetLeadingComments() testutil.Eq(t, " leading comments\n", comment) } + +func TestParseCustomOptions(t *testing.T) { + accessor := FileContentsFromMap(map[string]string{ + "test.proto": ` +syntax = "proto3"; +import "google/protobuf/descriptor.proto"; +extend google.protobuf.MessageOptions { + string foo = 30303; + int64 bar = 30304; +} +message Foo { + option (.foo) = "foo"; + option (bar) = 123; +} +`, + }) + + p := Parser{ + Accessor: accessor, + IncludeSourceCodeInfo: true, + } + fds, err := p.ParseFiles("test.proto") + testutil.Ok(t, err) + + md := fds[0].GetMessageTypes()[0] + opts := md.GetMessageOptions() + data := internal.GetUnrecognized(opts) + buf := codec.NewBuffer(data) + + tag, wt, err := buf.DecodeTagAndWireType() + testutil.Ok(t, err) + testutil.Eq(t, int32(30303), tag) + testutil.Eq(t, int8(proto.WireBytes), wt) + fieldData, err := buf.DecodeRawBytes(false) + testutil.Ok(t, err) + testutil.Eq(t, "foo", string(fieldData)) + + tag, wt, err = buf.DecodeTagAndWireType() + testutil.Ok(t, err) + testutil.Eq(t, int32(30304), tag) + testutil.Eq(t, int8(proto.WireVarint), wt) + fieldVal, err := buf.DecodeVarint() + testutil.Ok(t, err) + testutil.Eq(t, uint64(123), fieldVal) +} diff --git a/dynamic/dynamic_message.go b/dynamic/dynamic_message.go index 7eb03b5e..513971e7 100644 --- a/dynamic/dynamic_message.go +++ b/dynamic/dynamic_message.go @@ -14,6 +14,7 @@ import ( "github.com/jhump/protoreflect/codec" "github.com/jhump/protoreflect/desc" + "github.com/jhump/protoreflect/internal" ) // ErrUnknownTagNumber is an error that is returned when an operation refers @@ -2194,13 +2195,9 @@ func (m *Message) mergeInto(pm proto.Message, deterministic bool) error { // track tags for which the dynamic message has data but the given // message doesn't know about it - u := target.FieldByName("XXX_unrecognized") - var unknownTags map[int32]struct{} - if u.IsValid() && u.Type() == typeOfBytes { - unknownTags = map[int32]struct{}{} - for tag := range m.values { - unknownTags[tag] = struct{}{} - } + unknownTags := map[int32]struct{}{} + for tag := range m.values { + unknownTags[tag] = struct{}{} } // check that we can successfully do the merge @@ -2305,7 +2302,6 @@ func (m *Message) mergeInto(pm proto.Message, deterministic bool) error { // if we have fields that the given message doesn't know about, add to its unknown fields if len(unknownTags) > 0 { - ub := u.Interface().([]byte) var b codec.Buffer b.SetDeterministic(deterministic) if deterministic { @@ -2332,8 +2328,8 @@ func (m *Message) mergeInto(pm proto.Message, deterministic bool) error { } } } - ub = append(ub, b.Bytes()...) - u.Set(reflect.ValueOf(ub)) + + internal.SetUnrecognized(pm, b.Bytes()) } // finally, convey unknown fields into the given message by letting it unmarshal them @@ -2562,13 +2558,15 @@ func (m *Message) mergeFrom(pm proto.Message) error { // now actually perform the merge for fd, v := range values { - mergeField(m, fd, v) + if err := mergeField(m, fd, v); err != nil { + return err + } } - u := src.FieldByName("XXX_unrecognized") - if u.IsValid() && u.Type() == typeOfBytes { + data := internal.GetUnrecognized(pm) + if len(data) > 0 { // ignore any error returned: pulling in unknown fields is best-effort - _ = m.UnmarshalMerge(u.Interface().([]byte)) + _ = m.UnmarshalMerge(data) } // lastly, also extract any unknown extensions the message may have (unknown extensions diff --git a/go.sum b/go.sum index 16be44c7..6a947451 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/jhump/protobuf v2.6.1+incompatible h1:45MMYOwWuoSQB+pFVwZuoKix+13fjYaGOJnM4gNdIH4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/internal/unrecognized.go b/internal/unrecognized.go new file mode 100644 index 00000000..c903d4b2 --- /dev/null +++ b/internal/unrecognized.go @@ -0,0 +1,86 @@ +package internal + +import ( + "reflect" + + "github.com/golang/protobuf/proto" +) + +var typeOfBytes = reflect.TypeOf([]byte(nil)) + +// GetUnrecognized fetches the bytes of unrecognized fields for the given message. +func GetUnrecognized(msg proto.Message) []byte { + val := reflect.Indirect(reflect.ValueOf(msg)) + u := val.FieldByName("XXX_unrecognized") + if u.IsValid() && u.Type() == typeOfBytes { + return u.Interface().([]byte) + } + + // Fallback to reflection for API v2 messages + get, _, _, ok := unrecognizedGetSetMethods(val) + if !ok { + return nil + } + + return get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte) +} + +// SetUnrecognized adds the given bytes to the unrecognized fields for the given message. +func SetUnrecognized(msg proto.Message, data []byte) { + val := reflect.Indirect(reflect.ValueOf(msg)) + u := val.FieldByName("XXX_unrecognized") + if u.IsValid() && u.Type() == typeOfBytes { + // Just store the bytes in the unrecognized field + ub := u.Interface().([]byte) + ub = append(ub, data...) + u.Set(reflect.ValueOf(ub)) + return + } + + // Fallback to reflection for API v2 messages + get, set, argType, ok := unrecognizedGetSetMethods(val) + if !ok { + return + } + + existing := get.Call([]reflect.Value(nil))[0].Convert(typeOfBytes).Interface().([]byte) + if len(existing) > 0 { + data = append(existing, data...) + } + set.Call([]reflect.Value{reflect.ValueOf(data).Convert(argType)}) +} + +func unrecognizedGetSetMethods(val reflect.Value) (get reflect.Value, set reflect.Value, argType reflect.Type, ok bool) { + // val could be an APIv2 message. We use reflection to interact with + // this message so that we don't have a hard dependency on the new + // version of the protobuf package. + refMethod := val.MethodByName("ProtoReflect") + if !refMethod.IsValid() { + if val.CanAddr() { + refMethod = val.Addr().MethodByName("ProtoReflect") + } + if !refMethod.IsValid() { + return + } + } + refType := refMethod.Type() + if refType.NumIn() != 0 || refType.NumOut() != 1 { + return + } + ref := refMethod.Call([]reflect.Value(nil)) + getMethod, setMethod := ref[0].MethodByName("GetUnknown"), ref[0].MethodByName("SetUnknown") + if !getMethod.IsValid() || !setMethod.IsValid() { + return + } + getType := getMethod.Type() + setType := setMethod.Type() + if getType.NumIn() != 0 || getType.NumOut() != 1 || setType.NumIn() != 1 || setType.NumOut() != 0 { + return + } + arg := setType.In(0) + if !arg.ConvertibleTo(typeOfBytes) || getType.Out(0) != arg { + return + } + + return getMethod, setMethod, arg, true +}