Skip to content

Commit

Permalink
Merge branch 'development' into ed/issue-1794
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardmack authored Oct 19, 2021
2 parents f065500 + 88c59ea commit 9ba77ed
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 24 deletions.
23 changes: 13 additions & 10 deletions lib/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import (
"sync"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"github.com/ChainSafe/gossamer/pkg/scale"
)

// node is the interface for trie methods
Expand Down Expand Up @@ -337,26 +337,28 @@ func (b *branch) decode(r io.Reader, header byte) (err error) {
return err
}

sd := &scale.Decoder{Reader: r}
sd := scale.NewDecoder(r)

if nodeType == 3 {
var value []byte
// branch w/ value
value, err := sd.Decode([]byte{})
err := sd.Decode(&value)
if err != nil {
return err
}
b.value = value.([]byte)
b.value = value
}

for i := 0; i < 16; i++ {
if (childrenBitmap[i/8]>>(i%8))&1 == 1 {
hash, err := sd.Decode([]byte{})
var hash []byte
err := sd.Decode(&hash)
if err != nil {
return err
}

b.children[i] = &leaf{
hash: hash.([]byte),
hash: hash,
}
}
}
Expand Down Expand Up @@ -386,14 +388,15 @@ func (l *leaf) decode(r io.Reader, header byte) (err error) {
return err
}

sd := &scale.Decoder{Reader: r}
value, err := sd.Decode([]byte{})
sd := scale.NewDecoder(r)
var value []byte
err = sd.Decode(&value)
if err != nil {
return err
}

if len(value.([]byte)) > 0 {
l.value = value.([]byte)
if len(value) > 0 {
l.value = value
}

l.dirty = true
Expand Down
14 changes: 5 additions & 9 deletions lib/trie/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"testing"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/scale"
"github.com/ChainSafe/gossamer/pkg/scale"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -160,14 +160,12 @@ func TestBranchEncode(t *testing.T) {
expected = append(expected, nibblesToKeyLE(b.key)...)
expected = append(expected, common.Uint16ToBytes(b.childrenBitmap())...)

buf := bytes.Buffer{}
encoder := &scale.Encoder{Writer: &buf}
_, err = encoder.Encode(b.value)
enc, err := scale.Marshal(b.value)
if err != nil {
t.Fatalf("Fail when encoding value with scale: %s", err)
}

expected = append(expected, buf.Bytes()...)
expected = append(expected, enc...)

for _, child := range b.children {
if child != nil {
Expand Down Expand Up @@ -207,14 +205,12 @@ func TestLeafEncode(t *testing.T) {
expected = append(expected, header...)
expected = append(expected, nibblesToKeyLE(l.key)...)

buf := bytes.Buffer{}
encoder := &scale.Encoder{Writer: &buf}
_, err = encoder.Encode(l.value)
enc, err := scale.Marshal(l.value)
if err != nil {
t.Fatalf("Fail when encoding value with scale: %s", err)
}

expected = append(expected, buf.Bytes()...)
expected = append(expected, enc...)

hasher := newHasher(false)
defer hasher.returnToPool()
Expand Down
46 changes: 42 additions & 4 deletions pkg/scale/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"reflect"
)
Expand Down Expand Up @@ -87,7 +89,7 @@ func Unmarshal(data []byte, dst interface{}) (err error) {
if err != nil {
return
}
ds.Buffer = *buf
ds.Reader = buf

err = ds.unmarshal(elem)
if err != nil {
Expand All @@ -96,8 +98,36 @@ func Unmarshal(data []byte, dst interface{}) (err error) {
return
}

// Decoder is used to decode from an io.Reader
type Decoder struct {
decodeState
}

// Decode accepts a pointer to a destination and decodes into supplied destination
func (d *Decoder) Decode(dst interface{}) (err error) {
dstv := reflect.ValueOf(dst)
if dstv.Kind() != reflect.Ptr || dstv.IsNil() {
err = fmt.Errorf("unsupported dst: %T, must be a pointer to a destination", dst)
return
}

elem := indirect(dstv)
if err != nil {
return
}
return d.unmarshal(elem)
}

// NewDecoder is constructor for Decoder
func NewDecoder(r io.Reader) (d *Decoder) {
d = &Decoder{
decodeState{r},
}
return
}

type decodeState struct {
bytes.Buffer
io.Reader
}

func (ds *decodeState) unmarshal(dstv reflect.Value) (err error) {
Expand Down Expand Up @@ -230,6 +260,12 @@ func (ds *decodeState) decodeCustomPrimitive(dstv reflect.Value) (err error) {
return
}

func (ds *decodeState) ReadByte() (byte, error) {
b := make([]byte, 1) // make buffer
_, err := ds.Reader.Read(b) // read what's in the Decoder's underlying buffer to our new buffer b
return b[0], err
}

func (ds *decodeState) decodeResult(dstv reflect.Value) (err error) {
res := dstv.Interface().(Result)
var rb byte
Expand Down Expand Up @@ -263,7 +299,8 @@ func (ds *decodeState) decodeResult(dstv reflect.Value) (err error) {
}
dstv.Set(reflect.ValueOf(res))
default:
err = fmt.Errorf("unsupported Result value: %v, bytes: %v", rb, ds.Bytes())
bytes, _ := ioutil.ReadAll(ds.Reader)
err = fmt.Errorf("unsupported Result value: %v, bytes: %v", rb, bytes)
}
return
}
Expand Down Expand Up @@ -295,7 +332,8 @@ func (ds *decodeState) decodePointer(dstv reflect.Value) (err error) {
dstv.Set(tempElem)
}
default:
err = fmt.Errorf("unsupported Option value: %v, bytes: %v", rb, ds.Bytes())
bytes, _ := ioutil.ReadAll(ds.Reader)
err = fmt.Errorf("unsupported Option value: %v, bytes: %v", rb, bytes)
}
return
}
Expand Down
65 changes: 64 additions & 1 deletion pkg/scale/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package scale

import (
"bytes"
"math/big"
"reflect"
"testing"
Expand Down Expand Up @@ -189,7 +190,6 @@ func Test_unmarshal_optionality(t *testing.T) {
if diff != "" {
t.Errorf("decodeState.unmarshal() = %s", diff)
}

}
})
}
Expand Down Expand Up @@ -238,3 +238,66 @@ func Test_unmarshal_optionality_nil_case(t *testing.T) {
})
}
}

func Test_Decoder_Decode(t *testing.T) {
for _, tt := range newTests(fixedWidthIntegerTests, variableWidthIntegerTests, stringTests,
boolTests, sliceTests, arrayTests,
) {
t.Run(tt.name, func(t *testing.T) {
dst := reflect.New(reflect.TypeOf(tt.in)).Elem().Interface()
wantBuf := bytes.NewBuffer(tt.want)
d := NewDecoder(wantBuf)
if err := d.Decode(&dst); (err != nil) != tt.wantErr {
t.Errorf("Decoder.Decode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(dst, tt.in) {
t.Errorf("Decoder.Decode() = %v, want %v", dst, tt.in)
}
})
}
}

func Test_Decoder_Decode_MultipleCalls(t *testing.T) {
tests := []struct {
name string
ins []interface{}
want []byte
wantErr []bool
}{
{
name: "int64 and []byte",
ins: []interface{}{int64(9223372036854775807), []byte{0x01}},
want: append([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, []byte{0x04, 0x01}...),
},
{
name: "eof error",
ins: []interface{}{int64(9223372036854775807), []byte{0x01}},
want: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f},
wantErr: []bool{false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.want)
d := NewDecoder(buf)

for i := range tt.ins {
in := tt.ins[i]
dst := reflect.New(reflect.TypeOf(in)).Elem().Interface()
var wantErr bool
if len(tt.wantErr) > i {
wantErr = tt.wantErr[i]
}
if err := d.Decode(&dst); (err != nil) != wantErr {
t.Errorf("Decoder.Decode() error = %v, wantErr %v", err, tt.wantErr[i])
return
}
if !wantErr && !reflect.DeepEqual(dst, in) {
t.Errorf("Decoder.Decode() = %v, want %v", dst, in)
return
}
}
})
}
}

0 comments on commit 9ba77ed

Please sign in to comment.