From e0f30276eaba4eac0b930b26414f79e52df8d249 Mon Sep 17 00:00:00 2001 From: Gary Rong Date: Mon, 9 Oct 2023 11:35:15 +0800 Subject: [PATCH] trie: polish --- trie/stacktrie.go | 146 +++++++++++++++++----------------- trie/stacktrie_marshalling.go | 56 +++++++------ trie/stacktrie_test.go | 2 +- 3 files changed, 104 insertions(+), 100 deletions(-) diff --git a/trie/stacktrie.go b/trie/stacktrie.go index 48820986c59d..6b232882d31f 100644 --- a/trie/stacktrie.go +++ b/trie/stacktrie.go @@ -63,48 +63,51 @@ func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie } // Update inserts a (key, value) pair into the stack trie. -func (stack *StackTrie) Update(key, value []byte) error { +func (t *StackTrie) Update(key, value []byte) error { k := keybytesToHex(key) if len(value) == 0 { panic("deletion not supported") } - stack.insert(stack.root, k[:len(k)-1], value, nil) + t.insert(t.root, k[:len(k)-1], value, nil) return nil } // MustUpdate is a wrapper of Update and will omit any encountered error but // just print out an error message. -func (st *StackTrie) MustUpdate(key, value []byte) { - if err := st.Update(key, value); err != nil { +func (t *StackTrie) MustUpdate(key, value []byte) { + if err := t.Update(key, value); err != nil { log.Error("Unhandled trie error in StackTrie.Update", "err", err) } } -func (stack *StackTrie) Reset() { - stack.owner = (common.Hash{}) - stack.writeFn = nil - stack.root = stPool.Get().(*stNode) +func (t *StackTrie) Reset() { + t.root = stPool.Get().(*stNode) } // stNode represents a node within a StackTrie type stNode struct { - nodeType uint8 // node type (as in branch, ext, leaf) - val []byte // value contained by this node if it's a leaf + typ uint8 // node type (as in branch, ext, leaf) key []byte // key chunk covered by this (leaf|ext) node + val []byte // value contained by this node if it's a leaf children [16]*stNode // list of children (for branch and exts) } +// newLeaf constructs a leaf node with provided node key and value. The key +// will be deep-copied in the function and safe to modify afterwards, but +// value is not. func newLeaf(key, val []byte) *stNode { st := stPool.Get().(*stNode) - st.nodeType = leafNode + st.typ = leafNode st.key = append(st.key, key...) st.val = val return st } +// newExt constructs an extension node with provided node key and child. The +// key will be deep-copied in the function and safe to modify afterwards. func newExt(key []byte, child *stNode) *stNode { st := stPool.Get().(*stNode) - st.nodeType = extNode + st.typ = extNode st.key = append(st.key, key...) st.children[0] = child return st @@ -119,40 +122,40 @@ const ( hashedNode ) -func (st *stNode) Reset() *stNode { - st.key = st.key[:0] - st.val = nil - for i := range st.children { - st.children[i] = nil +func (n *stNode) reset() *stNode { + n.key = n.key[:0] + n.val = nil + for i := range n.children { + n.children[i] = nil } - st.nodeType = emptyNode - return st + n.typ = emptyNode + return n } // Helper function that, given a full key, determines the index // at which the chunk pointed by st.keyOffset is different from // the same chunk in the full key. -func (st *stNode) getDiffIndex(key []byte) int { - for idx, nibble := range st.key { +func (n *stNode) getDiffIndex(key []byte) int { + for idx, nibble := range n.key { if nibble != key[idx] { return idx } } - return len(st.key) + return len(n.key) } // Helper function to that inserts a (key, value) pair into // the trie. -func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { - switch st.nodeType { +func (t *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { + switch st.typ { case branchNode: /* Branch */ idx := int(key[0]) // Unresolve elder siblings for i := idx - 1; i >= 0; i-- { if st.children[i] != nil { - if st.children[i].nodeType != hashedNode { - stack.hash(st.children[i], append(prefix, byte(i))) + if st.children[i].typ != hashedNode { + t.hash(st.children[i], append(prefix, byte(i))) } break } @@ -162,7 +165,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { if st.children[idx] == nil { st.children[idx] = newLeaf(key[1:], value) } else { - stack.insert(st.children[idx], key[1:], value, append(prefix, key[0])) + t.insert(st.children[idx], key[1:], value, append(prefix, key[0])) } case extNode: /* Ext */ @@ -177,7 +180,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { if diffidx == len(st.key) { // Ext key and key segment are identical, recurse into // the child node. - stack.insert(st.children[0], key[diffidx:], value, append(prefix, key[:diffidx]...)) + t.insert(st.children[0], key[diffidx:], value, append(prefix, key[:diffidx]...)) return } // Save the original part. Depending if the break is @@ -190,14 +193,14 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { // extension. The path prefix of the newly-inserted // extension should also contain the different byte. n = newExt(st.key[diffidx+1:], st.children[0]) - stack.hash(n, append(prefix, st.key[:diffidx+1]...)) + t.hash(n, append(prefix, st.key[:diffidx+1]...)) } else { // Break on the last byte, no need to insert // an extension node: reuse the current node. // The path prefix of the original part should // still be same. n = st.children[0] - stack.hash(n, append(prefix, st.key...)) + t.hash(n, append(prefix, st.key...)) } var p *stNode if diffidx == 0 { @@ -206,13 +209,13 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { // a branch node. st.children[0] = nil p = st - st.nodeType = branchNode + st.typ = branchNode } else { // the common prefix is at least one byte // long, insert a new intermediate branch // node. st.children[0] = stPool.Get().(*stNode) - st.children[0].nodeType = branchNode + st.children[0].typ = branchNode p = st.children[0] } // Create a leaf for the inserted part @@ -245,15 +248,15 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { var p *stNode if diffidx == 0 { // Convert current leaf into a branch - st.nodeType = branchNode + st.typ = branchNode p = st st.children[0] = nil } else { // Convert current node into an ext, // and insert a child branch node. - st.nodeType = extNode + st.typ = extNode st.children[0] = stPool.Get().(*stNode) - st.children[0].nodeType = branchNode + st.children[0].typ = branchNode p = st.children[0] } @@ -262,7 +265,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { // is hashed directly in order to free up some memory. origIdx := st.key[diffidx] p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val) - stack.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...)) + t.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...)) newIdx := key[diffidx] p.children[newIdx] = newLeaf(key[diffidx+1:], value) @@ -273,7 +276,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { st.val = nil case emptyNode: /* Empty */ - st.nodeType = leafNode + st.typ = leafNode st.key = key st.val = value @@ -296,18 +299,18 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) { // - And the 'st.type' will be 'hashedNode' AGAIN // // This method also sets 'st.type' to hashedNode, and clears 'st.key'. -func (stack *StackTrie) hash(st *stNode, path []byte) { +func (t *StackTrie) hash(st *stNode, path []byte) { // The switch below sets this to the RLP-encoding of this node. var encodedNode []byte - switch st.nodeType { + switch st.typ { case hashedNode: return case emptyNode: st.val = types.EmptyRootHash.Bytes() st.key = st.key[:0] - st.nodeType = hashedNode + st.typ = hashedNode return case branchNode: @@ -317,23 +320,21 @@ func (stack *StackTrie) hash(st *stNode, path []byte) { nodes.Children[i] = nilValueNode continue } - stack.hash(child, append(path, byte(i))) + t.hash(child, append(path, byte(i))) + if len(child.val) < 32 { nodes.Children[i] = rawNode(child.val) } else { nodes.Children[i] = hashNode(child.val) } - - // Release child back to pool. st.children[i] = nil - stPool.Put(child.Reset()) + stPool.Put(child.reset()) // Release child back to pool. } - - nodes.encode(stack.h.encbuf) - encodedNode = stack.h.encodedBytes() + nodes.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() case extNode: - stack.hash(st.children[0], append(path, st.key...)) + t.hash(st.children[0], append(path, st.key...)) n := shortNode{Key: hexToCompactInPlace(st.key)} if len(st.children[0].val) < 32 { @@ -341,27 +342,24 @@ func (stack *StackTrie) hash(st *stNode, path []byte) { } else { n.Val = hashNode(st.children[0].val) } + n.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() - n.encode(stack.h.encbuf) - encodedNode = stack.h.encodedBytes() - - // Release child back to pool. - stPool.Put(st.children[0].Reset()) - + stPool.Put(st.children[0].reset()) // Release child back to pool. st.children[0] = nil case leafNode: st.key = append(st.key, byte(16)) n := shortNode{Key: hexToCompactInPlace(st.key), Val: valueNode(st.val)} - n.encode(stack.h.encbuf) - encodedNode = stack.h.encodedBytes() + n.encode(t.h.encbuf) + encodedNode = t.h.encodedBytes() default: panic("invalid node type") } - st.nodeType = hashedNode + st.typ = hashedNode st.key = st.key[:0] if len(encodedNode) < 32 { st.val = common.CopyBytes(encodedNode) @@ -370,16 +368,16 @@ func (stack *StackTrie) hash(st *stNode, path []byte) { // Write the hash to the 'val'. We allocate a new val here to not mutate // input values - st.val = stack.h.hashData(encodedNode) - if stack.writeFn != nil { - stack.writeFn(stack.owner, path, common.BytesToHash(st.val), encodedNode) + st.val = t.h.hashData(encodedNode) + if t.writeFn != nil { + t.writeFn(t.owner, path, common.BytesToHash(st.val), encodedNode) } } // Hash returns the hash of the current node. -func (stack *StackTrie) Hash() (h common.Hash) { - st := stack.root - stack.hash(st, nil) +func (t *StackTrie) Hash() (h common.Hash) { + st := t.root + t.hash(st, nil) if len(st.val) == 32 { copy(h[:], st.val) return h @@ -387,9 +385,9 @@ func (stack *StackTrie) Hash() (h common.Hash) { // If the node's RLP isn't 32 bytes long, the node will not // be hashed, and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing. - stack.h.sha.Reset() - stack.h.sha.Write(st.val) - stack.h.sha.Read(h[:]) + t.h.sha.Reset() + t.h.sha.Write(st.val) + t.h.sha.Read(h[:]) return h } @@ -400,12 +398,12 @@ func (stack *StackTrie) Hash() (h common.Hash) { // // The associated database is expected, otherwise the whole commit // functionality should be disabled. -func (stack *StackTrie) Commit() (h common.Hash, err error) { - if stack.writeFn == nil { +func (t *StackTrie) Commit() (h common.Hash, err error) { + if t.writeFn == nil { return common.Hash{}, ErrCommitDisabled } - st := stack.root - stack.hash(st, nil) + st := t.root + t.hash(st, nil) if len(st.val) == 32 { copy(h[:], st.val) return h, nil @@ -413,10 +411,10 @@ func (stack *StackTrie) Commit() (h common.Hash, err error) { // If the node's RLP isn't 32 bytes long, the node will not // be hashed (and committed), and instead contain the rlp-encoding of the // node. For the top level node, we need to force the hashing+commit. - stack.h.sha.Reset() - stack.h.sha.Write(st.val) - stack.h.sha.Read(h[:]) + t.h.sha.Reset() + t.h.sha.Write(st.val) + t.h.sha.Read(h[:]) - stack.writeFn(stack.owner, nil, h, st.val) + t.writeFn(t.owner, nil, h, st.val) return h, nil } diff --git a/trie/stacktrie_marshalling.go b/trie/stacktrie_marshalling.go index bc46b7c57283..c0bb07f8685b 100644 --- a/trie/stacktrie_marshalling.go +++ b/trie/stacktrie_marshalling.go @@ -23,60 +23,66 @@ import ( "encoding/gob" ) -var ( //Compile-time interface checks +// Compile-time interface checks. +var ( _ = encoding.BinaryMarshaler((*StackTrie)(nil)) _ = encoding.BinaryUnmarshaler((*StackTrie)(nil)) ) // NewFromBinaryV2 initialises a serialized stacktrie with the given db. // OBS! Format was changed along with the name of this constructor. -func NewFromBinaryV2(data []byte, writeFn NodeWriteFunc) (*StackTrie, error) { - stack := NewStackTrie(writeFn) +func NewFromBinaryV2(data []byte) (*StackTrie, error) { + stack := NewStackTrie(nil) if err := stack.UnmarshalBinary(data); err != nil { return nil, err } return stack, nil } -// UnmarshalBinary implements encoding.BinaryMarshaler -func (st *StackTrie) MarshalBinary() (data []byte, err error) { +// MarshalBinary implements encoding.BinaryMarshaler. +func (t *StackTrie) MarshalBinary() (data []byte, err error) { var ( b bytes.Buffer w = bufio.NewWriter(&b) ) - if err := gob.NewEncoder(w).Encode(st.owner); err != nil { + if err := gob.NewEncoder(w).Encode(t.owner); err != nil { return nil, err } - if err := st.root.marshalInto(w); err != nil { + if err := t.root.marshalInto(w); err != nil { return nil, err } w.Flush() return b.Bytes(), nil } -// UnmarshalBinary implements encoding.BinaryUnmarshaler -func (stack *StackTrie) UnmarshalBinary(data []byte) error { +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (t *StackTrie) UnmarshalBinary(data []byte) error { r := bytes.NewReader(data) - if err := gob.NewDecoder(r).Decode(&stack.owner); err != nil { + if err := gob.NewDecoder(r).Decode(&t.owner); err != nil { return err } - if err := stack.root.unmarshalFrom(r); err != nil { + if err := t.root.unmarshalFrom(r); err != nil { return err } return nil } -type encodedNode struct { - NodeType uint8 - Val []byte - Key []byte +type stackNodeMarshaling struct { + Typ uint8 + Key []byte + Val []byte } -func (st *stNode) marshalInto(w *bufio.Writer) (err error) { - if err := gob.NewEncoder(w).Encode(encodedNode{st.nodeType, st.val, st.key}); err != nil { +func (n *stNode) marshalInto(w *bufio.Writer) (err error) { + enc := stackNodeMarshaling{ + Typ: n.typ, + Key: n.key, + Val: n.val, + } + if err := gob.NewEncoder(w).Encode(enc); err != nil { return err } - for _, child := range st.children { + for _, child := range n.children { if child == nil { w.WriteByte(0) continue @@ -89,16 +95,16 @@ func (st *stNode) marshalInto(w *bufio.Writer) (err error) { return nil } -func (st *stNode) unmarshalFrom(r *bytes.Reader) error { - var dec encodedNode +func (n *stNode) unmarshalFrom(r *bytes.Reader) error { + var dec stackNodeMarshaling if err := gob.NewDecoder(r).Decode(&dec); err != nil { return err } - st.nodeType = dec.NodeType - st.val = dec.Val - st.key = dec.Key + n.typ = dec.Typ + n.key = dec.Key + n.val = dec.Val - for i := range st.children { + for i := range n.children { if b, err := r.ReadByte(); err != nil { return err } else if b == 0 { @@ -108,7 +114,7 @@ func (st *stNode) unmarshalFrom(r *bytes.Reader) error { if err := child.unmarshalFrom(r); err != nil { return err } - st.children[i] = &child + n.children[i] = &child } return nil } diff --git a/trie/stacktrie_test.go b/trie/stacktrie_test.go index 6b2be147ddb6..5b86a971e10c 100644 --- a/trie/stacktrie_test.go +++ b/trie/stacktrie_test.go @@ -410,7 +410,7 @@ func TestStacktrieSerialization(t *testing.T) { if err != nil { t.Fatal(err) } - newSt, err := NewFromBinaryV2(blob, nil) + newSt, err := NewFromBinaryV2(blob) if err != nil { t.Fatal(err) }