Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: lib/trie: fix trie.NextKey #1449

Merged
merged 10 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions lib/runtime/storage/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package storage

import (
"bytes"
"encoding/binary"
"sync"

Expand Down Expand Up @@ -101,26 +100,7 @@ func (s *TrieState) Delete(key []byte) {
func (s *TrieState) NextKey(key []byte) []byte {
s.lock.RLock()
defer s.lock.RUnlock()
keys := s.t.GetKeysWithPrefix([]byte{})

for i, k := range keys {
if i == len(keys)-1 {
return nil
}

if bytes.Equal(key, k) {
return keys[i+1]
}

// `keys` is already is lexigraphical order, so if k is greater than `key`, it's next
if bytes.Compare(k, key) == 1 {
return k
}
}

// TODO: fix this!!
//return s.t.NextKey(key)
return nil
return s.t.NextKey(key)
}

// ClearPrefix deletes all key-value pairs from the trie where the key starts with the given prefix
Expand Down
4 changes: 3 additions & 1 deletion lib/runtime/storage/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sort"
"testing"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/trie"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -151,6 +152,7 @@ func TestTrieState_NextKey(t *testing.T) {
ts.Set([]byte(tc), []byte(tc))
}

t.Log(ts.t)
sort.Slice(testCases, func(i, j int) bool {
return bytes.Compare([]byte(testCases[i]), []byte(testCases[j])) == -1
})
Expand All @@ -160,7 +162,7 @@ func TestTrieState_NextKey(t *testing.T) {
if i == len(testCases)-1 {
require.Nil(t, next)
} else {
require.Equal(t, []byte(testCases[i+1]), next)
require.Equal(t, []byte(testCases[i+1]), next, common.BytesToHex([]byte(tc)))
}
}
}
4 changes: 0 additions & 4 deletions lib/runtime/wasmer/exports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,6 @@ func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock901442(t *testing.T) {
}

func TestInstance_ExecuteBlock_KusamaRuntime_KusamaBlock1377831(t *testing.T) {
if testing.Short() {
t.Skip("this test takes around 3min to run at the moment")
}

ksmTrie := newTrieFromPairs(t, "../test_data/block1377830_kusama.out")
expectedRoot := common.MustHexToHash("0xe4de6fecda9e9e35f937d159665cf984bc1a68048b6c78912de0aeb6bd7f7e99")
require.Equal(t, expectedRoot, ksmTrie.MustHash())
Expand Down
6 changes: 0 additions & 6 deletions lib/runtime/wasmer/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -1632,12 +1632,6 @@ func ext_storage_clear_prefix_version_1(context unsafe.Pointer, prefixSpan C.int
if err != nil {
logger.Error("[ext_storage_clear_prefix_version_1]", "error", err)
}

// sanity check
next := storage.NextKey(prefix)
if len(next) >= len(prefix) && bytes.Equal(prefix, next[:len(prefix)]) {
panic("did not clear prefix")
}
}

//export ext_storage_exists_version_1
Expand Down
99 changes: 40 additions & 59 deletions lib/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,95 +152,76 @@ func (t *Trie) entries(current node, prefix []byte, kv map[string][]byte) map[st
func (t *Trie) NextKey(key []byte) []byte {
k := keyToNibbles(key)

next := t.nextKey([]node{}, t.root, nil, k)
next := t.nextKey(t.root, nil, k)
if next == nil {
return nil
}

return nibblesToKeyLE(next)
}

func (t *Trie) nextKey(ancestors []node, current node, prefix, target []byte) []byte {
switch c := current.(type) {
func (t *Trie) nextKey(curr node, prefix, key []byte) []byte {
switch c := curr.(type) {
case *branch:
fullKey := append(prefix, c.key...)
var cmp int
if len(key) < len(fullKey) {
// the key is lexicographically less than the current node key. return first key available
cmp = 1
} else {
// if cmp == 1, then node key is lexicographically greater than the key arg
cmp = bytes.Compare(fullKey, key[:len(fullKey)])
}

// length of key arg is less than branch key, return key of first child (or key of this branch, if it's a branch w/ value)
if (cmp == 0 && len(key) == len(fullKey)) || cmp == 1 {
if c.value != nil && bytes.Compare(fullKey, key) > 0 {
return fullKey
}

if bytes.Equal(target, fullKey) {
for i, child := range c.children {
if child == nil {
continue
}

// descend and return first key
return returnFirstKey(append(fullKey, byte(i)), child)
next := t.nextKey(child, append(fullKey, byte(i)), key)
if len(next) != 0 {
return next
}
}
}

if len(target) >= len(fullKey) && bytes.Equal(target[:len(fullKey)], fullKey) {
for i, child := range c.children {
if child == nil || byte(i) != target[len(fullKey)] {
// node key isn't greater than the arg key, continue to iterate
if cmp < 1 && len(key) > len(fullKey) {
idx := key[len(fullKey)]
for i, child := range c.children[idx:] {
if child == nil {
continue
}

return t.nextKey(append([]node{c}, ancestors...), child, append(fullKey, byte(i)), target)
next := t.nextKey(child, append(fullKey, byte(i)+idx), key)
if len(next) != 0 {
return next
}
}
}
case *leaf:
fullKey := append(prefix, c.key...)

if bytes.Equal(target, fullKey) {
// ancestors are all branches, find one with another child w/ index greater than ours
for _, anc := range ancestors {
// index of the current node in its parent branch
myIdx := prefix[len(prefix)-1]

br, ok := anc.(*branch)
if !ok {
return nil
}

prefix = prefix[:len(prefix)-len(br.key)-1]

if br.childrenBitmap()>>(myIdx+1) == 0 {
continue
}

// descend into ancestor's other children
for i, child := range br.children[myIdx+1:] {
idx := byte(i) + myIdx + 1

if child == nil {
continue
}

return returnFirstKey(append(prefix, append(br.key, idx)...), child)
}
}
var cmp int
if len(key) < len(fullKey) {
// the key is lexicographically less than the current node key. return first key available
cmp = 1
} else {
// if cmp == 1, then node key is lexicographically greater than the key arg
cmp = bytes.Compare(fullKey, key[:len(fullKey)])
}
}

return nil
}

// returnFirstKey descends into a node and returns the first key with an associated value
func returnFirstKey(prefix []byte, n node) []byte {
switch c := n.(type) {
case *branch:
if c.value != nil {
if cmp == 1 {
return append(prefix, c.key...)
}

for i, child := range c.children {
if child == nil {
continue
}

return returnFirstKey(append(prefix, append(c.key, byte(i))...), child)
}
case *leaf:
return append(prefix, c.key...)
case nil:
return nil
}

return nil
}

Expand Down
113 changes: 112 additions & 1 deletion lib/trie/trie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"math/rand"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"testing"
Expand Down Expand Up @@ -634,6 +635,7 @@ func TestNextKey_MoreAncestors(t *testing.T) {
{key: []byte{0x01, 0x35, 0x79}, value: []byte("gnocchi"), op: PUT},
{key: []byte{0x01, 0x35, 0x79, 0xab}, value: []byte("spaghetti"), op: PUT},
{key: []byte{0x01, 0x35, 0x79, 0xab, 0x9}, value: []byte("gnocchi"), op: PUT},
{key: []byte{0x01, 0x35, 0x79, 0xab, 0xf}, value: []byte("gnocchi"), op: PUT},
{key: []byte{0x07, 0x3a}, value: []byte("ramen"), op: PUT},
{key: []byte{0x07, 0x3b}, value: []byte("noodles"), op: PUT},
{key: []byte{0xf2}, value: []byte("pho"), op: PUT},
Expand Down Expand Up @@ -673,13 +675,67 @@ func TestNextKey_MoreAncestors(t *testing.T) {
},
{
tests[6].key,
tests[7].key,
},
{
tests[7].key,
nil,
},
{
[]byte{},
tests[0].key,
},
{
[]byte{0},
tests[0].key,
},
{
[]byte{0x01},
tests[0].key,
},
{
[]byte{0x02},
tests[5].key,
},
{
[]byte{0x05, 0x12, 0x34},
tests[5].key,
},
{
[]byte{0xf},
tests[7].key,
},
}

for _, tc := range testCases {
next := trie.NextKey(tc.input)
require.Equal(t, tc.expected, next)
require.Equal(t, tc.expected, next, common.BytesToHex(tc.input))
}
}

func TestNextKey_Again(t *testing.T) {
trie := NewEmptyTrie()

var testCases = []string{
"asdf",
"bnm",
"ghjk",
"qwerty",
"uiopl",
"zxcv",
}

for _, tc := range testCases {
trie.Put([]byte(tc), []byte(tc))
}

for i, tc := range testCases {
next := trie.NextKey([]byte(tc))
if i == len(testCases)-1 {
require.Nil(t, next)
} else {
require.Equal(t, []byte(testCases[i+1]), next, common.BytesToHex([]byte(tc)))
}
}
}

Expand Down Expand Up @@ -939,3 +995,58 @@ func TestSnapshot(t *testing.T) {
require.Equal(t, expectedTrie.MustHash(), newTrie.MustHash())
require.NotEqual(t, parentSnapshot.MustHash(), newTrie.MustHash())
}

const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

func RandStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}

func TestNextKey_Random(t *testing.T) {
for i := 0; i < 100; i++ {
trie := NewEmptyTrie()

// Generate random test cases.
testCaseMap := make(map[string]struct{}) // ensure no duplicate keys
size := 1000 + rand.Intn(10000)

for ii := 0; ii < size; ii++ {
str := RandStringBytes(1 + rand.Intn(20))
if len(str) == 0 {
continue
}
testCaseMap[str] = struct{}{}
}

testCases := make([][]byte, len(testCaseMap))
j := 0

for k := range testCaseMap {
testCases[j] = []byte(k)
j++
}

sort.Slice(testCases, func(i, j int) bool {
return bytes.Compare(testCases[i], testCases[j]) < 0
})

for _, tc := range testCases {
trie.Put(tc, tc)
}

fmt.Println("Iteration: ", i)

for idx, tc := range testCases {
next := trie.NextKey(tc)
if idx == len(testCases)-1 {
require.Nil(t, next)
} else {
require.Equal(t, testCases[idx+1], next, common.BytesToHex(tc))
}
}
}
}