diff --git a/trie/proof.go b/trie/proof.go index b5bc8e435..fa1426cd7 100644 --- a/trie/proof.go +++ b/trie/proof.go @@ -42,7 +42,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ctxcdb.KeyValueWriter) for len(key) > 0 && tn != nil { switch n := tn.(type) { case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + if !bytes.HasPrefix(key, n.Key) { // The trie doesn't contain the key. tn = nil } else { @@ -359,8 +359,8 @@ func unset(parent node, child node, key []byte, pos int, removeLeft bool) error } return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft) case *shortNode: - if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) { - // Find the fork point, it's an non-existent branch. + if !bytes.HasPrefix(key[pos:], cld.Key) { + // Find the fork point, it's a non-existent branch. if removeLeft { if bytes.Compare(cld.Key, key[pos:]) < 0 { // The key of fork shortnode is less than the path @@ -420,7 +420,7 @@ func hasRightElement(node node, key []byte) bool { } node, pos = rn.Children[key[pos]], pos+1 case *shortNode: - if len(key)-pos < len(rn.Key) || !bytes.Equal(rn.Key, key[pos:pos+len(rn.Key)]) { + if !bytes.HasPrefix(key[pos:], rn.Key) { return bytes.Compare(rn.Key, key[pos:]) > 0 } node, pos = rn.Val, pos+len(rn.Key) @@ -613,7 +613,7 @@ func get(tn node, key []byte, skipResolved bool) ([]byte, node) { for { switch n := tn.(type) { case *shortNode: - if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { + if !bytes.HasPrefix(key, n.Key) { return nil, nil } tn = n.Val diff --git a/trie/trie.go b/trie/trie.go index 839e2eb31..d7e0f267c 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -151,7 +151,7 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode case valueNode: return n, n, false, nil case *shortNode: - if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { + if !bytes.HasPrefix(key[pos:], n.Key) { // key not found in trie return nil, n, false, nil } @@ -190,10 +190,7 @@ func (t *Trie) TryGetNode(path []byte) ([]byte, int, error) { if resolved > 0 { t.root = newroot } - if item == nil { - return nil, resolved, nil - } - return item, resolved, err + return item, resolved, nil } func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, newnode node, resolved int, err error) { @@ -225,7 +222,7 @@ func (t *Trie) tryGetNode(origNode node, path []byte, pos int) (item []byte, new return nil, nil, 0, nil case *shortNode: - if len(path)-pos < len(n.Key) || !bytes.Equal(n.Key, path[pos:pos+len(n.Key)]) { + if !bytes.HasPrefix(path[pos:], n.Key) { // Path branches off from short node return nil, n, 0, nil }