diff --git a/dot/rpc/modules/state_integration_test.go b/dot/rpc/modules/state_integration_test.go index 86ffab0807..04cd189554 100644 --- a/dot/rpc/modules/state_integration_test.go +++ b/dot/rpc/modules/state_integration_test.go @@ -105,19 +105,30 @@ func TestStateModule_GetPairs(t *testing.T) { randomHash, err := common.HexToHash(RandomHash) require.NoError(t, err) + hexEncode := func(s string) string { + return "0x" + hex.EncodeToString([]byte(s)) + } + testCases := []struct { params []string expected []interface{} errMsg string }{ {params: []string{"0x00"}, expected: nil}, - {params: []string{""}, expected: []interface{}{[]string{":key1", "value1"}, []string{":key2", "value2"}}}, - {params: []string{":key1"}, expected: []interface{}{[]string{":key1", "value1"}}}, + {params: []string{""}, expected: []interface{}{ + []string{hexEncode(":child_storage:default::child1"), + "0x8f733acc98dff0e6527f97e2a87e4834cd8b2e601f56fb003084e9d43183d7ff"}, + []string{hexEncode(":key1"), hexEncode("value1")}, + []string{hexEncode(":key2"), hexEncode("value2")}}}, + {params: []string{hexEncode(":key1")}, expected: []interface{}{[]string{hexEncode(":key1"), hexEncode("value1")}}}, {params: []string{"0x00", hash.String()}, expected: nil}, {params: []string{"", hash.String()}, expected: []interface{}{ - []string{":key1", "value1"}, - []string{":key2", "value2"}}}, - {params: []string{":key1", hash.String()}, expected: []interface{}{[]string{":key1", "value1"}}}, + []string{hexEncode(":child_storage:default::child1"), + "0x8f733acc98dff0e6527f97e2a87e4834cd8b2e601f56fb003084e9d43183d7ff"}, + []string{hexEncode(":key1"), hexEncode("value1")}, + []string{hexEncode(":key2"), hexEncode("value2")}}}, + {params: []string{hexEncode(":key1"), hash.String()}, + expected: []interface{}{[]string{hexEncode(":key1"), hexEncode("value1")}}}, {params: []string{"", randomHash.String()}, errMsg: "pebble: not found"}, } @@ -134,6 +145,7 @@ func TestStateModule_GetPairs(t *testing.T) { if len(test.params) > 1 && test.params[1] != "" { req.Bhash = &common.Hash{} + var err error *req.Bhash, err = common.HexToHash(test.params[1]) require.NoError(t, err) } @@ -160,10 +172,7 @@ func TestStateModule_GetPairs(t *testing.T) { // Convert human-readable result value to hex. expectedKV, _ := val.([]string) - expectedKey := "0x" + hex.EncodeToString([]byte(expectedKV[0])) - expectedVal := "0x" + hex.EncodeToString([]byte(expectedKV[1])) - - require.Equal(t, []string{expectedKey, expectedVal}, kv) + require.Equal(t, []string{expectedKV[0], expectedKV[1]}, kv) } }) } @@ -446,17 +455,22 @@ func TestStateModule_GetKeysPaged(t *testing.T) { params: StateStorageKeyRequest{ Qty: 10, Block: nil, - }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + }, expected: []string{ + "0x3a6368696c645f73746f726167653a64656661756c743a3a6368696c6431", + "0x3a6b657931", "0x3a6b657932"}}, {name: "allKeysTestBlockHash", params: StateStorageKeyRequest{ Qty: 10, Block: stateRootHash, - }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + }, expected: []string{ + "0x3a6368696c645f73746f726167653a64656661756c743a3a6368696c6431", + "0x3a6b657931", "0x3a6b657932"}}, {name: "prefixMatchAll", params: StateStorageKeyRequest{ Prefix: "0x3a6b6579", Qty: 10, - }, expected: []string{"0x3a6b657931", "0x3a6b657932"}}, + }, expected: []string{ + "0x3a6b657931", "0x3a6b657932"}}, {name: "prefixMatchOne", params: StateStorageKeyRequest{ Prefix: "0x3a6b657931", @@ -470,7 +484,7 @@ func TestStateModule_GetKeysPaged(t *testing.T) { {name: "qtyOne", params: StateStorageKeyRequest{ Qty: 1, - }, expected: []string{"0x3a6b657931"}}, + }, expected: []string{"0x3a6368696c645f73746f726167653a64656661756c743a3a6368696c6431"}}, {name: "afterKey", params: StateStorageKeyRequest{ Qty: 10, @@ -551,9 +565,14 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { ts, err := chain.Storage.TrieState(nil) require.NoError(t, err) - ts.Put([]byte(`:key2`), []byte(`value2`)) - ts.Put([]byte(`:key1`), []byte(`value1`)) - ts.SetChildStorage([]byte(`:child1`), []byte(`:key1`), []byte(`:childValue1`)) + err = ts.Put([]byte(`:key2`), []byte(`value2`)) + require.NoError(t, err) + + err = ts.Put([]byte(`:key1`), []byte(`value1`)) + require.NoError(t, err) + + err = ts.SetChildStorage([]byte(`:child1`), []byte(`:key1`), []byte(`:childValue1`)) + require.NoError(t, err) sr1, err := ts.Root() require.NoError(t, err) diff --git a/lib/runtime/wazero/imports_test.go b/lib/runtime/wazero/imports_test.go index 80f09ea38c..3ea6e2274a 100644 --- a/lib/runtime/wazero/imports_test.go +++ b/lib/runtime/wazero/imports_test.go @@ -812,31 +812,89 @@ func Test_ext_default_child_storage_read_version_1(t *testing.T) { } func Test_ext_default_child_storage_set_version_1(t *testing.T) { - inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + cases := map[string]struct { + setupInstance func(*testing.T) *Instance + existsBeforehand bool + }{ + "child_trie_exists_should_not_panic": { + existsBeforehand: true, + setupInstance: func(t *testing.T) *Instance { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) + err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) + require.NoError(t, err) - // Check if value is not set - val, err := inst.Context.Storage.GetChildStorage(testChildKey, testKey) - require.NoError(t, err) - require.Nil(t, val) + return inst + }, + }, + "child_trie_not_found_should_create_a_empty_one": { + existsBeforehand: false, + setupInstance: func(t *testing.T) *Instance { + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + return inst + }, + }, + } - encChildKey, err := scale.Marshal(testChildKey) - require.NoError(t, err) + insertKeyAndValue := func(t *testing.T, inst *Instance, childKey, key, value []byte) { + encChildKey, err := scale.Marshal(childKey) + require.NoError(t, err) - encKey, err := scale.Marshal(testKey) - require.NoError(t, err) + encKey, err := scale.Marshal(key) + require.NoError(t, err) - encVal, err := scale.Marshal(testValue) - require.NoError(t, err) + encVal, err := scale.Marshal(value) + require.NoError(t, err) - _, err = inst.Exec("rtm_ext_default_child_storage_set_version_1", append(append(encChildKey, encKey...), encVal...)) - require.NoError(t, err) + args := bytes.Join([][]byte{ + encChildKey, encKey, encVal, + }, nil) - val, err = inst.Context.Storage.GetChildStorage(testChildKey, testKey) - require.NoError(t, err) - require.Equal(t, testValue, val) + _, err = inst.Exec("rtm_ext_default_child_storage_set_version_1", args) + require.NoError(t, err) + } + + getValueFromChildStorage := func(t *testing.T, inst *Instance, childKey, key []byte) *[]byte { + encChildKey, err := scale.Marshal(childKey) + require.NoError(t, err) + + encKey, err := scale.Marshal(key) + require.NoError(t, err) + + ret, err := inst.Exec("rtm_ext_default_child_storage_get_version_1", append(encChildKey, encKey...)) + require.NoError(t, err) + + var retrieved *[]byte + err = scale.Unmarshal(ret, &retrieved) + require.NoError(t, err) + + return retrieved + } + + for tname, tt := range cases { + tt := tt + + t.Run(tname, func(t *testing.T) { + inst := tt.setupInstance(t) + + exampleChildKey := []byte("example_child_key") + exampleKey := []byte("key_to_account") + exampleValue := []byte("some_acc_address") + + insertKeyAndValue(t, inst, exampleChildKey, exampleKey, exampleValue) + + anotherKey := []byte("key_to_account_2") + anotherValue := []byte("some_acc_address_2") + insertKeyAndValue(t, inst, exampleChildKey, anotherKey, anotherValue) + + // should be possible to retrieve the first address and the new inserted one + acc1 := getValueFromChildStorage(t, inst, exampleChildKey, exampleKey) + require.Equal(t, &exampleValue, acc1) + + acc2 := getValueFromChildStorage(t, inst, exampleChildKey, anotherKey) + require.Equal(t, &anotherValue, acc2) + }) + } } func Test_ext_default_child_storage_clear_version_1(t *testing.T) { diff --git a/lib/trie/child_storage.go b/lib/trie/child_storage.go index 78506eaa8c..498a7a3137 100644 --- a/lib/trie/child_storage.go +++ b/lib/trie/child_storage.go @@ -55,7 +55,11 @@ func (t *Trie) GetChild(keyToChild []byte) (*Trie, error) { func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { child, err := t.GetChild(keyToChild) if err != nil { - return err + if errors.Is(err, ErrChildTrieDoesNotExist) { + child = NewEmptyTrie() + } else { + return fmt.Errorf("getting child: %w", err) + } } origChildHash, err := child.Hash() @@ -68,14 +72,7 @@ func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { return fmt.Errorf("putting into child trie located at key 0x%x: %w", keyToChild, err) } - childHash, err := child.Hash() - if err != nil { - return err - } - delete(t.childTries, origChildHash) - t.childTries[childHash] = child - return t.SetChild(keyToChild, child) }