diff --git a/private/buf/bufcurl/invoker.go b/private/buf/bufcurl/invoker.go index 20649a9291..bfcc4c752b 100644 --- a/private/buf/bufcurl/invoker.go +++ b/private/buf/bufcurl/invoker.go @@ -451,7 +451,10 @@ func countUnrecognized(msg protoreflect.Message) int { var count int msg.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool { switch { - case field.IsMap() && isMessageKind(field.MapValue().Kind()): + case field.IsMap(): + if !isMessageKind(field.MapValue().Kind()) { + break + } // Note: Technically, each message entry could have had unrecognized field // bytes, but they are discarded by the runtime. So we can only look at // unrecognized fields in message values inside the map. @@ -460,7 +463,10 @@ func countUnrecognized(msg protoreflect.Message) int { count += countUnrecognized(v.Message()) return true }) - case field.IsList() && isMessageKind(field.Kind()): + case field.IsList(): + if !isMessageKind(field.Kind()) { + break + } listVal := val.List() for i, length := 0, listVal.Len(); i < length; i++ { count += countUnrecognized(listVal.Get(i).Message()) diff --git a/private/buf/bufcurl/invoker_test.go b/private/buf/bufcurl/invoker_test.go new file mode 100644 index 0000000000..1b69edd5fd --- /dev/null +++ b/private/buf/bufcurl/invoker_test.go @@ -0,0 +1,73 @@ +// Copyright 2020-2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bufcurl + +import ( + "context" + "os" + "testing" + + "github.com/bufbuild/protocompile" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/reflect/protoreflect" +) + +func TestCountUnrecognized(t *testing.T) { + t.Parallel() + descriptors, err := (&protocompile.Compiler{ + Resolver: &protocompile.SourceResolver{ + ImportPaths: []string{"./testdata"}, + }, + }).Compile(context.Background(), "test.proto") + require.NoError(t, err) + msgType, err := descriptors.AsResolver().FindMessageByName("foo.bar.Message") + require.NoError(t, err) + msg := msgType.New() + msgData, err := os.ReadFile("./testdata/testdata.txt") + require.NoError(t, err) + err = prototext.Unmarshal(msgData, msg.Interface()) + require.NoError(t, err) + // Add some unrecognized bytes + unknownBytes := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1} + + msg.SetUnknown(unknownBytes) + expectedUnrecognized := len(unknownBytes) + + msg.Get(msgType.Descriptor().Fields().ByName("msg")).Message().SetUnknown(unknownBytes[:10]) + expectedUnrecognized += 10 + msg.Get(msgType.Descriptor().Fields().ByName("grp")).Message().SetUnknown(unknownBytes[:6]) + expectedUnrecognized += 6 + + slice := msg.Get(msgType.Descriptor().Fields().ByName("rmsg")).List() + slice.Get(0).Message().SetUnknown(unknownBytes[:10]) + slice.Get(1).Message().SetUnknown(unknownBytes[:5]) + expectedUnrecognized += 15 + slice = msg.Get(msgType.Descriptor().Fields().ByName("rgrp")).List() + slice.Get(0).Message().SetUnknown(unknownBytes[:3]) + slice.Get(1).Message().SetUnknown(unknownBytes[:8]) + expectedUnrecognized += 11 + + mapVal := msg.Get(msgType.Descriptor().Fields().ByName("mvmsg")).Map() + mapVal.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool { + v.Message().SetUnknown(unknownBytes[:6]) + expectedUnrecognized += 6 + return true + }) + + unrecognized := countUnrecognized(msg) + assert.Equal(t, expectedUnrecognized, unrecognized) +} diff --git a/private/buf/bufcurl/testdata/test.proto b/private/buf/bufcurl/testdata/test.proto new file mode 100644 index 0000000000..c63bcd532e --- /dev/null +++ b/private/buf/bufcurl/testdata/test.proto @@ -0,0 +1,84 @@ +syntax = "proto2"; + +package foo.bar; + +message Message { + // Every kind of singular field + optional int32 i32 = 1; + optional int64 i64 = 2; + optional uint32 ui32 = 3; + optional uint64 ui64 = 4; + optional sint32 si32 = 5; + optional sint64 si64 = 6; + optional fixed32 f32 = 7; + optional fixed64 f64 = 8; + optional sfixed32 sf32 = 9; + optional sfixed64 sf64 = 10; + optional float fl = 11; + optional double dbl = 12; + optional bool b = 13; + optional string s = 14; + optional bytes bs = 15; + optional Enum en = 16; + optional Message msg = 17; + optional group Grp = 18 { + optional string name = 1; + }; + // Every kind of repeated field + repeated int32 ri32 = 19; + repeated int64 ri64 = 20; + repeated uint32 rui32 = 21; + repeated uint64 rui64 = 22; + repeated sint32 rsi32 = 23; + repeated sint64 rsi64 = 24; + repeated fixed32 rf32 = 25; + repeated fixed64 rf64 = 26; + repeated sfixed32 rsf32 = 27; + repeated sfixed64 rsf64 = 28; + repeated float rfl = 29; + repeated double rdbl = 30; + repeated bool rb = 31; + repeated string rs = 32; + repeated bytes rbs = 33; + repeated Enum ren = 34; + repeated Message rmsg = 35; + repeated group Rgrp = 36 { + optional string name = 1; + }; + // Every kind of map key + map mks = 37; + map mki32 = 38; + map mki64 = 39; + map mkui32 = 40; + map mkui64 = 41; + map mksi32 = 42; + map mksi64 = 43; + map mkf32 = 44; + map mkf64 = 45; + map mksf32 = 46; + map mksf64 = 47; + map mkb = 48; + // Every kind of map value + map mvi32 = 51; + map mvi64 = 52; + map mvui32 = 53; + map mvui64 = 54; + map mvsi32 = 55; + map mvsi64 = 56; + map mvf32 = 57; + map mvf64 = 58; + map mvsf32 = 59; + map mvsf64 = 60; + map mvfl = 61; + map mvdbl = 62; + map mvb = 63; + map mvs = 64; + map mvbs = 65; + map mven = 66; + map mvmsg = 67; +} + +enum Enum { + A = 0; + B = 1; +} diff --git a/private/buf/bufcurl/testdata/testdata.txt b/private/buf/bufcurl/testdata/testdata.txt new file mode 100644 index 0000000000..64b9c5db85 --- /dev/null +++ b/private/buf/bufcurl/testdata/testdata.txt @@ -0,0 +1,260 @@ +i32: -123 +i64: -456 +ui32: 123 +ui64: 456 +si32: -123 +si64: -456 +f32: 123 +f64: 456 +sf32: -123 +sf64: -456 +fl: 1.23 +dbl: 4.56 +b: true +s: "abcdef" +bs: "\x01\x02\x03\x04" +en: A +msg: { + i32: 1 + s: "foo" +} +grp: { + name: "abc" +} +ri32: [-123,-123] +ri64: [-456,-456] +rui32: [123,123] +rui64: [456,456] +rsi32: [-123,-123] +rsi64: [-456,-456] +rf32: [123,123] +rf64: [456,456] +rsf32: [-123,-123] +rsf64: [-456,-456] +rfl: [1.23,1.23] +rdbl: [4.56,4.56] +rb: [true,true] +rs: ["abcdef","ghijkl"] +rbs: ["\x01\x02\x03\x04", "\x05\x06\x07\x08"] +ren: [A, B] +rmsg: [{ + i32: 1 + s: "foo" +},{ + i32: 2 + s: "bar" +}] +rgrp: [{ + name: "abc" +},{ + name: "def" +}] +mks: [{ + key: "a" + value: "abc" +},{ + key: "b" + value: "def" +}] +mki32: [{ + key: 123 + value: "abc" +},{ + key: -123 + value: "def" +}] +mki64: [{ + key: 456 + value: "abc" +},{ + key: -456 + value: "def" +}] +mkui32: [{ + key: 123 + value: "abc" +},{ + key: 234 + value: "def" +}] +mkui64: [{ + key: 456 + value: "abc" +},{ + key: 567 + value: "def" +}] +mksi32: [{ + key: 123 + value: "abc" +},{ + key: -123 + value: "def" +}] +mksi64: [{ + key: 456 + value: "abc" +},{ + key: -456 + value: "def" +}] +mkf32: [{ + key: 123 + value: "abc" +},{ + key: 234 + value: "def" +}] +mkf64: [{ + key: 456 + value: "abc" +},{ + key: 567 + value: "def" +}] +mksf32: [{ + key: 123 + value: "abc" +},{ + key: -123 + value: "def" +}] +mksf64: [{ + key: 456 + value: "abc" +},{ + key: -456 + value: "def" +}] +mkb: [{ + key: true + value: "abc" +},{ + key: false + value: "def" +}] +mvi32: [{ + key: "a" + value: -123 +},{ + key: "b" + value: -123 +}] +mvi64: [{ + key: "a" + value: -456 +},{ + key: "b" + value: -456 +}] +mvui32: [{ + key: "a" + value: 123 +},{ + key: "b" + value: 123 +}] +mvui64: [{ + key: "a" + value: 456 +},{ + key: "b" + value: 456 +}] +mvsi32: [{ + key: "a" + value: -123 +},{ + key: "b" + value: -123 +}] +mvsi64: [{ + key: "a" + value: -456 +},{ + key: "b" + value: -456 +}] +mvf32: [{ + key: "a" + value: 123 +},{ + key: "b" + value: 123 +}] +mvf64: [{ + key: "a" + value: 456 +},{ + key: "b" + value: 456 +}] +mvsf32: [{ + key: "a" + value: -123 +},{ + key: "b" + value: -123 +}] +mvsf64: [{ + key: "a" + value: -456 +},{ + key: "b" + value: -456 +}] +mvfl: [{ + key: "a" + value: 1.23 +},{ + key: "b" + value: 1.23 +}] +mvdbl: [{ + key: "a" + value: 4.56 +},{ + key: "b" + value: 4.56 +}] +mvb: [{ + key: "a" + value: true +},{ + key: "b" + value: false +}] +mvs: [{ + key: "a" + value: "abc" +},{ + key: "b" + value: "def" +}] +mvbs: [{ + key: "a" + value: "\x01\x02\x03\x04" +},{ + key: "b" + value: "\x05\x06\x07\x08" +}] +mven: [{ + key: "a" + value: A +},{ + key: "b" + value: B +}] +mvmsg: [{ + key: "a" + value: { + i32: 1 + s: "foo" + } +},{ + key: "b" + value: { + i32: 2 + s: "bar" + } +}] \ No newline at end of file