Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix panic in buf curl that can happen with some map fields #2711

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions private/buf/bufcurl/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,10 @@ func countUnrecognized(msg protoreflect.Message) int {
var count int
msg.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
switch {
case field.IsMap() && isMessageKind(field.MapValue().Kind()):
case field.IsMap():
if !isMessageKind(field.MapValue().Kind()) {
break
}
// Note: Technically, each message entry could have had unrecognized field
// bytes, but they are discarded by the runtime. So we can only look at
// unrecognized fields in message values inside the map.
Expand All @@ -460,7 +463,10 @@ func countUnrecognized(msg protoreflect.Message) int {
count += countUnrecognized(v.Message())
return true
})
case field.IsList() && isMessageKind(field.Kind()):
case field.IsList():
if !isMessageKind(field.Kind()) {
break
}
listVal := val.List()
for i, length := 0, listVal.Len(); i < length; i++ {
count += countUnrecognized(listVal.Get(i).Message())
Expand Down
73 changes: 73 additions & 0 deletions private/buf/bufcurl/invoker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2020-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bufcurl

import (
"context"
"os"
"testing"

"github.com/bufbuild/protocompile"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestCountUnrecognized(t *testing.T) {
t.Parallel()
descriptors, err := (&protocompile.Compiler{
Resolver: &protocompile.SourceResolver{
ImportPaths: []string{"./testdata"},
},
}).Compile(context.Background(), "test.proto")
require.NoError(t, err)
msgType, err := descriptors.AsResolver().FindMessageByName("foo.bar.Message")
require.NoError(t, err)
msg := msgType.New()
msgData, err := os.ReadFile("./testdata/testdata.txt")
require.NoError(t, err)
err = prototext.Unmarshal(msgData, msg.Interface())
require.NoError(t, err)
// Add some unrecognized bytes
unknownBytes := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1}

msg.SetUnknown(unknownBytes)
expectedUnrecognized := len(unknownBytes)

msg.Get(msgType.Descriptor().Fields().ByName("msg")).Message().SetUnknown(unknownBytes[:10])
expectedUnrecognized += 10
msg.Get(msgType.Descriptor().Fields().ByName("grp")).Message().SetUnknown(unknownBytes[:6])
expectedUnrecognized += 6

slice := msg.Get(msgType.Descriptor().Fields().ByName("rmsg")).List()
slice.Get(0).Message().SetUnknown(unknownBytes[:10])
slice.Get(1).Message().SetUnknown(unknownBytes[:5])
expectedUnrecognized += 15
slice = msg.Get(msgType.Descriptor().Fields().ByName("rgrp")).List()
slice.Get(0).Message().SetUnknown(unknownBytes[:3])
slice.Get(1).Message().SetUnknown(unknownBytes[:8])
expectedUnrecognized += 11

mapVal := msg.Get(msgType.Descriptor().Fields().ByName("mvmsg")).Map()
mapVal.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
v.Message().SetUnknown(unknownBytes[:6])
expectedUnrecognized += 6
return true
})

unrecognized := countUnrecognized(msg)
assert.Equal(t, expectedUnrecognized, unrecognized)
}
84 changes: 84 additions & 0 deletions private/buf/bufcurl/testdata/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
syntax = "proto2";

package foo.bar;

message Message {
// Every kind of singular field
optional int32 i32 = 1;
optional int64 i64 = 2;
optional uint32 ui32 = 3;
optional uint64 ui64 = 4;
optional sint32 si32 = 5;
optional sint64 si64 = 6;
optional fixed32 f32 = 7;
optional fixed64 f64 = 8;
optional sfixed32 sf32 = 9;
optional sfixed64 sf64 = 10;
optional float fl = 11;
optional double dbl = 12;
optional bool b = 13;
optional string s = 14;
optional bytes bs = 15;
optional Enum en = 16;
optional Message msg = 17;
optional group Grp = 18 {
optional string name = 1;
};
// Every kind of repeated field
repeated int32 ri32 = 19;
repeated int64 ri64 = 20;
repeated uint32 rui32 = 21;
repeated uint64 rui64 = 22;
repeated sint32 rsi32 = 23;
repeated sint64 rsi64 = 24;
repeated fixed32 rf32 = 25;
repeated fixed64 rf64 = 26;
repeated sfixed32 rsf32 = 27;
repeated sfixed64 rsf64 = 28;
repeated float rfl = 29;
repeated double rdbl = 30;
repeated bool rb = 31;
repeated string rs = 32;
repeated bytes rbs = 33;
repeated Enum ren = 34;
repeated Message rmsg = 35;
repeated group Rgrp = 36 {
optional string name = 1;
};
// Every kind of map key
map<string, string> mks = 37;
map<int32, string> mki32 = 38;
map<int64, string> mki64 = 39;
map<uint32, string> mkui32 = 40;
map<uint64, string> mkui64 = 41;
map<sint32, string> mksi32 = 42;
map<sint64, string> mksi64 = 43;
map<fixed32, string> mkf32 = 44;
map<fixed64, string> mkf64 = 45;
map<sfixed32, string> mksf32 = 46;
map<sfixed64, string> mksf64 = 47;
map<bool, string> mkb = 48;
// Every kind of map value
map<string, int32> mvi32 = 51;
map<string, int64> mvi64 = 52;
map<string, uint32> mvui32 = 53;
map<string, uint64> mvui64 = 54;
map<string, sint32> mvsi32 = 55;
map<string, sint64> mvsi64 = 56;
map<string, fixed32> mvf32 = 57;
map<string, fixed64> mvf64 = 58;
map<string, sfixed32> mvsf32 = 59;
map<string, sfixed64> mvsf64 = 60;
map<string, float> mvfl = 61;
map<string, double> mvdbl = 62;
map<string, bool> mvb = 63;
map<string, string> mvs = 64;
map<string, bytes> mvbs = 65;
map<string, Enum> mven = 66;
map<string, Message> mvmsg = 67;
}

enum Enum {
A = 0;
B = 1;
}
Loading