Skip to content

Commit

Permalink
Fix panic in buf curl that can happen with some map fields (#2711)
Browse files Browse the repository at this point in the history
If response messages include a non-empty map field whose values
are not message types, the command could fail with the following:
        panic: type mismatch: cannot convert map to message
This commit adds a repro test and fixes the panic bug.
  • Loading branch information
jhump authored Jan 17, 2024
1 parent 1412b86 commit 7774390
Show file tree
Hide file tree
Showing 4 changed files with 425 additions and 2 deletions.
10 changes: 8 additions & 2 deletions private/buf/bufcurl/invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,10 @@ func countUnrecognized(msg protoreflect.Message) int {
var count int
msg.Range(func(field protoreflect.FieldDescriptor, val protoreflect.Value) bool {
switch {
case field.IsMap() && isMessageKind(field.MapValue().Kind()):
case field.IsMap():
if !isMessageKind(field.MapValue().Kind()) {
break
}
// Note: Technically, each message entry could have had unrecognized field
// bytes, but they are discarded by the runtime. So we can only look at
// unrecognized fields in message values inside the map.
Expand All @@ -460,7 +463,10 @@ func countUnrecognized(msg protoreflect.Message) int {
count += countUnrecognized(v.Message())
return true
})
case field.IsList() && isMessageKind(field.Kind()):
case field.IsList():
if !isMessageKind(field.Kind()) {
break
}
listVal := val.List()
for i, length := 0, listVal.Len(); i < length; i++ {
count += countUnrecognized(listVal.Get(i).Message())
Expand Down
73 changes: 73 additions & 0 deletions private/buf/bufcurl/invoker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright 2020-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bufcurl

import (
"context"
"os"
"testing"

"github.com/bufbuild/protocompile"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestCountUnrecognized(t *testing.T) {
t.Parallel()
descriptors, err := (&protocompile.Compiler{
Resolver: &protocompile.SourceResolver{
ImportPaths: []string{"./testdata"},
},
}).Compile(context.Background(), "test.proto")
require.NoError(t, err)
msgType, err := descriptors.AsResolver().FindMessageByName("foo.bar.Message")
require.NoError(t, err)
msg := msgType.New()
msgData, err := os.ReadFile("./testdata/testdata.txt")
require.NoError(t, err)
err = prototext.Unmarshal(msgData, msg.Interface())
require.NoError(t, err)
// Add some unrecognized bytes
unknownBytes := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1}

msg.SetUnknown(unknownBytes)
expectedUnrecognized := len(unknownBytes)

msg.Get(msgType.Descriptor().Fields().ByName("msg")).Message().SetUnknown(unknownBytes[:10])
expectedUnrecognized += 10
msg.Get(msgType.Descriptor().Fields().ByName("grp")).Message().SetUnknown(unknownBytes[:6])
expectedUnrecognized += 6

slice := msg.Get(msgType.Descriptor().Fields().ByName("rmsg")).List()
slice.Get(0).Message().SetUnknown(unknownBytes[:10])
slice.Get(1).Message().SetUnknown(unknownBytes[:5])
expectedUnrecognized += 15
slice = msg.Get(msgType.Descriptor().Fields().ByName("rgrp")).List()
slice.Get(0).Message().SetUnknown(unknownBytes[:3])
slice.Get(1).Message().SetUnknown(unknownBytes[:8])
expectedUnrecognized += 11

mapVal := msg.Get(msgType.Descriptor().Fields().ByName("mvmsg")).Map()
mapVal.Range(func(_ protoreflect.MapKey, v protoreflect.Value) bool {
v.Message().SetUnknown(unknownBytes[:6])
expectedUnrecognized += 6
return true
})

unrecognized := countUnrecognized(msg)
assert.Equal(t, expectedUnrecognized, unrecognized)
}
84 changes: 84 additions & 0 deletions private/buf/bufcurl/testdata/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
syntax = "proto2";

package foo.bar;

message Message {
// Every kind of singular field
optional int32 i32 = 1;
optional int64 i64 = 2;
optional uint32 ui32 = 3;
optional uint64 ui64 = 4;
optional sint32 si32 = 5;
optional sint64 si64 = 6;
optional fixed32 f32 = 7;
optional fixed64 f64 = 8;
optional sfixed32 sf32 = 9;
optional sfixed64 sf64 = 10;
optional float fl = 11;
optional double dbl = 12;
optional bool b = 13;
optional string s = 14;
optional bytes bs = 15;
optional Enum en = 16;
optional Message msg = 17;
optional group Grp = 18 {
optional string name = 1;
};
// Every kind of repeated field
repeated int32 ri32 = 19;
repeated int64 ri64 = 20;
repeated uint32 rui32 = 21;
repeated uint64 rui64 = 22;
repeated sint32 rsi32 = 23;
repeated sint64 rsi64 = 24;
repeated fixed32 rf32 = 25;
repeated fixed64 rf64 = 26;
repeated sfixed32 rsf32 = 27;
repeated sfixed64 rsf64 = 28;
repeated float rfl = 29;
repeated double rdbl = 30;
repeated bool rb = 31;
repeated string rs = 32;
repeated bytes rbs = 33;
repeated Enum ren = 34;
repeated Message rmsg = 35;
repeated group Rgrp = 36 {
optional string name = 1;
};
// Every kind of map key
map<string, string> mks = 37;
map<int32, string> mki32 = 38;
map<int64, string> mki64 = 39;
map<uint32, string> mkui32 = 40;
map<uint64, string> mkui64 = 41;
map<sint32, string> mksi32 = 42;
map<sint64, string> mksi64 = 43;
map<fixed32, string> mkf32 = 44;
map<fixed64, string> mkf64 = 45;
map<sfixed32, string> mksf32 = 46;
map<sfixed64, string> mksf64 = 47;
map<bool, string> mkb = 48;
// Every kind of map value
map<string, int32> mvi32 = 51;
map<string, int64> mvi64 = 52;
map<string, uint32> mvui32 = 53;
map<string, uint64> mvui64 = 54;
map<string, sint32> mvsi32 = 55;
map<string, sint64> mvsi64 = 56;
map<string, fixed32> mvf32 = 57;
map<string, fixed64> mvf64 = 58;
map<string, sfixed32> mvsf32 = 59;
map<string, sfixed64> mvsf64 = 60;
map<string, float> mvfl = 61;
map<string, double> mvdbl = 62;
map<string, bool> mvb = 63;
map<string, string> mvs = 64;
map<string, bytes> mvbs = 65;
map<string, Enum> mven = 66;
map<string, Message> mvmsg = 67;
}

enum Enum {
A = 0;
B = 1;
}
Loading

0 comments on commit 7774390

Please sign in to comment.