Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject unknown fields in TxDecoder and sign mode handlers #6883

Merged
merged 17 commits into from
Aug 3, 2020
Merged
2 changes: 1 addition & 1 deletion client/tx/legacy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ type TestSuite struct {
func (s *TestSuite) SetupSuite() {
encCfg := simapp.MakeEncodingConfig()
s.encCfg = encCfg
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
s.protoCfg = tx.NewTxConfig(codec.NewProtoCodec(encCfg.InterfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)
s.aminoCfg = types3.StdTxConfig{Cdc: encCfg.Amino}
}

Expand Down
File renamed without changes.
6 changes: 2 additions & 4 deletions codec/unknownproto/benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
b.ReportAllocs()

if !parallel {
ckr := new(unknownproto.Checker)
b.ResetTimer()
for i := 0; i < b.N; i++ {
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
b.SetBytes(int64(len(n1BBlob)))
Expand All @@ -66,11 +65,10 @@ func benchmarkRejectUnknownFields(b *testing.B, parallel bool) {
var mu sync.Mutex
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ckr := new(unknownproto.Checker)
for pb.Next() {
// To simulate the conditions of multiple transactions being processed in parallel.
n1A := new(testdata.Nested1A)
if err := ckr.RejectUnknownFields(n1BBlob, n1A); err == nil {
if err := unknownproto.RejectUnknownFieldsStrict(n1BBlob, n1A); err == nil {
b.Fatal("expected an error")
}
mu.Lock()
Expand Down
12 changes: 4 additions & 8 deletions codec/unknownproto/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,18 @@ a) Unknown fields in the stream -- this is indicative of mismatched services, pe

b) Mismatched wire types for a field -- this is indicative of mismatched services

Its API signature is similar to proto.Unmarshal([]byte, proto.Message) as
Its API signature is similar to proto.Unmarshal([]byte, proto.Message) in the strict case

ckr := new(unknownproto.Checker)
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
if err := RejectUnknownFieldsStrict(protoBlob, protoMessage, false); err != nil {
// Handle the error.
}

and ideally should be added before invoking proto.Unmarshal, if you'd like to enforce the features mentioned above.

By default, for security we report every single field that's unknown, whether a non-critical field or not. To customize
this behavior, please create a Checker and set the AllowUnknownNonCriticals to true, for example:
this behavior, please set the boolean parameter allowUnknownNonCriticals to true to RejectUnknownFields:

ckr := &unknownproto.Checker{
AllowUnknownNonCriticals: true,
}
if err := ckr.RejectUnknownFields(protoBlob, protoMessage); err != nil {
if err := RejectUnknownFields(protoBlob, protoMessage, true); err != nil {
// Handle the error.
}
*/
Expand Down
76 changes: 50 additions & 26 deletions codec/unknownproto/unknown_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,45 @@ type descriptorIface interface {
Descriptor() ([]byte, []int)
}

type Checker struct {
// AllowUnknownNonCriticals when set will skip over non-critical fields that are unknown.
AllowUnknownNonCriticals bool
// RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
func RejectUnknownFieldsStrict(bz []byte, msg proto.Message) error {
_, err := RejectUnknownFields(bz, msg, false)
return err
}

func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
if len(b) == 0 {
return nil
// RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
// option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
// hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
// used to treat a message with non-critical field different in different security contexts (such as transaction signing).
// This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool) (hasUnknownNonCriticals bool, err error) {
if len(bz) == 0 {
return hasUnknownNonCriticals, nil
odeke-em marked this conversation as resolved.
Show resolved Hide resolved
}

desc, ok := msg.(descriptorIface)
if !ok {
return fmt.Errorf("%T does not have a Descriptor() method", msg)
return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
odeke-em marked this conversation as resolved.
Show resolved Hide resolved
}

fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
if err != nil {
return err
return hasUnknownNonCriticals, err
}

for len(b) > 0 {
tagNum, wireType, n := protowire.ConsumeField(b)
if n < 0 {
return errors.New("invalid length")
for len(bz) > 0 {
tagNum, wireType, m := protowire.ConsumeTag(bz)
if m < 0 {
return hasUnknownNonCriticals, errors.New("invalid length")
}

fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
switch {
case ok:
// Assert that the wireTypes match.
if !canEncodeType(wireType, fieldDescProto.GetType()) {
return &errMismatchedWireType{
return hasUnknownNonCriticals, &errMismatchedWireType{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
GotWireType: wireType,
Expand All @@ -62,19 +69,27 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
}

default:
if !ckr.AllowUnknownNonCriticals || tagNum&bit11NonCritical == 0 {
isCriticalField := tagNum&bit11NonCritical == 0

if !isCriticalField {
hasUnknownNonCriticals = true
}

if isCriticalField || !allowUnknownNonCriticals {
// The tag is critical, so report it.
return &errUnknownField{
return hasUnknownNonCriticals, &errUnknownField{
Type: reflect.ValueOf(msg).Type().String(),
TagNum: tagNum,
WireType: wireType,
}
}
}
aaronc marked this conversation as resolved.
Show resolved Hide resolved

// Skip over the 2 bytes that store fieldNumber and wireType bytes.
fieldBytes := b[2:n]
b = b[n:]
// Skip over the bytes that store fieldNumber and wireType bytes.
bz = bz[m:]
n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
fieldBytes := bz[:n]
aaronc marked this conversation as resolved.
Show resolved Hide resolved
bz = bz[n:]

// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
if fieldDescProto == nil || fieldDescProto.IsScalar() {
Expand All @@ -89,37 +104,46 @@ func (ckr *Checker) RejectUnknownFields(b []byte, msg proto.Message) error {
// TYPE_BYTES and TYPE_STRING as per
// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
default:
return fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
}
continue
}

// Let's recursively traverse and typecheck the field.

// consume length prefix of nested message
_, o := protowire.ConsumeVarint(fieldBytes)
fieldBytes = fieldBytes[o:]

if protoMessageName == ".google.protobuf.Any" {
// Firstly typecheck types.Any to ensure nothing snuck in.
if err := ckr.RejectUnknownFields(fieldBytes, (*types.Any)(nil)); err != nil {
return err
hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
// And finally we can extract the TypeURL containing the protoMessageName.
any := new(types.Any)
if err := proto.Unmarshal(fieldBytes, any); err != nil {
return err
return hasUnknownNonCriticals, err
}
protoMessageName = any.TypeUrl
fieldBytes = any.Value
}

msg, err := protoMessageForTypeName(protoMessageName[1:])
if err != nil {
return err
return hasUnknownNonCriticals, err
}
if err := ckr.RejectUnknownFields(fieldBytes, msg); err != nil {
return err

hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals)
hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
if err != nil {
return hasUnknownNonCriticals, err
}
}

return nil
return hasUnknownNonCriticals, nil
}

var protoMessageForTypeNameMu sync.RWMutex
Expand Down
31 changes: 15 additions & 16 deletions codec/unknownproto/unknown_fields_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/require"

"github.com/gogo/protobuf/proto"

"github.com/cosmos/cosmos-sdk/codec/types"
Expand All @@ -17,6 +19,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
recv proto.Message
wantErr error
allowUnknownNonCriticals bool
hasUnknownNonCriticals bool
}{
{
name: "Unknown field in midst of repeated values",
Expand Down Expand Up @@ -172,6 +175,7 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
TagNum: 1031,
WireType: 2,
},
hasUnknownNonCriticals: true,
},
{
name: "Unknown field in midst of repeated values, non-critical field ignored",
Expand Down Expand Up @@ -213,8 +217,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
},
},
},
recv: new(testdata.TestVersion1),
wantErr: nil,
recv: new(testdata.TestVersion1),
wantErr: nil,
hasUnknownNonCriticals: true,
},
}

Expand All @@ -225,11 +230,9 @@ func TestRejectUnknownFieldsRepeated(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%v\n\nWant:\n%v", gotErr, tt.wantErr)
}
hasUnknownNonCriticals, gotErr := RejectUnknownFields(protoBlob, tt.recv, tt.allowUnknownNonCriticals)
require.Equal(t, tt.wantErr, gotErr)
require.Equal(t, tt.hasUnknownNonCriticals, hasUnknownNonCriticals)
})
}
}
Expand Down Expand Up @@ -263,7 +266,7 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
wantErr: nil,
},
{
name: "Unkown fields that are critical, but with allowUnknownNonCriticals set",
name: "Unknown fields that are critical, but with allowUnknownNonCriticals set",
allowUnknownNonCriticals: true,
in: &testdata.Customer2{
Id: 289,
Expand All @@ -285,9 +288,8 @@ func TestRejectUnknownFields_allowUnknownNonCriticals(t *testing.T) {
t.Fatalf("Failed to marshal input: %v", err)
}

ckr := &Checker{AllowUnknownNonCriticals: tt.allowUnknownNonCriticals}
c1 := new(testdata.Customer1)
gotErr := ckr.RejectUnknownFields(blob, c1)
_, gotErr := RejectUnknownFields(blob, c1, tt.allowUnknownNonCriticals)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
Expand Down Expand Up @@ -498,8 +500,7 @@ func TestRejectUnknownFieldsNested(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
gotErr := RejectUnknownFieldsStrict(protoBlob, tt.recv)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
Expand Down Expand Up @@ -652,8 +653,7 @@ func TestRejectUnknownFieldsFlat(t *testing.T) {
}

c1 := new(testdata.Customer1)
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(blob, c1)
gotErr := RejectUnknownFieldsStrict(blob, c1)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
Expand Down Expand Up @@ -738,8 +738,7 @@ func TestMismatchedTypes_Nested(t *testing.T) {
if err != nil {
t.Fatal(err)
}
ckr := new(Checker)
gotErr := ckr.RejectUnknownFields(protoBlob, tt.recv)
_, gotErr := RejectUnknownFields(protoBlob, tt.recv, false)
if !reflect.DeepEqual(gotErr, tt.wantErr) {
t.Fatalf("Error mismatch\nGot:\n%s\n\nWant:\n%s", gotErr, tt.wantErr)
}
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ require (
github.com/cosmos/go-bip39 v0.0.0-20180819234021-555e2067c45d
github.com/cosmos/ledger-cosmos-go v0.11.1
github.com/enigmampc/btcutil v1.0.3-0.20200723161021-e2fb6adb2a25
github.com/gibson042/canonicaljson-go v1.0.3
github.com/gogo/protobuf v1.3.1
github.com/golang/mock v1.4.4
github.com/golang/protobuf v1.4.2
Expand Down
4 changes: 0 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gibson042/canonicaljson-go v1.0.3 h1:EAyF8L74AWabkyUmrvEFHEt/AGFQeD6RfwbAuf0j1bI=
github.com/gibson042/canonicaljson-go v1.0.3/go.mod h1:DsLpJTThXyGNO+KZlI85C1/KDcImpP67k/RKVjcaEqo=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
Expand Down Expand Up @@ -480,8 +478,6 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/spf13/viper v1.6.2/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
github.com/spf13/viper v1.6.3/go.mod h1:jUMtyi0/lB5yZH/FjyGAoH7IMNrIhlBf6pXZmbMDvzw=
github.com/spf13/viper v1.7.0 h1:xVKxvI7ouOI5I+U9s2eeiUfMaWBVoXA3AWskkrqK0VM=
github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/spf13/viper v1.7.1 h1:pM5oEahlgWv/WnHXpgbKz7iLIxRf65tye2Ci+XFK5sk=
github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
Expand Down
2 changes: 1 addition & 1 deletion simapp/params/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func MakeEncodingConfig() EncodingConfig {
amino := codec.New()
interfaceRegistry := types.NewInterfaceRegistry()
marshaler := codec.NewHybridCodec(amino, interfaceRegistry)
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModeHandler())
txGen := tx.NewTxConfig(codec.NewProtoCodec(interfaceRegistry), std.DefaultPublicKeyCodec{}, tx.DefaultSignModes)

return EncodingConfig{
InterfaceRegistry: interfaceRegistry,
Expand Down
3 changes: 2 additions & 1 deletion store/cachekv/memiterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"container/list"
"errors"

"github.com/cosmos/cosmos-sdk/types/kv"
dbm "github.com/tendermint/tm-db"

"github.com/cosmos/cosmos-sdk/types/kv"
)

// Iterates over iterKVCache items.
Expand Down
Loading