Skip to content

Commit

Permalink
rlp: improve nil pointer handling (#20064)
Browse files Browse the repository at this point in the history
* rlp: improve nil pointer handling

In both encoder and decoder, the rules for encoding nil pointers were a
bit hard to understand, and didn't leave much choice. Since RLP allows
two empty values (empty list, empty string), any protocol built on RLP
must choose either of these values to represent the null value in a
certain context.

This change adds choice in the form of two new struct tags, "nilString"
and "nilList". These can be used to specify how a nil pointer value is
encoded. The "nil" tag still exists, but its implementation is now
explicit and defines exactly how nil pointers are handled in a single
place.

Another important change in this commit is how nil pointers and the
Encoder interface interact. The EncodeRLP method was previously called
even on nil values, which was supposed to give users a choice of how
their value would be handled when nil. It turns out this is a stupid
idea. If you create a network protocol containing an object defined in
another package, it's better to be able to say that the object should be
a list or string when nil in the definition of the protocol message
rather than defining the encoding of nil on the object itself.

As of this commit, the encoding rules for pointers now take precedence
over the Encoder interface rule. I think the "nil" tag will work fine
for most cases. For special kinds of objects which are a struct in Go
but strings in RLP, code using the object can specify the desired
encoding of nil using the "nilString" and "nilList" tags.

* rlp: propagate struct field type errors

If a struct contained fields of undecodable type, the encoder and
decoder would panic instead of returning an error. Fix this by
propagating type errors in makeStruct{Writer,Decoder} and add a test.
  • Loading branch information
fjl authored Sep 13, 2019
1 parent 3b6c990 commit 96fb839
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 244 deletions.
168 changes: 54 additions & 114 deletions rlp/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,81 +55,23 @@ var (
}
)

// Decoder is implemented by types that require custom RLP
// decoding rules or need to decode into private fields.
// Decoder is implemented by types that require custom RLP decoding rules or need to decode
// into private fields.
//
// The DecodeRLP method should read one value from the given
// Stream. It is not forbidden to read less or more, but it might
// be confusing.
// The DecodeRLP method should read one value from the given Stream. It is not forbidden to
// read less or more, but it might be confusing.
type Decoder interface {
DecodeRLP(*Stream) error
}

// Decode parses RLP-encoded data from r and stores the result in the
// value pointed to by val. Val must be a non-nil pointer. If r does
// not implement ByteReader, Decode will do its own buffering.
// Decode parses RLP-encoded data from r and stores the result in the value pointed to by
// val. Please see package-level documentation for the decoding rules. Val must be a
// non-nil pointer.
//
// Decode uses the following type-dependent decoding rules:
// If r does not implement ByteReader, Decode will do its own buffering.
//
// If the type implements the Decoder interface, decode calls
// DecodeRLP.
//
// To decode into a pointer, Decode will decode into the value pointed
// to. If the pointer is nil, a new value of the pointer's element
// type is allocated. If the pointer is non-nil, the existing value
// will be reused.
//
// To decode into a struct, Decode expects the input to be an RLP
// list. The decoded elements of the list are assigned to each public
// field in the order given by the struct's definition. The input list
// must contain an element for each decoded field. Decode returns an
// error if there are too few or too many elements.
//
// The decoding of struct fields honours certain struct tags, "tail",
// "nil" and "-".
//
// The "-" tag ignores fields.
//
// For an explanation of "tail", see the example.
//
// The "nil" tag applies to pointer-typed fields and changes the decoding
// rules for the field such that input values of size zero decode as a nil
// pointer. This tag can be useful when decoding recursive types.
//
// type StructWithEmptyOK struct {
// Foo *[20]byte `rlp:"nil"`
// }
//
// To decode into a slice, the input must be a list and the resulting
// slice will contain the input elements in order. For byte slices,
// the input must be an RLP string. Array types decode similarly, with
// the additional restriction that the number of input elements (or
// bytes) must match the array's length.
//
// To decode into a Go string, the input must be an RLP string. The
// input bytes are taken as-is and will not necessarily be valid UTF-8.
//
// To decode into an unsigned integer type, the input must also be an RLP
// string. The bytes are interpreted as a big endian representation of
// the integer. If the RLP string is larger than the bit size of the
// type, Decode will return an error. Decode also supports *big.Int.
// There is no size limit for big integers.
//
// To decode into a boolean, the input must contain an unsigned integer
// of value zero (false) or one (true).
//
// To decode into an interface value, Decode stores one of these
// in the value:
//
// []interface{}, for RLP lists
// []byte, for RLP strings
//
// Non-empty interface types are not supported, nor are signed integers,
// floating point numbers, maps, channels and functions.
//
// Note that Decode does not set an input limit for all readers
// and may be vulnerable to panics cause by huge value sizes. If
// you need an input limit, use
// Note that Decode does not set an input limit for all readers and may be vulnerable to
// panics cause by huge value sizes. If you need an input limit, use
//
// NewStream(r, limit).Decode(val)
func Decode(r io.Reader, val interface{}) error {
Expand All @@ -140,9 +82,8 @@ func Decode(r io.Reader, val interface{}) error {
return stream.Decode(val)
}

// DecodeBytes parses RLP data from b into val.
// Please see the documentation of Decode for the decoding rules.
// The input must contain exactly one value and no trailing data.
// DecodeBytes parses RLP data from b into val. Please see package-level documentation for
// the decoding rules. The input must contain exactly one value and no trailing data.
func DecodeBytes(b []byte, val interface{}) error {
r := bytes.NewReader(b)

Expand Down Expand Up @@ -211,14 +152,15 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
switch {
case typ == rawValueType:
return decodeRawValue, nil
case typ.Implements(decoderInterface):
return decodeDecoder, nil
case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface):
return decodeDecoderNoPtr, nil
case typ.AssignableTo(reflect.PtrTo(bigInt)):
return decodeBigInt, nil
case typ.AssignableTo(bigInt):
return decodeBigIntNoPtr, nil
case kind == reflect.Ptr:
return makePtrDecoder(typ, tags)
case reflect.PtrTo(typ).Implements(decoderInterface):
return decodeDecoder, nil
case isUint(kind):
return decodeUint, nil
case kind == reflect.Bool:
Expand All @@ -229,11 +171,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
return makeListDecoder(typ, tags)
case kind == reflect.Struct:
return makeStructDecoder(typ)
case kind == reflect.Ptr:
if tags.nilOK {
return makeOptionalPtrDecoder(typ)
}
return makePtrDecoder(typ)
case kind == reflect.Interface:
return decodeInterface, nil
default:
Expand Down Expand Up @@ -448,6 +385,11 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
if err != nil {
return nil, err
}
for _, f := range fields {
if f.info.decoderErr != nil {
return nil, structFieldError{typ, f.index, f.info.decoderErr}
}
}
dec := func(s *Stream, val reflect.Value) (err error) {
if _, err := s.List(); err != nil {
return wrapStreamError(err, typ)
Expand All @@ -465,15 +407,22 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
return dec, nil
}

// makePtrDecoder creates a decoder that decodes into
// the pointer's element type.
func makePtrDecoder(typ reflect.Type) (decoder, error) {
// makePtrDecoder creates a decoder that decodes into the pointer's element type.
func makePtrDecoder(typ reflect.Type, tag tags) (decoder, error) {
etype := typ.Elem()
etypeinfo := cachedTypeInfo1(etype, tags{})
if etypeinfo.decoderErr != nil {
switch {
case etypeinfo.decoderErr != nil:
return nil, etypeinfo.decoderErr
case !tag.nilOK:
return makeSimplePtrDecoder(etype, etypeinfo), nil
default:
return makeNilPtrDecoder(etype, etypeinfo, tag.nilKind), nil
}
dec := func(s *Stream, val reflect.Value) (err error) {
}

func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder {
return func(s *Stream, val reflect.Value) (err error) {
newval := val
if val.IsNil() {
newval = reflect.New(etype)
Expand All @@ -483,30 +432,35 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) {
}
return err
}
return dec, nil
}

// makeOptionalPtrDecoder creates a decoder that decodes empty values
// as nil. Non-empty values are decoded into a value of the element type,
// just like makePtrDecoder does.
// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty
// values are decoded into a value of the element type, just like makePtrDecoder does.
//
// This decoder is used for pointer-typed struct fields with struct tag "nil".
func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
etype := typ.Elem()
etypeinfo := cachedTypeInfo1(etype, tags{})
if etypeinfo.decoderErr != nil {
return nil, etypeinfo.decoderErr
}
dec := func(s *Stream, val reflect.Value) (err error) {
func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, nilKind Kind) decoder {
typ := reflect.PtrTo(etype)
nilPtr := reflect.Zero(typ)
return func(s *Stream, val reflect.Value) (err error) {
kind, size, err := s.Kind()
if err != nil || size == 0 && kind != Byte {
if err != nil {
val.Set(nilPtr)
return wrapStreamError(err, typ)
}
// Handle empty values as a nil pointer.
if kind != Byte && size == 0 {
if kind != nilKind {
return &decodeError{
msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind),
typ: typ,
}
}
// rearm s.Kind. This is important because the input
// position must advance to the next value even though
// we don't read anything.
s.kind = -1
// set the pointer to nil.
val.Set(reflect.Zero(typ))
return err
val.Set(nilPtr)
return nil
}
newval := val
if val.IsNil() {
Expand All @@ -517,7 +471,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
}
return err
}
return dec, nil
}

var ifsliceType = reflect.TypeOf([]interface{}{})
Expand Down Expand Up @@ -546,21 +499,8 @@ func decodeInterface(s *Stream, val reflect.Value) error {
return nil
}

// This decoder is used for non-pointer values of types
// that implement the Decoder interface using a pointer receiver.
func decodeDecoderNoPtr(s *Stream, val reflect.Value) error {
return val.Addr().Interface().(Decoder).DecodeRLP(s)
}

func decodeDecoder(s *Stream, val reflect.Value) error {
// Decoder instances are not handled using the pointer rule if the type
// implements Decoder with pointer receiver (i.e. always)
// because it might handle empty values specially.
// We need to allocate one here in this case, like makePtrDecoder does.
if val.Kind() == reflect.Ptr && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
return val.Interface().(Decoder).DecodeRLP(s)
return val.Addr().Interface().(Decoder).DecodeRLP(s)
}

// Kind represents the kind of value contained in an RLP stream.
Expand Down
95 changes: 87 additions & 8 deletions rlp/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ type recstruct struct {
Child *recstruct `rlp:"nil"`
}

type invalidNilTag struct {
X []byte `rlp:"nil"`
}

type invalidTail1 struct {
A uint `rlp:"tail"`
B string
Expand All @@ -353,6 +357,18 @@ type tailPrivateFields struct {
x, y bool
}

type nilListUint struct {
X *uint `rlp:"nilList"`
}

type nilStringSlice struct {
X *[]uint `rlp:"nilString"`
}

type intField struct {
X int
}

var (
veryBigInt = big.NewInt(0).Add(
big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16),
Expand Down Expand Up @@ -485,20 +501,20 @@ var decodeTests = []decodeTest{
error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I",
},
{
input: "C0",
ptr: new(invalidTail1),
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)",
},
{
input: "C0",
ptr: new(invalidTail2),
error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)",
input: "C103",
ptr: new(intField),
error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)",
},
{
input: "C50102C20102",
ptr: new(tailUint),
error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]",
},
{
input: "C0",
ptr: new(invalidNilTag),
error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`,
},

// struct tag "tail"
{
Expand All @@ -521,6 +537,16 @@ var decodeTests = []decodeTest{
ptr: new(tailPrivateFields),
value: tailPrivateFields{A: 1, Tail: []uint{2, 3}},
},
{
input: "C0",
ptr: new(invalidTail1),
error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`,
},
{
input: "C0",
ptr: new(invalidTail2),
error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`,
},

// struct tag "-"
{
Expand All @@ -529,6 +555,43 @@ var decodeTests = []decodeTest{
value: hasIgnoredField{A: 1, C: 2},
},

// struct tag "nilList"
{
input: "C180",
ptr: new(nilListUint),
error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X",
},
{
input: "C1C0",
ptr: new(nilListUint),
value: nilListUint{},
},
{
input: "C103",
ptr: new(nilListUint),
value: func() interface{} {
v := uint(3)
return nilListUint{X: &v}
}(),
},

// struct tag "nilString"
{
input: "C1C0",
ptr: new(nilStringSlice),
error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X",
},
{
input: "C180",
ptr: new(nilStringSlice),
value: nilStringSlice{},
},
{
input: "C2C103",
ptr: new(nilStringSlice),
value: nilStringSlice{X: &[]uint{3}},
},

// RawValue
{input: "01", ptr: new(RawValue), value: RawValue(unhex("01"))},
{input: "82FFFF", ptr: new(RawValue), value: RawValue(unhex("82FFFF"))},
Expand Down Expand Up @@ -672,6 +735,22 @@ func TestDecodeDecoder(t *testing.T) {
}
}

func TestDecodeDecoderNilPointer(t *testing.T) {
var s struct {
T1 *testDecoder `rlp:"nil"`
T2 *testDecoder
}
if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil {
t.Fatalf("Decode error: %v", err)
}
if s.T1 != nil {
t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called)
}
if s.T2 == nil || !s.T2.called {
t.Errorf("decoder T2 not allocated/called")
}
}

type byteDecoder byte

func (bd *byteDecoder) DecodeRLP(s *Stream) error {
Expand Down
Loading

0 comments on commit 96fb839

Please sign in to comment.