Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trie: polish #52

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 72 additions & 74 deletions trie/stacktrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,48 +63,51 @@ func NewStackTrieWithOwner(writeFn NodeWriteFunc, owner common.Hash) *StackTrie
}

// Update inserts a (key, value) pair into the stack trie.
func (stack *StackTrie) Update(key, value []byte) error {
func (t *StackTrie) Update(key, value []byte) error {
k := keybytesToHex(key)
if len(value) == 0 {
panic("deletion not supported")
}
stack.insert(stack.root, k[:len(k)-1], value, nil)
t.insert(t.root, k[:len(k)-1], value, nil)
return nil
}

// MustUpdate is a wrapper of Update and will omit any encountered error but
// just print out an error message.
func (st *StackTrie) MustUpdate(key, value []byte) {
if err := st.Update(key, value); err != nil {
func (t *StackTrie) MustUpdate(key, value []byte) {
if err := t.Update(key, value); err != nil {
log.Error("Unhandled trie error in StackTrie.Update", "err", err)
}
}

func (stack *StackTrie) Reset() {
stack.owner = (common.Hash{})
stack.writeFn = nil
stack.root = stPool.Get().(*stNode)
func (t *StackTrie) Reset() {
t.root = stPool.Get().(*stNode)
}

// stNode represents a node within a StackTrie
type stNode struct {
nodeType uint8 // node type (as in branch, ext, leaf)
val []byte // value contained by this node if it's a leaf
typ uint8 // node type (as in branch, ext, leaf)
key []byte // key chunk covered by this (leaf|ext) node
val []byte // value contained by this node if it's a leaf
children [16]*stNode // list of children (for branch and exts)
}

// newLeaf constructs a leaf node with provided node key and value. The key
// will be deep-copied in the function and safe to modify afterwards, but
// value is not.
func newLeaf(key, val []byte) *stNode {
st := stPool.Get().(*stNode)
st.nodeType = leafNode
st.typ = leafNode
st.key = append(st.key, key...)
st.val = val
return st
}

// newExt constructs an extension node with provided node key and child. The
// key will be deep-copied in the function and safe to modify afterwards.
func newExt(key []byte, child *stNode) *stNode {
st := stPool.Get().(*stNode)
st.nodeType = extNode
st.typ = extNode
st.key = append(st.key, key...)
st.children[0] = child
return st
Expand All @@ -119,40 +122,40 @@ const (
hashedNode
)

func (st *stNode) Reset() *stNode {
st.key = st.key[:0]
st.val = nil
for i := range st.children {
st.children[i] = nil
func (n *stNode) reset() *stNode {
n.key = n.key[:0]
n.val = nil
for i := range n.children {
n.children[i] = nil
}
st.nodeType = emptyNode
return st
n.typ = emptyNode
return n
}

// Helper function that, given a full key, determines the index
// at which the chunk pointed by st.keyOffset is different from
// the same chunk in the full key.
func (st *stNode) getDiffIndex(key []byte) int {
for idx, nibble := range st.key {
func (n *stNode) getDiffIndex(key []byte) int {
for idx, nibble := range n.key {
if nibble != key[idx] {
return idx
}
}
return len(st.key)
return len(n.key)
}

// Helper function to that inserts a (key, value) pair into
// the trie.
func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
switch st.nodeType {
func (t *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
switch st.typ {
case branchNode: /* Branch */
idx := int(key[0])

// Unresolve elder siblings
for i := idx - 1; i >= 0; i-- {
if st.children[i] != nil {
if st.children[i].nodeType != hashedNode {
stack.hash(st.children[i], append(prefix, byte(i)))
if st.children[i].typ != hashedNode {
t.hash(st.children[i], append(prefix, byte(i)))
}
break
}
Expand All @@ -162,7 +165,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
if st.children[idx] == nil {
st.children[idx] = newLeaf(key[1:], value)
} else {
stack.insert(st.children[idx], key[1:], value, append(prefix, key[0]))
t.insert(st.children[idx], key[1:], value, append(prefix, key[0]))
}

case extNode: /* Ext */
Expand All @@ -177,7 +180,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
if diffidx == len(st.key) {
// Ext key and key segment are identical, recurse into
// the child node.
stack.insert(st.children[0], key[diffidx:], value, append(prefix, key[:diffidx]...))
t.insert(st.children[0], key[diffidx:], value, append(prefix, key[:diffidx]...))
return
}
// Save the original part. Depending if the break is
Expand All @@ -190,14 +193,14 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// extension. The path prefix of the newly-inserted
// extension should also contain the different byte.
n = newExt(st.key[diffidx+1:], st.children[0])
stack.hash(n, append(prefix, st.key[:diffidx+1]...))
t.hash(n, append(prefix, st.key[:diffidx+1]...))
} else {
// Break on the last byte, no need to insert
// an extension node: reuse the current node.
// The path prefix of the original part should
// still be same.
n = st.children[0]
stack.hash(n, append(prefix, st.key...))
t.hash(n, append(prefix, st.key...))
}
var p *stNode
if diffidx == 0 {
Expand All @@ -206,13 +209,13 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// a branch node.
st.children[0] = nil
p = st
st.nodeType = branchNode
st.typ = branchNode
} else {
// the common prefix is at least one byte
// long, insert a new intermediate branch
// node.
st.children[0] = stPool.Get().(*stNode)
st.children[0].nodeType = branchNode
st.children[0].typ = branchNode
p = st.children[0]
}
// Create a leaf for the inserted part
Expand Down Expand Up @@ -245,15 +248,15 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
var p *stNode
if diffidx == 0 {
// Convert current leaf into a branch
st.nodeType = branchNode
st.typ = branchNode
p = st
st.children[0] = nil
} else {
// Convert current node into an ext,
// and insert a child branch node.
st.nodeType = extNode
st.typ = extNode
st.children[0] = stPool.Get().(*stNode)
st.children[0].nodeType = branchNode
st.children[0].typ = branchNode
p = st.children[0]
}

Expand All @@ -262,7 +265,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// is hashed directly in order to free up some memory.
origIdx := st.key[diffidx]
p.children[origIdx] = newLeaf(st.key[diffidx+1:], st.val)
stack.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...))
t.hash(p.children[origIdx], append(prefix, st.key[:diffidx+1]...))

newIdx := key[diffidx]
p.children[newIdx] = newLeaf(key[diffidx+1:], value)
Expand All @@ -273,7 +276,7 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
st.val = nil

case emptyNode: /* Empty */
st.nodeType = leafNode
st.typ = leafNode
st.key = key
st.val = value

Expand All @@ -296,18 +299,18 @@ func (stack *StackTrie) insert(st *stNode, key, value []byte, prefix []byte) {
// - And the 'st.type' will be 'hashedNode' AGAIN
//
// This method also sets 'st.type' to hashedNode, and clears 'st.key'.
func (stack *StackTrie) hash(st *stNode, path []byte) {
func (t *StackTrie) hash(st *stNode, path []byte) {
// The switch below sets this to the RLP-encoding of this node.
var encodedNode []byte

switch st.nodeType {
switch st.typ {
case hashedNode:
return

case emptyNode:
st.val = types.EmptyRootHash.Bytes()
st.key = st.key[:0]
st.nodeType = hashedNode
st.typ = hashedNode
return

case branchNode:
Expand All @@ -317,51 +320,46 @@ func (stack *StackTrie) hash(st *stNode, path []byte) {
nodes.Children[i] = nilValueNode
continue
}
stack.hash(child, append(path, byte(i)))
t.hash(child, append(path, byte(i)))

if len(child.val) < 32 {
nodes.Children[i] = rawNode(child.val)
} else {
nodes.Children[i] = hashNode(child.val)
}

// Release child back to pool.
st.children[i] = nil
stPool.Put(child.Reset())
stPool.Put(child.reset()) // Release child back to pool.
}

nodes.encode(stack.h.encbuf)
encodedNode = stack.h.encodedBytes()
nodes.encode(t.h.encbuf)
encodedNode = t.h.encodedBytes()

case extNode:
stack.hash(st.children[0], append(path, st.key...))
t.hash(st.children[0], append(path, st.key...))

n := shortNode{Key: hexToCompactInPlace(st.key)}
if len(st.children[0].val) < 32 {
n.Val = rawNode(st.children[0].val)
} else {
n.Val = hashNode(st.children[0].val)
}
n.encode(t.h.encbuf)
encodedNode = t.h.encodedBytes()

n.encode(stack.h.encbuf)
encodedNode = stack.h.encodedBytes()

// Release child back to pool.
stPool.Put(st.children[0].Reset())

stPool.Put(st.children[0].reset()) // Release child back to pool.
st.children[0] = nil

case leafNode:
st.key = append(st.key, byte(16))
n := shortNode{Key: hexToCompactInPlace(st.key), Val: valueNode(st.val)}

n.encode(stack.h.encbuf)
encodedNode = stack.h.encodedBytes()
n.encode(t.h.encbuf)
encodedNode = t.h.encodedBytes()

default:
panic("invalid node type")
}

st.nodeType = hashedNode
st.typ = hashedNode
st.key = st.key[:0]
if len(encodedNode) < 32 {
st.val = common.CopyBytes(encodedNode)
Expand All @@ -370,26 +368,26 @@ func (stack *StackTrie) hash(st *stNode, path []byte) {

// Write the hash to the 'val'. We allocate a new val here to not mutate
// input values
st.val = stack.h.hashData(encodedNode)
if stack.writeFn != nil {
stack.writeFn(stack.owner, path, common.BytesToHash(st.val), encodedNode)
st.val = t.h.hashData(encodedNode)
if t.writeFn != nil {
t.writeFn(t.owner, path, common.BytesToHash(st.val), encodedNode)
}
}

// Hash returns the hash of the current node.
func (stack *StackTrie) Hash() (h common.Hash) {
st := stack.root
stack.hash(st, nil)
func (t *StackTrie) Hash() (h common.Hash) {
st := t.root
t.hash(st, nil)
if len(st.val) == 32 {
copy(h[:], st.val)
return h
}
// If the node's RLP isn't 32 bytes long, the node will not
// be hashed, and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing.
stack.h.sha.Reset()
stack.h.sha.Write(st.val)
stack.h.sha.Read(h[:])
t.h.sha.Reset()
t.h.sha.Write(st.val)
t.h.sha.Read(h[:])
return h
}

Expand All @@ -400,23 +398,23 @@ func (stack *StackTrie) Hash() (h common.Hash) {
//
// The associated database is expected, otherwise the whole commit
// functionality should be disabled.
func (stack *StackTrie) Commit() (h common.Hash, err error) {
if stack.writeFn == nil {
func (t *StackTrie) Commit() (h common.Hash, err error) {
if t.writeFn == nil {
return common.Hash{}, ErrCommitDisabled
}
st := stack.root
stack.hash(st, nil)
st := t.root
t.hash(st, nil)
if len(st.val) == 32 {
copy(h[:], st.val)
return h, nil
}
// If the node's RLP isn't 32 bytes long, the node will not
// be hashed (and committed), and instead contain the rlp-encoding of the
// node. For the top level node, we need to force the hashing+commit.
stack.h.sha.Reset()
stack.h.sha.Write(st.val)
stack.h.sha.Read(h[:])
t.h.sha.Reset()
t.h.sha.Write(st.val)
t.h.sha.Read(h[:])

stack.writeFn(stack.owner, nil, h, st.val)
t.writeFn(t.owner, nil, h, st.val)
return h, nil
}
Loading