From 7dd7c4bf025b8655ac711ae74308b3a7bb6b947b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Fri, 8 Oct 2021 09:44:39 -0400 Subject: [PATCH 01/18] feat: add verify_proof function --- lib/trie/codec.go | 13 +++ lib/trie/recorder.go | 8 +- lib/trie/verify_proof.go | 224 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 241 insertions(+), 4 deletions(-) create mode 100644 lib/trie/verify_proof.go diff --git a/lib/trie/codec.go b/lib/trie/codec.go index a20389caaa..49d3afb6d5 100644 --- a/lib/trie/codec.go +++ b/lib/trie/codec.go @@ -16,6 +16,19 @@ package trie +import "bytes" + +type Nibbles []byte + +func (n *Nibbles) contains(in []byte, offset uint) bool { + if len(*n) < len(in) { + return false + } + + compareWith := (*n)[offset:len(in)] + return bytes.Equal(compareWith, in) +} + // keyToNibbles turns bytes into nibbles // does not rearrange the nibbles; assumes they are already ordered in LE func keyToNibbles(in []byte) []byte { diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go index 7c2b9a40c9..3bcadaa121 100644 --- a/lib/trie/recorder.go +++ b/lib/trie/recorder.go @@ -6,15 +6,15 @@ type nodeRecord struct { hash []byte } -// Recorder keeps the list of nodes find by Lookup.Find +// recorder keeps the list of nodes find by Lookup.Find type recorder []nodeRecord -// Record insert a node insede the recorded list +// record insert a node insede the recorded list func (r *recorder) record(h, rd []byte) { *r = append(*r, nodeRecord{rawData: rd, hash: h}) } -// Next returns the current item the cursor is on and increment the cursor by 1 +// next returns the current item the cursor is on and increment the cursor by 1 func (r *recorder) next() *nodeRecord { if !r.isEmpty() { n := (*r)[0] @@ -25,7 +25,7 @@ func (r *recorder) next() *nodeRecord { return nil } -// IsEmpty returns bool if there is data inside the slice +// isEmpty returns bool if there is data inside the slice func (r *recorder) isEmpty() bool { return len(*r) <= 0 } diff --git a/lib/trie/verify_proof.go b/lib/trie/verify_proof.go new file mode 100644 index 0000000000..3a90b23c44 --- /dev/null +++ b/lib/trie/verify_proof.go @@ -0,0 +1,224 @@ +package trie + +import ( + "bytes" + "errors" + + "github.com/ChainSafe/gossamer/lib/common" +) + +const ( + MatchesLeaf = iota + MatchesBranch + NotFound + IsChild +) + +var ( + ErrDuplicateKeys = errors.New("duplicate keys on verify proof") + ErrIncompleteProof = errors.New("incomplete proof") + ErrNoMoreItemsOnIterable = errors.New("items iterable exhausted") + ErrExhaustedNibbles = errors.New("exhausted nibbles key") + ErrValueMatchNotFound = errors.New("value match not found") +) + +type stackItem struct { + value []byte + node node + rawNode []byte + path []byte +} + +func newStackItem(path, raw []byte) (*stackItem, error) { + decoded, err := decodeBytes(raw) + if err != nil { + return nil, err + } + + return &stackItem{nil, decoded, raw, path}, nil +} + +func (i *stackItem) advanceChildIndex(d []byte, prooI *proofIter) (*stackItem, error) { + +} + +func (i *stackItem) advanceItem(it *pairListIter) ([]byte, error) { + for { + item := it.peek() + if item == nil { + return nil, ErrNoMoreItemsOnIterable + } + + nk := Nibbles(keyToNibbles(item.key)) + if bytes.HasPrefix(nk, i.path) { + found, next, err := matchKeyToNode(nk, len(i.path), i.node) + + if err != nil { + return nil, err + } else if next != nil { + return next, nil + } else if found { + i.value = item.value + } + + it.next() + continue + } + + return nil, ErrNoMoreItemsOnIterable + } +} + +// matchKeyToNode return true if the leaf was found +// returns the byte array of the next node to keep searching +// returns error if the nibbles are exhausted or node key does not match +func matchKeyToNode(nk Nibbles, prefixOffset int, n node) (bool, []byte, error) { + switch node := n.(type) { + case nil: + return false, nil, ErrValueMatchNotFound + case *leaf: + if nk.contains(node.key, uint(prefixOffset)) && len(nk) == prefixOffset+len(node.key) { + return true, nil, nil + } + + return false, nil, ErrValueMatchNotFound + case *branch: + if nk.contains(node.key, uint(prefixOffset)) { + return matchKeyToBranchNode(nk, prefixOffset+len(node.key), node.children, node.value) + } + + return false, nil, ErrValueMatchNotFound + } + + return false, nil, ErrValueMatchNotFound +} + +func matchKeyToBranchNode(nk Nibbles, prefixPlusKeyLen int, children [16]node, value []byte) (bool, []byte, error) { + if len(nk) == prefixPlusKeyLen { + return false, nil, nil + } + + if len(nk) < prefixPlusKeyLen { + return false, nil, ErrExhaustedNibbles + } + + if children[nk[prefixPlusKeyLen]] == nil { + return false, nil, ErrValueMatchNotFound + } + + continueFrom := make([]byte, len(nk[prefixPlusKeyLen+1:])) + copy(continueFrom, nk[prefixPlusKeyLen+1:]) + return false, continueFrom, nil +} + +type stack []*stackItem + +func (s *stack) push(si *stackItem) { + *s = append(*s, si) +} + +type pair struct{ key, value []byte } +type PairList []*pair + +func (pl *PairList) Add(k, v []byte) { + *pl = append(*pl, &pair{k, v}) +} + +type pairListIter struct { + idx int + set []*pair +} + +func (i *pairListIter) peek() *pair { + if i.hasNext() { + return i.set[i.idx] + } + + return nil +} + +func (i *pairListIter) next() *pair { + if i.hasNext() { + return i.set[i.idx] + i.idx += 1 + } + + return nil +} + +func (i *pairListIter) hasNext() bool { + return len(i.set) < i.idx +} + +func (pl *PairList) toIter() *pairListIter { + return &pairListIter{0, *pl} +} + +type proofIter struct { + idx int + proof [][]byte +} + +func (p *proofIter) next() []byte { + if p.hasNext() { + return p.proof[p.idx] + p.idx += 1 + } + return nil +} + +func (p *proofIter) hasNext() bool { + return len(p.proof) < p.idx +} + +func newProofIter(proof [][]byte) *proofIter { + return &proofIter{0, proof} +} + +func VerifyProof(root common.Hash, proof [][]byte, items PairList) (bool, error) { + if len(proof) == 0 && len(items) == 0 { + return true, nil + } + + // check for duplicates + for i := 1; i < len(items); i++ { + if bytes.Equal(items[i].key, items[i-1].key) { + return false, ErrDuplicateKeys + } + } + + proofI := newProofIter(proof) + itemsI := items.toIter() + + var rootNode []byte + if rootNode = proofI.next(); rootNode == nil { + return false, ErrIncompleteProof + } + + lastEntry, err := newStackItem([]byte{}, rootNode) + if err != nil { + return false, err + } + + st := new(stack) + + for { + descend, err := lastEntry.advanceItem(itemsI) + + if errors.Is(err, ErrNoMoreItemsOnIterable) { + // TODO: implement the unwind stack + } else if err != nil { + return false, err + } + + nextEntry, err := lastEntry.advanceChildIndex(descend, proofI) + if err != nil { + return false, err + } + + st.push(lastEntry) + lastEntry = nextEntry + } + + return false, nil +} From e6a3415225ba5476814bac2006f132ca96422a30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Fri, 8 Oct 2021 15:48:59 -0400 Subject: [PATCH 02/18] chore: adding helpers --- lib/trie/verify_proof.go | 85 +++++++++++++++++++++++++++++++++++----- 1 file changed, 76 insertions(+), 9 deletions(-) diff --git a/lib/trie/verify_proof.go b/lib/trie/verify_proof.go index 3a90b23c44..c94554d3ef 100644 --- a/lib/trie/verify_proof.go +++ b/lib/trie/verify_proof.go @@ -19,14 +19,19 @@ var ( ErrIncompleteProof = errors.New("incomplete proof") ErrNoMoreItemsOnIterable = errors.New("items iterable exhausted") ErrExhaustedNibbles = errors.New("exhausted nibbles key") + ErrExhaustedStack = errors.New("no more itens to pop from stack") ErrValueMatchNotFound = errors.New("value match not found") + ErrExtraneousNode = errors.New("the proof contains at least one extraneous node") + ErrRootMismatch = errors.New("computed root does not match with the given one") ) type stackItem struct { - value []byte - node node - rawNode []byte - path []byte + value []byte + node node + rawNode []byte + path []byte + childIndex int + children [16]node } func newStackItem(path, raw []byte) (*stackItem, error) { @@ -35,11 +40,47 @@ func newStackItem(path, raw []byte) (*stackItem, error) { return nil, err } - return &stackItem{nil, decoded, raw, path}, nil + return &stackItem{nil, decoded, raw, path, 0, [16]node{}}, nil } -func (i *stackItem) advanceChildIndex(d []byte, prooI *proofIter) (*stackItem, error) { +func (i *stackItem) advanceChildIndex(path []byte, proofI *proofIter) (*stackItem, error) { + switch node := i.node.(type) { + case *branch: + if len(node.children) <= 0 { + return nil, errors.New("branch node must to has children nodes") + } + + if len(path) <= 0 { + return nil, errors.New("descend node key is empty") + } + + childIndex := (int)(path[len(path)-1]) + + for i.childIndex < childIndex { + child := node.children[i.childIndex] + if child != nil { + i.children[i.childIndex] = child + } + i.childIndex += 1 + } + + child := node.children[i.childIndex] + return i.makeChildEntry(proofI, child, path) + default: + return nil, errors.New("node must be a branch node") + } +} + +func (i *stackItem) makeChildEntry(proofI *proofIter, child node, path []byte) (*stackItem, error) { + if child == nil { + return newStackItem(path, proofI.next()) + } + encoded, err := encodeAndHash(child) + if err != nil { + return nil, err + } + return newStackItem(path, encoded) } func (i *stackItem) advanceItem(it *pairListIter) ([]byte, error) { @@ -117,6 +158,16 @@ func (s *stack) push(si *stackItem) { *s = append(*s, si) } +func (s *stack) pop() *stackItem { + if len(*s) == 0 { + return nil + } + + i := (*s)[len(*s)-1] + *s = (*s)[:len(*s)-1] + return i +} + type pair struct{ key, value []byte } type PairList []*pair @@ -161,8 +212,9 @@ type proofIter struct { func (p *proofIter) next() []byte { if p.hasNext() { - return p.proof[p.idx] + i := p.proof[p.idx] p.idx += 1 + return i } return nil } @@ -206,7 +258,23 @@ func VerifyProof(root common.Hash, proof [][]byte, items PairList) (bool, error) descend, err := lastEntry.advanceItem(itemsI) if errors.Is(err, ErrNoMoreItemsOnIterable) { - // TODO: implement the unwind stack + entry := st.pop() + if entry == nil { + if proofI.next() != nil { + return false, ErrExtraneousNode + } + + computedRoot := lastEntry.node.getHash() + if !root.Equal(common.BytesToHash(computedRoot)) { + return false, ErrRootMismatch + } + break + } + + lastEntry = entry + lastEntry.children[lastEntry.childIndex] = lastEntry.node + lastEntry.childIndex += 1 + } else if err != nil { return false, err } @@ -215,7 +283,6 @@ func VerifyProof(root common.Hash, proof [][]byte, items PairList) (bool, error) if err != nil { return false, err } - st.push(lastEntry) lastEntry = nextEntry } From cff7280b3a7cf10361e10543ca66606cb3bda319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 11 Oct 2021 18:17:05 -0400 Subject: [PATCH 03/18] chore: build the tree from proof slice --- lib/trie/codec.go | 8 +- lib/trie/database.go | 73 ++++++++++ lib/trie/lookup.go | 49 +++++++ lib/trie/proof.go | 37 ++++- lib/trie/proof_test.go | 88 +++++++++++- lib/trie/verify_proof.go | 291 --------------------------------------- 6 files changed, 246 insertions(+), 300 deletions(-) delete mode 100644 lib/trie/verify_proof.go diff --git a/lib/trie/codec.go b/lib/trie/codec.go index 49d3afb6d5..1d23b00920 100644 --- a/lib/trie/codec.go +++ b/lib/trie/codec.go @@ -16,7 +16,10 @@ package trie -import "bytes" +import ( + "bytes" + "fmt" +) type Nibbles []byte @@ -26,6 +29,9 @@ func (n *Nibbles) contains(in []byte, offset uint) bool { } compareWith := (*n)[offset:len(in)] + + fmt.Printf("Current Nibbles: 0x%x | Offset: %v\n", *n, offset) + fmt.Printf("Comp: 0x%x\n", in) return bytes.Equal(compareWith, in) } diff --git a/lib/trie/database.go b/lib/trie/database.go index 8218bd1b38..d371fdc2d0 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -18,6 +18,7 @@ package trie import ( "bytes" + "errors" "fmt" "github.com/ChainSafe/gossamer/lib/common" @@ -25,6 +26,18 @@ import ( "github.com/ChainSafe/chaindb" ) +var ( + ErrDuplicateKeys = errors.New("duplicate keys on verify proof") + ErrIncompleteProof = errors.New("incomplete proof") + ErrNoMoreItemsOnIterable = errors.New("items iterable exhausted") + ErrExhaustedNibbles = errors.New("exhausted nibbles key") + ErrExhaustedStack = errors.New("no more itens to pop from stack") + ErrValueMatchNotFound = errors.New("value match not found") + ErrExtraneousNode = errors.New("the proof contains at least one extraneous node") + ErrRootMismatch = errors.New("computed root does not match with the given one") + ErrValueMismatch = errors.New("expected value not found in the trie") +) + // Store stores each trie node in the database, where the key is the hash of the encoded node and the value is the encoded node. // Generally, this will only be used for the genesis trie. func (t *Trie) Store(db chaindb.Database) error { @@ -73,6 +86,66 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { return nil } +// LoadFromProof create a trie based on the proof slice. +func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { + mappedNodes := make(map[common.Hash]node) + + // map all the proofs hash -> decoded node + // and takes the loop to indentify the root node + for _, rawNode := range proof { + var ( + decNode node + err error + ) + + if decNode, err = decodeBytes(rawNode); err != nil { + return err + } + + decNode.setDirty(false) + decNode.setEncodingAndHash(rawNode, nil) + + _, computedRoot, err := decNode.encodeAndHash() + if err != nil { + return err + } + + mappedNodes[common.BytesToHash(computedRoot)] = decNode + + if bytes.Equal(computedRoot, root) { + t.root = decNode + } + } + + return t.loadFromProof(mappedNodes, t.root) +} + +// loadFromProof is a recursive function that will create all the trie paths based +// on the mapped proofs slice starting by the root +func (t *Trie) loadFromProof(proof map[common.Hash]node, curr node) error { + switch c := curr.(type) { + case *branch: + for i, child := range c.children { + if child == nil { + continue + } + + proofNode, ok := proof[common.BytesToHash(child.getHash())] + if !ok { + continue + } + + c.children[i] = proofNode + err := t.loadFromProof(proof, proofNode) + if err != nil { + return err + } + } + } + + return nil +} + // Load reconstructs the trie from the database from the given root hash. Used when restarting the node to load the current state trie. func (t *Trie) Load(db chaindb.Database, root common.Hash) error { if root == EmptyHash { diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index dd8600963e..4a916d1088 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -85,3 +85,52 @@ func (l *lookup) find(key []byte, recorder *recorder) ([]byte, error) { } } } + +func findAndRecord(t *Trie, key []byte, recorder *recorder) []byte { + l, err := find(t.root, key, recorder) + if l == nil || err != nil { + return nil + } + + return l.value +} + +func find(parent node, key []byte, recorder *recorder) (*leaf, error) { + enc, hash, err := parent.encodeAndHash() + if err != nil { + return nil, err + } + + recorder.record(hash, enc) + + switch p := parent.(type) { + case *branch: + length := lenCommonPrefix(p.key, key) + + // found the value at this node + if bytes.Equal(p.key, key) || len(key) == 0 { + return &leaf{key: p.key, value: p.value, dirty: false}, nil + } + + // did not find value + if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { + return nil, nil + } + + return find(p.children[key[length]], key[length+1:], recorder) + case *leaf: + enc, hash, err := p.encodeAndHash() + if err != nil { + return nil, err + } + + recorder.record(hash, enc) + if bytes.Equal(p.key, key) { + return p, nil + } + default: + return nil, nil + } + + return nil, nil +} diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 7668b69df8..c1d4ac5dcf 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -17,6 +17,7 @@ package trie import ( + "bytes" "errors" "github.com/ChainSafe/chaindb" @@ -33,16 +34,17 @@ var ( func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) { trackedProofs := make(map[string][]byte) + proofTrie := NewEmptyTrie() + err := proofTrie.Load(db, common.BytesToHash(root)) + if err != nil { + return nil, err + } + for _, k := range keys { nk := keyToNibbles(k) - lookup := newLookup(root, db) recorder := new(recorder) - - _, err := lookup.find(nk, recorder) - if err != nil { - return nil, err - } + findAndRecord(proofTrie, nk, recorder) for !recorder.isEmpty() { recNode := recorder.next() @@ -54,10 +56,31 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e } proofs := make([][]byte, 0) - for _, p := range trackedProofs { proofs = append(proofs, p) } return proofs, nil } + +// Pair holds the key and value to check while verifying the proof +type Pair struct{ Key, Value []byte } + +// VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice +// this function ignores the order of proofs +func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { + proofTrie := NewEmptyTrie() + err := proofTrie.LoadFromProof(proof, root) + if err != nil { + return false, err + } + + for _, i := range items { + recValue := proofTrie.Get(i.Key) + if !bytes.Equal(i.Value, recValue) { + return false, ErrValueMismatch + } + } + + return true, nil +} diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 9129d503c2..26725ac4c8 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -17,6 +17,7 @@ package trie import ( + "fmt" "io/ioutil" "testing" @@ -34,10 +35,12 @@ func TestProofGeneration(t *testing.T) { }) require.NoError(t, err) + expectedValue := rand32Bytes() + trie := NewEmptyTrie() trie.Put([]byte("cat"), rand32Bytes()) trie.Put([]byte("catapulta"), rand32Bytes()) - trie.Put([]byte("catapora"), rand32Bytes()) + trie.Put([]byte("catapora"), expectedValue) trie.Put([]byte("dog"), rand32Bytes()) trie.Put([]byte("doguinho"), rand32Bytes()) @@ -52,4 +55,87 @@ func TestProofGeneration(t *testing.T) { // TODO: use the verify_proof function to assert the tests require.Equal(t, 5, len(proof)) + + pl := []Pair{ + {Key: []byte("catapora"), Value: expectedValue}, + } + + v, err := VerifyProof(proof, hash.ToBytes(), pl) + require.True(t, v) + require.NoError(t, err) +} + +func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][]byte, []Pair) { + t.Helper() + + tmp, err := ioutil.TempDir("", "*-test-trie") + require.NoError(t, err) + + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + DataDir: tmp, + }) + require.NoError(t, err) + + trie := NewEmptyTrie() + for _, e := range entries { + trie.Put(e.Key, e.Value) + } + + err = trie.Store(memdb) + require.NoError(t, err) + + root := trie.root.getHash() + proof, err := GenerateProof(root, keys, memdb) + require.NoError(t, err) + + items := make([]Pair, 0) + for _, i := range keys { + value := trie.Get(i) + require.NotNil(t, value) + + itemFromDB := Pair{ + Key: i, + Value: value, + } + items = append(items, itemFromDB) + } + + return root, proof, items +} + +func TestVerifyProofCorrectly(t *testing.T) { + entries := []Pair{ + {Key: []byte("alpha"), Value: make([]byte, 32)}, + {Key: []byte("bravo"), Value: []byte("bravo")}, + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + {Key: []byte("horse"), Value: []byte("stallion")}, + {Key: []byte("house"), Value: []byte("building")}, + } + + keys := [][]byte{ + []byte("do"), + []byte("dog"), + []byte("doge"), + } + + root, proof, _ := testGenerateProof(t, entries, keys) + + fmt.Printf("ROOT: 0x%x\n", root) + fmt.Println("PROOF") + for _, p := range proof { + fmt.Printf("0x%x\n", p) + } + + pl := []Pair{ + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + } + + v, err := VerifyProof(proof, root, pl) + require.True(t, v) + require.NoError(t, err) } diff --git a/lib/trie/verify_proof.go b/lib/trie/verify_proof.go deleted file mode 100644 index c94554d3ef..0000000000 --- a/lib/trie/verify_proof.go +++ /dev/null @@ -1,291 +0,0 @@ -package trie - -import ( - "bytes" - "errors" - - "github.com/ChainSafe/gossamer/lib/common" -) - -const ( - MatchesLeaf = iota - MatchesBranch - NotFound - IsChild -) - -var ( - ErrDuplicateKeys = errors.New("duplicate keys on verify proof") - ErrIncompleteProof = errors.New("incomplete proof") - ErrNoMoreItemsOnIterable = errors.New("items iterable exhausted") - ErrExhaustedNibbles = errors.New("exhausted nibbles key") - ErrExhaustedStack = errors.New("no more itens to pop from stack") - ErrValueMatchNotFound = errors.New("value match not found") - ErrExtraneousNode = errors.New("the proof contains at least one extraneous node") - ErrRootMismatch = errors.New("computed root does not match with the given one") -) - -type stackItem struct { - value []byte - node node - rawNode []byte - path []byte - childIndex int - children [16]node -} - -func newStackItem(path, raw []byte) (*stackItem, error) { - decoded, err := decodeBytes(raw) - if err != nil { - return nil, err - } - - return &stackItem{nil, decoded, raw, path, 0, [16]node{}}, nil -} - -func (i *stackItem) advanceChildIndex(path []byte, proofI *proofIter) (*stackItem, error) { - switch node := i.node.(type) { - case *branch: - if len(node.children) <= 0 { - return nil, errors.New("branch node must to has children nodes") - } - - if len(path) <= 0 { - return nil, errors.New("descend node key is empty") - } - - childIndex := (int)(path[len(path)-1]) - - for i.childIndex < childIndex { - child := node.children[i.childIndex] - if child != nil { - i.children[i.childIndex] = child - } - i.childIndex += 1 - } - - child := node.children[i.childIndex] - return i.makeChildEntry(proofI, child, path) - default: - return nil, errors.New("node must be a branch node") - } -} - -func (i *stackItem) makeChildEntry(proofI *proofIter, child node, path []byte) (*stackItem, error) { - if child == nil { - return newStackItem(path, proofI.next()) - } - encoded, err := encodeAndHash(child) - if err != nil { - return nil, err - } - - return newStackItem(path, encoded) -} - -func (i *stackItem) advanceItem(it *pairListIter) ([]byte, error) { - for { - item := it.peek() - if item == nil { - return nil, ErrNoMoreItemsOnIterable - } - - nk := Nibbles(keyToNibbles(item.key)) - if bytes.HasPrefix(nk, i.path) { - found, next, err := matchKeyToNode(nk, len(i.path), i.node) - - if err != nil { - return nil, err - } else if next != nil { - return next, nil - } else if found { - i.value = item.value - } - - it.next() - continue - } - - return nil, ErrNoMoreItemsOnIterable - } -} - -// matchKeyToNode return true if the leaf was found -// returns the byte array of the next node to keep searching -// returns error if the nibbles are exhausted or node key does not match -func matchKeyToNode(nk Nibbles, prefixOffset int, n node) (bool, []byte, error) { - switch node := n.(type) { - case nil: - return false, nil, ErrValueMatchNotFound - case *leaf: - if nk.contains(node.key, uint(prefixOffset)) && len(nk) == prefixOffset+len(node.key) { - return true, nil, nil - } - - return false, nil, ErrValueMatchNotFound - case *branch: - if nk.contains(node.key, uint(prefixOffset)) { - return matchKeyToBranchNode(nk, prefixOffset+len(node.key), node.children, node.value) - } - - return false, nil, ErrValueMatchNotFound - } - - return false, nil, ErrValueMatchNotFound -} - -func matchKeyToBranchNode(nk Nibbles, prefixPlusKeyLen int, children [16]node, value []byte) (bool, []byte, error) { - if len(nk) == prefixPlusKeyLen { - return false, nil, nil - } - - if len(nk) < prefixPlusKeyLen { - return false, nil, ErrExhaustedNibbles - } - - if children[nk[prefixPlusKeyLen]] == nil { - return false, nil, ErrValueMatchNotFound - } - - continueFrom := make([]byte, len(nk[prefixPlusKeyLen+1:])) - copy(continueFrom, nk[prefixPlusKeyLen+1:]) - return false, continueFrom, nil -} - -type stack []*stackItem - -func (s *stack) push(si *stackItem) { - *s = append(*s, si) -} - -func (s *stack) pop() *stackItem { - if len(*s) == 0 { - return nil - } - - i := (*s)[len(*s)-1] - *s = (*s)[:len(*s)-1] - return i -} - -type pair struct{ key, value []byte } -type PairList []*pair - -func (pl *PairList) Add(k, v []byte) { - *pl = append(*pl, &pair{k, v}) -} - -type pairListIter struct { - idx int - set []*pair -} - -func (i *pairListIter) peek() *pair { - if i.hasNext() { - return i.set[i.idx] - } - - return nil -} - -func (i *pairListIter) next() *pair { - if i.hasNext() { - return i.set[i.idx] - i.idx += 1 - } - - return nil -} - -func (i *pairListIter) hasNext() bool { - return len(i.set) < i.idx -} - -func (pl *PairList) toIter() *pairListIter { - return &pairListIter{0, *pl} -} - -type proofIter struct { - idx int - proof [][]byte -} - -func (p *proofIter) next() []byte { - if p.hasNext() { - i := p.proof[p.idx] - p.idx += 1 - return i - } - return nil -} - -func (p *proofIter) hasNext() bool { - return len(p.proof) < p.idx -} - -func newProofIter(proof [][]byte) *proofIter { - return &proofIter{0, proof} -} - -func VerifyProof(root common.Hash, proof [][]byte, items PairList) (bool, error) { - if len(proof) == 0 && len(items) == 0 { - return true, nil - } - - // check for duplicates - for i := 1; i < len(items); i++ { - if bytes.Equal(items[i].key, items[i-1].key) { - return false, ErrDuplicateKeys - } - } - - proofI := newProofIter(proof) - itemsI := items.toIter() - - var rootNode []byte - if rootNode = proofI.next(); rootNode == nil { - return false, ErrIncompleteProof - } - - lastEntry, err := newStackItem([]byte{}, rootNode) - if err != nil { - return false, err - } - - st := new(stack) - - for { - descend, err := lastEntry.advanceItem(itemsI) - - if errors.Is(err, ErrNoMoreItemsOnIterable) { - entry := st.pop() - if entry == nil { - if proofI.next() != nil { - return false, ErrExtraneousNode - } - - computedRoot := lastEntry.node.getHash() - if !root.Equal(common.BytesToHash(computedRoot)) { - return false, ErrRootMismatch - } - break - } - - lastEntry = entry - lastEntry.children[lastEntry.childIndex] = lastEntry.node - lastEntry.childIndex += 1 - - } else if err != nil { - return false, err - } - - nextEntry, err := lastEntry.advanceChildIndex(descend, proofI) - if err != nil { - return false, err - } - st.push(lastEntry) - lastEntry = nextEntry - } - - return false, nil -} From 2d816a4330f28cdabfe423fb7e6e9ce824f34ca3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 11 Oct 2021 18:21:48 -0400 Subject: [PATCH 04/18] chore: remove Nibbles custom type --- lib/trie/codec.go | 19 ------------------- lib/trie/proof_test.go | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/lib/trie/codec.go b/lib/trie/codec.go index 1d23b00920..a20389caaa 100644 --- a/lib/trie/codec.go +++ b/lib/trie/codec.go @@ -16,25 +16,6 @@ package trie -import ( - "bytes" - "fmt" -) - -type Nibbles []byte - -func (n *Nibbles) contains(in []byte, offset uint) bool { - if len(*n) < len(in) { - return false - } - - compareWith := (*n)[offset:len(in)] - - fmt.Printf("Current Nibbles: 0x%x | Offset: %v\n", *n, offset) - fmt.Printf("Comp: 0x%x\n", in) - return bytes.Equal(compareWith, in) -} - // keyToNibbles turns bytes into nibbles // does not rearrange the nibbles; assumes they are already ordered in LE func keyToNibbles(in []byte) []byte { diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 26725ac4c8..1c4982fea0 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -104,7 +104,7 @@ func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][ return root, proof, items } -func TestVerifyProofCorrectly(t *testing.T) { +func TestVerifyProof_ShouldReturnTrue(t *testing.T) { entries := []Pair{ {Key: []byte("alpha"), Value: make([]byte, 32)}, {Key: []byte("bravo"), Value: []byte("bravo")}, From e6a02d88ad0041243410813862a964c79d4fe3c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 11 Oct 2021 18:35:07 -0400 Subject: [PATCH 05/18] chore: fix lint warns --- lib/trie/database.go | 17 ++++----- lib/trie/lookup.go | 82 -------------------------------------------- lib/trie/proof.go | 12 +++++++ 3 files changed, 18 insertions(+), 93 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index d371fdc2d0..d46a4775e4 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -26,17 +26,8 @@ import ( "github.com/ChainSafe/chaindb" ) -var ( - ErrDuplicateKeys = errors.New("duplicate keys on verify proof") - ErrIncompleteProof = errors.New("incomplete proof") - ErrNoMoreItemsOnIterable = errors.New("items iterable exhausted") - ErrExhaustedNibbles = errors.New("exhausted nibbles key") - ErrExhaustedStack = errors.New("no more itens to pop from stack") - ErrValueMatchNotFound = errors.New("value match not found") - ErrExtraneousNode = errors.New("the proof contains at least one extraneous node") - ErrRootMismatch = errors.New("computed root does not match with the given one") - ErrValueMismatch = errors.New("expected value not found in the trie") -) +// ErrIncompleteProof indicates the proof slice is empty +var ErrIncompleteProof = errors.New("incomplete proof") // Store stores each trie node in the database, where the key is the hash of the encoded node and the value is the encoded node. // Generally, this will only be used for the genesis trie. @@ -117,6 +108,10 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { } } + if len(mappedNodes) == 0 { + return ErrIncompleteProof + } + return t.loadFromProof(mappedNodes, t.root) } diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 4a916d1088..021dd17b3b 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -2,90 +2,8 @@ package trie import ( "bytes" - "errors" - - "github.com/ChainSafe/chaindb" -) - -var ( - // ErrProofNodeNotFound when a needed proof node is not in the database - ErrProofNodeNotFound = errors.New("cannot find a trie node in the database") ) -// lookup struct holds the state root and database reference -// used to retrieve trie information from database -type lookup struct { - // root to start the lookup - root []byte - db chaindb.Database -} - -// newLookup returns a Lookup to helps the proof generator -func newLookup(rootHash []byte, db chaindb.Database) *lookup { - lk := &lookup{db: db} - lk.root = make([]byte, len(rootHash)) - copy(lk.root, rootHash) - - return lk -} - -// find will return the desired value or nil if key cannot be found and will record visited nodes -func (l *lookup) find(key []byte, recorder *recorder) ([]byte, error) { - partial := key - hash := l.root - - for { - nodeData, err := l.db.Get(hash) - if err != nil { - return nil, ErrProofNodeNotFound - } - - nodeHash := make([]byte, len(hash)) - copy(nodeHash, hash) - - recorder.record(nodeHash, nodeData) - - decoded, err := decodeBytes(nodeData) - if err != nil { - return nil, err - } - - switch currNode := decoded.(type) { - case nil: - return nil, nil - - case *leaf: - if bytes.Equal(currNode.key, partial) { - return currNode.value, nil - } - return nil, nil - - case *branch: - switch len(partial) { - case 0: - return currNode.value, nil - default: - if !bytes.HasPrefix(partial, currNode.key) { - return nil, nil - } - - if bytes.Equal(partial, currNode.key) { - return currNode.value, nil - } - - length := lenCommonPrefix(currNode.key, partial) - switch child := currNode.children[partial[length]].(type) { - case nil: - return nil, nil - default: - partial = partial[length+1:] - copy(hash, child.getHash()) - } - } - } - } -} - func findAndRecord(t *Trie, key []byte, recorder *recorder) []byte { l, err := find(t.root, key, recorder) if l == nil || err != nil { diff --git a/lib/trie/proof.go b/lib/trie/proof.go index c1d4ac5dcf..c3655ebfb4 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -27,6 +27,12 @@ import ( var ( // ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root ErrEmptyTrieRoot = errors.New("provided trie must have a root") + + // ErrValueMismatch indicates that a returned verify proof value doesnt match with the expected value on items array + ErrValueMismatch = errors.New("expected value not found in the trie") + + // ErrDuplicateKeys not allowed to verify proof with duplicate keys + ErrDuplicateKeys = errors.New("duplicate keys on verify proof") ) // GenerateProof receive the keys to proof, the trie root and a reference to database @@ -69,6 +75,12 @@ type Pair struct{ Key, Value []byte } // VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice // this function ignores the order of proofs func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { + for i := 1; i < len(items); i++ { + if bytes.Equal(items[i-1].Key, items[i].Key) { + return false, ErrDuplicateKeys + } + } + proofTrie := NewEmptyTrie() err := proofTrie.LoadFromProof(proof, root) if err != nil { From e92276b223890244d520b5e07147398e93ff9523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 11 Oct 2021 19:09:39 -0400 Subject: [PATCH 06/18] chore: add benchmark tests --- lib/trie/database.go | 18 +++++------ lib/trie/proof_test.go | 69 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index d46a4775e4..7c24184bc8 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -79,7 +79,7 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { // LoadFromProof create a trie based on the proof slice. func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { - mappedNodes := make(map[common.Hash]node) + mappedNodes := make(map[string]node) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node @@ -101,7 +101,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return err } - mappedNodes[common.BytesToHash(computedRoot)] = decNode + mappedNodes[common.BytesToHex(computedRoot)] = decNode if bytes.Equal(computedRoot, root) { t.root = decNode @@ -112,12 +112,13 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrIncompleteProof } - return t.loadFromProof(mappedNodes, t.root) + t.loadFromProof(mappedNodes, t.root) + return nil } // loadFromProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadFromProof(proof map[common.Hash]node, curr node) error { +func (t *Trie) loadFromProof(proof map[string]node, curr node) { switch c := curr.(type) { case *branch: for i, child := range c.children { @@ -125,20 +126,15 @@ func (t *Trie) loadFromProof(proof map[common.Hash]node, curr node) error { continue } - proofNode, ok := proof[common.BytesToHash(child.getHash())] + proofNode, ok := proof[common.BytesToHex(child.getHash())] if !ok { continue } c.children[i] = proofNode - err := t.loadFromProof(proof, proofNode) - if err != nil { - return err - } + t.loadFromProof(proof, proofNode) } } - - return nil } // Load reconstructs the trie from the database from the given root hash. Used when restarting the node to load the current state trie. diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 1c4982fea0..fff0da8f7a 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -19,12 +19,18 @@ package trie import ( "fmt" "io/ioutil" + "math/rand" "testing" + "time" "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func TestProofGeneration(t *testing.T) { tmp, err := ioutil.TempDir("", "*-test-trie") require.NoError(t, err) @@ -139,3 +145,66 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { require.True(t, v) require.NoError(t, err) } + +func Benchmark_GenerateAndVerifyAllKeys(b *testing.B) { + tmp, err := ioutil.TempDir("", "*-test-trie") + require.NoError(b, err) + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + DataDir: tmp, + }) + require.NoError(b, err) + + trie, keys, toProve := generateTrie(b, b.N*10) + trie.Store(memdb) + + root := trie.root.getHash() + proof, err := GenerateProof(root, keys, memdb) + require.NoError(b, err) + + v, err := VerifyProof(proof, root, *toProve) + require.True(b, v) + require.NoError(b, err) +} + +func Benchmark_GenerateAndVerifyAllKeys_ShuffleProof(b *testing.B) { + tmp, err := ioutil.TempDir("", "*-test-trie") + require.NoError(b, err) + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + DataDir: tmp, + }) + require.NoError(b, err) + + trie, keys, toProve := generateTrie(b, b.N*10) + trie.Store(memdb) + + root := trie.root.getHash() + proof, err := GenerateProof(root, keys, memdb) + require.NoError(b, err) + + for i := len(proof) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + proof[i], proof[j] = proof[j], proof[i] + } + v, err := VerifyProof(proof, root, *toProve) + require.True(b, v) + require.NoError(b, err) +} + +func generateTrie(t *testing.B, nodes int) (*Trie, [][]byte, *[]Pair) { + t.Helper() + + pairs := make([]Pair, 0) + keys := make([][]byte, 0) + + trie := NewEmptyTrie() + for i := 0; i < nodes; i++ { + key, value := rand32Bytes(), rand32Bytes() + trie.Put(key, value) + pairs = append(pairs, Pair{key, value}) + keys = append(keys, key) + } + + return trie, keys, &pairs +} From 493dddb7b0c02bff1479857aa592efc5d79a6b51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 12 Oct 2021 09:09:13 -0400 Subject: [PATCH 07/18] chore: fix deepsource warns --- lib/trie/database.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 7c24184bc8..5c7df3f0c0 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -112,15 +112,14 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { return ErrIncompleteProof } - t.loadFromProof(mappedNodes, t.root) + t.loadProof(mappedNodes, t.root) return nil } // loadFromProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root -func (t *Trie) loadFromProof(proof map[string]node, curr node) { - switch c := curr.(type) { - case *branch: +func (t *Trie) loadProof(proof map[string]node, curr node) { + if c, ok := curr.(*branch); ok { for i, child := range c.children { if child == nil { continue @@ -132,7 +131,7 @@ func (t *Trie) loadFromProof(proof map[string]node, curr node) { } c.children[i] = proofNode - t.loadFromProof(proof, proofNode) + t.loadProof(proof, proofNode) } } } From 9455e1389c20887a65e347beefe5ce1f8c698622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Wed, 13 Oct 2021 17:24:14 -0400 Subject: [PATCH 08/18] chore: redefine LoadFromProof function --- lib/trie/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 5c7df3f0c0..1f3a8394a9 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -77,7 +77,7 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { return nil } -// LoadFromProof create a trie based on the proof slice. +// LoadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik. func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { mappedNodes := make(map[string]node) From e8d4912c27b5ccfb07c4b87a51e6dd0235b5b3d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Wed, 13 Oct 2021 17:27:55 -0400 Subject: [PATCH 09/18] chore: remove logs --- lib/trie/proof_test.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index fff0da8f7a..09bdfbaeeb 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -17,7 +17,6 @@ package trie import ( - "fmt" "io/ioutil" "math/rand" "testing" @@ -129,12 +128,6 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { root, proof, _ := testGenerateProof(t, entries, keys) - fmt.Printf("ROOT: 0x%x\n", root) - fmt.Println("PROOF") - for _, p := range proof { - fmt.Printf("0x%x\n", p) - } - pl := []Pair{ {Key: []byte("do"), Value: []byte("verb")}, {Key: []byte("dog"), Value: []byte("puppy")}, From 10dd2960be746c1e3215582bf93410d9dd0d4090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Thu, 14 Oct 2021 11:46:56 -0400 Subject: [PATCH 10/18] chore: address comments --- lib/trie/database.go | 29 ++++++++++------- lib/trie/proof.go | 6 ++-- lib/trie/proof_test.go | 73 +++--------------------------------------- lib/trie/recorder.go | 2 +- 4 files changed, 26 insertions(+), 84 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 1f3a8394a9..8023a0bd81 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -119,20 +119,27 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadFromProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root func (t *Trie) loadProof(proof map[string]node, curr node) { - if c, ok := curr.(*branch); ok { - for i, child := range c.children { - if child == nil { - continue - } + var ( + c *branch + ok bool + ) - proofNode, ok := proof[common.BytesToHex(child.getHash())] - if !ok { - continue - } + if c, ok = curr.(*branch); ok { + return + } - c.children[i] = proofNode - t.loadProof(proof, proofNode) + for i, child := range c.children { + if child == nil { + continue } + + proofNode, ok := proof[common.BytesToHex(child.getHash())] + if !ok { + continue + } + + c.children[i] = proofNode + t.loadProof(proof, proofNode) } } diff --git a/lib/trie/proof.go b/lib/trie/proof.go index c3655ebfb4..e122c4806d 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -41,8 +41,7 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e trackedProofs := make(map[string][]byte) proofTrie := NewEmptyTrie() - err := proofTrie.Load(db, common.BytesToHash(root)) - if err != nil { + if err := proofTrie.Load(db, common.BytesToHash(root)); err != nil { return nil, err } @@ -82,8 +81,7 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { } proofTrie := NewEmptyTrie() - err := proofTrie.LoadFromProof(proof, root) - if err != nil { + if err := proofTrie.LoadFromProof(proof, root); err != nil { return false, err } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 09bdfbaeeb..1fbb944184 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -94,16 +94,16 @@ func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][ proof, err := GenerateProof(root, keys, memdb) require.NoError(t, err) - items := make([]Pair, 0) - for _, i := range keys { - value := trie.Get(i) + items := make([]Pair, len(keys)) + for idx, key := range keys { + value := trie.Get(key) require.NotNil(t, value) itemFromDB := Pair{ - Key: i, + Key: key, Value: value, } - items = append(items, itemFromDB) + items[idx] = itemFromDB } return root, proof, items @@ -138,66 +138,3 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { require.True(t, v) require.NoError(t, err) } - -func Benchmark_GenerateAndVerifyAllKeys(b *testing.B) { - tmp, err := ioutil.TempDir("", "*-test-trie") - require.NoError(b, err) - memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ - InMemory: true, - DataDir: tmp, - }) - require.NoError(b, err) - - trie, keys, toProve := generateTrie(b, b.N*10) - trie.Store(memdb) - - root := trie.root.getHash() - proof, err := GenerateProof(root, keys, memdb) - require.NoError(b, err) - - v, err := VerifyProof(proof, root, *toProve) - require.True(b, v) - require.NoError(b, err) -} - -func Benchmark_GenerateAndVerifyAllKeys_ShuffleProof(b *testing.B) { - tmp, err := ioutil.TempDir("", "*-test-trie") - require.NoError(b, err) - memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ - InMemory: true, - DataDir: tmp, - }) - require.NoError(b, err) - - trie, keys, toProve := generateTrie(b, b.N*10) - trie.Store(memdb) - - root := trie.root.getHash() - proof, err := GenerateProof(root, keys, memdb) - require.NoError(b, err) - - for i := len(proof) - 1; i > 0; i-- { - j := rand.Intn(i + 1) - proof[i], proof[j] = proof[j], proof[i] - } - v, err := VerifyProof(proof, root, *toProve) - require.True(b, v) - require.NoError(b, err) -} - -func generateTrie(t *testing.B, nodes int) (*Trie, [][]byte, *[]Pair) { - t.Helper() - - pairs := make([]Pair, 0) - keys := make([][]byte, 0) - - trie := NewEmptyTrie() - for i := 0; i < nodes; i++ { - key, value := rand32Bytes(), rand32Bytes() - trie.Put(key, value) - pairs = append(pairs, Pair{key, value}) - keys = append(keys, key) - } - - return trie, keys, &pairs -} diff --git a/lib/trie/recorder.go b/lib/trie/recorder.go index 3bcadaa121..5443e55401 100644 --- a/lib/trie/recorder.go +++ b/lib/trie/recorder.go @@ -9,7 +9,7 @@ type nodeRecord struct { // recorder keeps the list of nodes find by Lookup.Find type recorder []nodeRecord -// record insert a node insede the recorded list +// record insert a node inside the recorded list func (r *recorder) record(h, rd []byte) { *r = append(*r, nodeRecord{rawData: rd, hash: h}) } From 8e4a1c1824fc386189710ccb4e0443dfde79b7e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 18 Oct 2021 17:42:55 -0400 Subject: [PATCH 11/18] chore: fix the condition to load the proof --- lib/trie/database.go | 2 +- lib/trie/proof_test.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 8023a0bd81..7c06695487 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -124,7 +124,7 @@ func (t *Trie) loadProof(proof map[string]node, curr node) { ok bool ) - if c, ok = curr.(*branch); ok { + if c, ok = curr.(*branch); !ok { return } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 83659ac122..05277906cd 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -58,7 +58,6 @@ func TestProofGeneration(t *testing.T) { proof, err := GenerateProof(hash.ToBytes(), [][]byte{[]byte("catapulta"), []byte("catapora")}, memdb) require.NoError(t, err) - // TODO: use the verify_proof function to assert the tests (#1790) require.Equal(t, 5, len(proof)) pl := []Pair{ From 76893e2d44653c9f87bcc7e049fb0f95a7e32454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Mon, 18 Oct 2021 18:23:56 -0400 Subject: [PATCH 12/18] chore: address comments --- lib/trie/database.go | 19 ++++++++++--------- lib/trie/lookup.go | 31 +++++++++---------------------- lib/trie/proof.go | 24 +++++++++++++++++++----- lib/trie/proof_test.go | 9 +++------ 4 files changed, 41 insertions(+), 42 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 7c06695487..f450ef2bc2 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -26,8 +26,8 @@ import ( "github.com/ChainSafe/chaindb" ) -// ErrIncompleteProof indicates the proof slice is empty -var ErrIncompleteProof = errors.New("incomplete proof") +// ErrEmptyProof indicates the proof slice is empty +var ErrEmptyProof = errors.New("proof slice empty") // Store stores each trie node in the database, where the key is the hash of the encoded node and the value is the encoded node. // Generally, this will only be used for the genesis trie. @@ -79,7 +79,11 @@ func (t *Trie) store(db chaindb.Batch, curr node) error { // LoadFromProof create a partial trie based on the proof slice, as it only contains nodes that are in the proof afaik. func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { - mappedNodes := make(map[string]node) + if len(proof) == 0 { + return ErrEmptyProof + } + + mappedNodes := make(map[string]node, len(proof)) // map all the proofs hash -> decoded node // and takes the loop to indentify the root node @@ -89,7 +93,8 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { err error ) - if decNode, err = decodeBytes(rawNode); err != nil { + decNode, err = decodeBytes(rawNode) + if err != nil { return err } @@ -108,15 +113,11 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { } } - if len(mappedNodes) == 0 { - return ErrIncompleteProof - } - t.loadProof(mappedNodes, t.root) return nil } -// loadFromProof is a recursive function that will create all the trie paths based +// loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root func (t *Trie) loadProof(proof map[string]node, curr node) { var ( diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 021dd17b3b..6963aabb5b 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -4,21 +4,16 @@ import ( "bytes" ) -func findAndRecord(t *Trie, key []byte, recorder *recorder) []byte { - l, err := find(t.root, key, recorder) - if l == nil || err != nil { - return nil - } - - return l.value +// findAndRecord search for a desired key recording all the nodes in the path including the desired node +func findAndRecord(t *Trie, key []byte, recorder *recorder) error { + return find(t.root, key, recorder) } -func find(parent node, key []byte, recorder *recorder) (*leaf, error) { +func find(parent node, key []byte, recorder *recorder) error { enc, hash, err := parent.encodeAndHash() if err != nil { - return nil, err + return err } - recorder.record(hash, enc) switch p := parent.(type) { @@ -27,28 +22,20 @@ func find(parent node, key []byte, recorder *recorder) (*leaf, error) { // found the value at this node if bytes.Equal(p.key, key) || len(key) == 0 { - return &leaf{key: p.key, value: p.value, dirty: false}, nil + return nil } // did not find value if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { - return nil, nil + return nil } return find(p.children[key[length]], key[length+1:], recorder) case *leaf: - enc, hash, err := p.encodeAndHash() - if err != nil { - return nil, err - } - - recorder.record(hash, enc) if bytes.Equal(p.key, key) { - return p, nil + return nil } - default: - return nil, nil } - return nil, nil + return nil } diff --git a/lib/trie/proof.go b/lib/trie/proof.go index e122c4806d..325beadf59 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -19,6 +19,8 @@ package trie import ( "bytes" "errors" + "fmt" + "sort" "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" @@ -28,11 +30,14 @@ var ( // ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root ErrEmptyTrieRoot = errors.New("provided trie must have a root") - // ErrValueMismatch indicates that a returned verify proof value doesnt match with the expected value on items array - ErrValueMismatch = errors.New("expected value not found in the trie") + // ErrValueNotFound indicates that a returned verify proof value doesnt match with the expected value on items array + ErrValueNotFound = errors.New("expected value not found in the trie") // ErrDuplicateKeys not allowed to verify proof with duplicate keys ErrDuplicateKeys = errors.New("duplicate keys on verify proof") + + // ErrLoadFromProof occurs when there are problems with the proof slice while building the partial proof trie + ErrLoadFromProof = errors.New("failed to build the proof trie") ) // GenerateProof receive the keys to proof, the trie root and a reference to database @@ -49,7 +54,10 @@ func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, e nk := keyToNibbles(k) recorder := new(recorder) - findAndRecord(proofTrie, nk, recorder) + err := findAndRecord(proofTrie, nk, recorder) + if err != nil { + return nil, err + } for !recorder.isEmpty() { recNode := recorder.next() @@ -74,6 +82,12 @@ type Pair struct{ Key, Value []byte } // VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice // this function ignores the order of proofs func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { + // ordering in the asc order + sort.Slice(items, func(i, j int) bool { + return bytes.Compare(items[i].Key, items[j].Key) == -1 + }) + + // check for duplicates for i := 1; i < len(items); i++ { if bytes.Equal(items[i-1].Key, items[i].Key) { return false, ErrDuplicateKeys @@ -82,13 +96,13 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { proofTrie := NewEmptyTrie() if err := proofTrie.LoadFromProof(proof, root); err != nil { - return false, err + return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) } for _, i := range items { recValue := proofTrie.Get(i.Key) if !bytes.Equal(i.Value, recValue) { - return false, ErrValueMismatch + return false, ErrValueNotFound } } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 05277906cd..94e8b30637 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -18,19 +18,14 @@ package trie import ( "io/ioutil" - "math/rand" "testing" - "time" "github.com/ChainSafe/chaindb" "github.com/stretchr/testify/require" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - func TestProofGeneration(t *testing.T) { + t.Parallel() tmp, err := ioutil.TempDir("", "*-test-trie") require.NoError(t, err) @@ -109,6 +104,8 @@ func testGenerateProof(t *testing.T, entries []Pair, keys [][]byte) ([]byte, [][ } func TestVerifyProof_ShouldReturnTrue(t *testing.T) { + t.Parallel() + entries := []Pair{ {Key: []byte("alpha"), Value: make([]byte, 32)}, {Key: []byte("bravo"), Value: []byte("bravo")}, From 7f24a6576d2720085fac22b6564d7c5d60f6a16f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 10:26:50 -0400 Subject: [PATCH 13/18] chore: improve find function --- lib/trie/database.go | 8 ++------ lib/trie/lookup.go | 37 +++++++++++++++++-------------------- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index f450ef2bc2..1bc12e4d80 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -120,12 +120,8 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // loadProof is a recursive function that will create all the trie paths based // on the mapped proofs slice starting by the root func (t *Trie) loadProof(proof map[string]node, curr node) { - var ( - c *branch - ok bool - ) - - if c, ok = curr.(*branch); !ok { + c, ok := curr.(*branch) + if !ok { return } diff --git a/lib/trie/lookup.go b/lib/trie/lookup.go index 6963aabb5b..67f2792692 100644 --- a/lib/trie/lookup.go +++ b/lib/trie/lookup.go @@ -14,28 +14,25 @@ func find(parent node, key []byte, recorder *recorder) error { if err != nil { return err } + recorder.record(hash, enc) - switch p := parent.(type) { - case *branch: - length := lenCommonPrefix(p.key, key) - - // found the value at this node - if bytes.Equal(p.key, key) || len(key) == 0 { - return nil - } - - // did not find value - if bytes.Equal(p.key[:length], key) && len(key) < len(p.key) { - return nil - } - - return find(p.children[key[length]], key[length+1:], recorder) - case *leaf: - if bytes.Equal(p.key, key) { - return nil - } + b, ok := parent.(*branch) + if !ok { + return nil + } + + length := lenCommonPrefix(b.key, key) + + // found the value at this node + if bytes.Equal(b.key, key) || len(key) == 0 { + return nil + } + + // did not find value + if bytes.Equal(b.key[:length], key) && len(key) < len(b.key) { + return nil } - return nil + return find(b.children[key[length]], key[length+1:], recorder) } From 5a6d3f8da1378a8c7b91d515c072f1d5b36b341f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 10:40:02 -0400 Subject: [PATCH 14/18] chore: use map to avoid duplicate keys --- lib/trie/proof.go | 31 +++++++++++++++++++------------ lib/trie/proof_test.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 325beadf59..62eb82e951 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -18,9 +18,9 @@ package trie import ( "bytes" + "encoding/hex" "errors" "fmt" - "sort" "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" @@ -82,16 +82,16 @@ type Pair struct{ Key, Value []byte } // VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice // this function ignores the order of proofs func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { - // ordering in the asc order - sort.Slice(items, func(i, j int) bool { - return bytes.Compare(items[i].Key, items[j].Key) == -1 - }) - - // check for duplicates - for i := 1; i < len(items); i++ { - if bytes.Equal(items[i-1].Key, items[i].Key) { + set := make(map[string][]byte, len(items)) + + // check for duplicate keys + for _, item := range items { + hexKey := hex.EncodeToString(item.Key) + if _, ok := set[hexKey]; ok { return false, ErrDuplicateKeys } + + set[hexKey] = item.Value } proofTrie := NewEmptyTrie() @@ -99,9 +99,16 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) } - for _, i := range items { - recValue := proofTrie.Get(i.Key) - if !bytes.Equal(i.Value, recValue) { + for k, v := range set { + key, err := hex.DecodeString(k) + if err != nil { + return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) + } + + recValue := proofTrie.Get(key) + + // here we need to compare value only if the caller pass the value + if v != nil && !bytes.Equal(v, recValue) { return false, ErrValueNotFound } } diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 94e8b30637..695584ecab 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -134,3 +134,35 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { require.True(t, v) require.NoError(t, err) } + +func TestVerifyProof_ShouldReturnDuplicateKeysError(t *testing.T) { + t.Parallel() + + entries := []Pair{ + {Key: []byte("alpha"), Value: make([]byte, 32)}, + {Key: []byte("bravo"), Value: []byte("bravo")}, + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + {Key: []byte("horse"), Value: []byte("stallion")}, + {Key: []byte("house"), Value: []byte("building")}, + } + + keys := [][]byte{ + []byte("do"), + []byte("dog"), + []byte("doge"), + } + + root, proof, _ := testGenerateProof(t, entries, keys) + + pl := []Pair{ + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("dog"), Value: []byte("puppy")}, + {Key: []byte("doge"), Value: make([]byte, 32)}, + } + + v, err := VerifyProof(proof, root, pl) + require.True(t, v) + require.NoError(t, err) +} From 8bd6a335cd0ffb26c6915ce6dc184b7e4fa3cd28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 10:45:57 -0400 Subject: [PATCH 15/18] chore: add test cases to duplicate values and nil values --- lib/trie/proof_test.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index 695584ecab..ce1faacd83 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -138,6 +138,19 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { func TestVerifyProof_ShouldReturnDuplicateKeysError(t *testing.T) { t.Parallel() + pl := []Pair{ + {Key: []byte("do"), Value: []byte("verb")}, + {Key: []byte("do"), Value: []byte("puppy")}, + } + + v, err := VerifyProof([][]byte{}, []byte{}, pl) + require.False(t, v) + require.Error(t, err, ErrDuplicateKeys) +} + +func TestVerifyProof_ShouldReturnTrueWithouCompareValues(t *testing.T) { + t.Parallel() + entries := []Pair{ {Key: []byte("alpha"), Value: make([]byte, 32)}, {Key: []byte("bravo"), Value: []byte("bravo")}, @@ -157,9 +170,9 @@ func TestVerifyProof_ShouldReturnDuplicateKeysError(t *testing.T) { root, proof, _ := testGenerateProof(t, entries, keys) pl := []Pair{ - {Key: []byte("do"), Value: []byte("verb")}, - {Key: []byte("dog"), Value: []byte("puppy")}, - {Key: []byte("doge"), Value: make([]byte, 32)}, + {Key: []byte("do"), Value: nil}, + {Key: []byte("dog"), Value: nil}, + {Key: []byte("doge"), Value: nil}, } v, err := VerifyProof(proof, root, pl) From f4081bab063ec91b6622c7ca5ca80edda41475e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 11:03:41 -0400 Subject: [PATCH 16/18] chore: fix unused param lint error --- lib/trie/proof_test.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lib/trie/proof_test.go b/lib/trie/proof_test.go index ce1faacd83..72a42ad81b 100644 --- a/lib/trie/proof_test.go +++ b/lib/trie/proof_test.go @@ -122,13 +122,7 @@ func TestVerifyProof_ShouldReturnTrue(t *testing.T) { []byte("doge"), } - root, proof, _ := testGenerateProof(t, entries, keys) - - pl := []Pair{ - {Key: []byte("do"), Value: []byte("verb")}, - {Key: []byte("dog"), Value: []byte("puppy")}, - {Key: []byte("doge"), Value: make([]byte, 32)}, - } + root, proof, pl := testGenerateProof(t, entries, keys) v, err := VerifyProof(proof, root, pl) require.True(t, v) From d8a5e41cc8cb6c679632aecc511a87a203195adb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 14:57:08 -0400 Subject: [PATCH 17/18] chore: use the shortest form --- lib/trie/database.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lib/trie/database.go b/lib/trie/database.go index 1bc12e4d80..1f15b0159a 100644 --- a/lib/trie/database.go +++ b/lib/trie/database.go @@ -88,12 +88,7 @@ func (t *Trie) LoadFromProof(proof [][]byte, root []byte) error { // map all the proofs hash -> decoded node // and takes the loop to indentify the root node for _, rawNode := range proof { - var ( - decNode node - err error - ) - - decNode, err = decodeBytes(rawNode) + decNode, err := decodeBytes(rawNode) if err != nil { return err } From 6465be8d7521323ead7a8089eca6c7ed35af853c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20J=C3=BAnior?= Date: Tue, 19 Oct 2021 14:58:35 -0400 Subject: [PATCH 18/18] chore: use set just for find dupl keys --- lib/trie/proof.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/lib/trie/proof.go b/lib/trie/proof.go index 62eb82e951..a4a83b919f 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -82,7 +82,7 @@ type Pair struct{ Key, Value []byte } // VerifyProof ensure a given key is inside a proof by creating a proof trie based on the proof slice // this function ignores the order of proofs func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { - set := make(map[string][]byte, len(items)) + set := make(map[string]struct{}, len(items)) // check for duplicate keys for _, item := range items { @@ -90,8 +90,7 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { if _, ok := set[hexKey]; ok { return false, ErrDuplicateKeys } - - set[hexKey] = item.Value + set[hexKey] = struct{}{} } proofTrie := NewEmptyTrie() @@ -99,16 +98,11 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) } - for k, v := range set { - key, err := hex.DecodeString(k) - if err != nil { - return false, fmt.Errorf("%w: %s", ErrLoadFromProof, err) - } - - recValue := proofTrie.Get(key) + for _, item := range items { + recValue := proofTrie.Get(item.Key) // here we need to compare value only if the caller pass the value - if v != nil && !bytes.Equal(v, recValue) { + if item.Value != nil && !bytes.Equal(item.Value, recValue) { return false, ErrValueNotFound } }