Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core/state: better randomized testing (postcheck) on journalling #29627

Merged
merged 7 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions core/state/access_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
package state

import (
"fmt"
"maps"
"slices"
"strings"

"github.com/ethereum/go-ethereum/common"
)
Expand Down Expand Up @@ -130,3 +133,35 @@ func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) {
func (al *accessList) DeleteAddress(address common.Address) {
delete(al.addresses, address)
}

// Equal returns true if the two access lists are identical
func (al *accessList) Equal(other *accessList) bool {
if !maps.Equal(al.addresses, other.addresses) {
return false
}
return slices.EqualFunc(al.slots, other.slots,
func(m map[common.Hash]struct{}, m2 map[common.Hash]struct{}) bool {
return maps.Equal(m, m2)
})
}

holiman marked this conversation as resolved.
Show resolved Hide resolved
// PrettyPrint prints the contents of the access list in a human-readable form
func (al *accessList) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range al.addresses {
sortedAddrs = append(sortedAddrs, addr)
}
slices.SortFunc(sortedAddrs, common.Address.Cmp)
for _, addr := range sortedAddrs {
idx := al.addresses[addr]
fmt.Fprintf(out, "%#x : (idx %d)\n", addr, idx)
if idx >= 0 {
slotmap := al.slots[idx]
for h := range slotmap {
fmt.Fprintf(out, " %#x\n", h)
}
}
}
return out.String()
}
24 changes: 12 additions & 12 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,22 +459,22 @@ func (s *stateObject) setBalance(amount *uint256.Int) {

func (s *stateObject) deepCopy(db *StateDB) *stateObject {
obj := &stateObject{
db: db,
address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
db: db,
address: s.address,
addrHash: s.addrHash,
origin: s.origin,
data: s.data,
code: s.code,
originStorage: s.originStorage.Copy(),
pendingStorage: s.pendingStorage.Copy(),
dirtyStorage: s.dirtyStorage.Copy(),
dirtyCode: s.dirtyCode,
selfDestructed: s.selfDestructed,
newContract: s.newContract,
}
if s.trie != nil {
obj.trie = db.db.CopyTrie(s.trie)
}
obj.code = s.code
obj.originStorage = s.originStorage.Copy()
obj.pendingStorage = s.pendingStorage.Copy()
obj.dirtyStorage = s.dirtyStorage.Copy()
obj.dirtyCode = s.dirtyCode
obj.selfDestructed = s.selfDestructed
obj.newContract = s.newContract
return obj
}

Expand Down
68 changes: 66 additions & 2 deletions core/state/statedb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import (
"encoding/binary"
"errors"
"fmt"
"maps"
"math"
"math/rand"
"reflect"
"slices"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -557,10 +559,14 @@ func forEachStorage(s *StateDB, addr common.Address, cb func(key, value common.H
if err != nil {
return err
}
it := trie.NewIterator(trieIt)
var (
it = trie.NewIterator(trieIt)
visited = make(map[common.Hash]bool)
)

for it.Next() {
key := common.BytesToHash(s.trie.GetKey(it.Key))
visited[key] = true
if value, dirty := so.dirtyStorage[key]; dirty {
if !cb(key, value) {
return nil
Expand Down Expand Up @@ -600,6 +606,10 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
// Check newContract-flag
if obj := state.getStateObject(addr); obj != nil {
checkeq("IsNewContract", obj.newContract, checkstate.getStateObject(addr).newContract)
}
// Check storage.
if obj := state.getStateObject(addr); obj != nil {
forEachStorage(state, addr, func(key, value common.Hash) bool {
Expand All @@ -608,12 +618,49 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
forEachStorage(checkstate, addr, func(key, value common.Hash) bool {
return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
})
other := checkstate.getStateObject(addr)
// Check dirty storage which is not in trie
if !maps.Equal(obj.dirtyStorage, other.dirtyStorage) {
print := func(dirty map[common.Hash]common.Hash) string {
var keys []common.Hash
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Hash.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v %v\n", i, key, dirty[key])
}
return out.String()
}
return fmt.Errorf("dirty storage err, have\n%v\nwant\n%v",
print(obj.dirtyStorage),
print(other.dirtyStorage))
}
}
// Check transient storage.
{
have := state.transientStorage
want := checkstate.transientStorage
eq := maps.EqualFunc(have, want,
func(a Storage, b Storage) bool {
return maps.Equal(a, b)
})
if !eq {
return fmt.Errorf("transient storage differs ,have\n%v\nwant\n%v",
have.PrettyPrint(),
want.PrettyPrint())
}
}
if err != nil {
return err
}
}

if !checkstate.accessList.Equal(state.accessList) { // Check access lists
return fmt.Errorf("AccessLists are wrong, have \n%v\nwant\n%v",
checkstate.accessList.PrettyPrint(),
state.accessList.PrettyPrint())
}
if state.GetRefund() != checkstate.GetRefund() {
return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
state.GetRefund(), checkstate.GetRefund())
Expand All @@ -622,6 +669,23 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
state.GetLogs(common.Hash{}, 0, common.Hash{}), checkstate.GetLogs(common.Hash{}, 0, common.Hash{}))
}
if !maps.Equal(state.journal.dirties, checkstate.journal.dirties) {
getKeys := func(dirty map[common.Address]int) string {
var keys []common.Address
out := new(strings.Builder)
for key := range dirty {
keys = append(keys, key)
}
slices.SortFunc(keys, common.Address.Cmp)
for i, key := range keys {
fmt.Fprintf(out, " %d. %v\n", i, key)
}
return out.String()
}
have := getKeys(state.journal.dirties)
want := getKeys(checkstate.journal.dirties)
return fmt.Errorf("dirty-journal set mismatch.\nhave:\n%v\nwant:\n%v\n", have, want)
}
return nil
}

Expand Down
43 changes: 40 additions & 3 deletions core/state/transient_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
package state

import (
"fmt"
"slices"
"strings"

"github.com/ethereum/go-ethereum/common"
)

Expand All @@ -30,10 +34,19 @@ func newTransientStorage() transientStorage {

// Set sets the transient-storage `value` for `key` at the given `addr`.
func (t transientStorage) Set(addr common.Address, key, value common.Hash) {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
if value == (common.Hash{}) { // this is a 'delete'
if _, ok := t[addr]; ok {
delete(t[addr], key)
if len(t[addr]) == 0 {
delete(t, addr)
}
}
} else {
if _, ok := t[addr]; !ok {
t[addr] = make(Storage)
}
t[addr][key] = value
}
t[addr][key] = value
}

// Get gets the transient storage for `key` at the given `addr`.
Expand All @@ -53,3 +66,27 @@ func (t transientStorage) Copy() transientStorage {
}
return storage
}

// PrettyPrint prints the contents of the access list in a human-readable form
func (t transientStorage) PrettyPrint() string {
out := new(strings.Builder)
var sortedAddrs []common.Address
for addr := range t {
sortedAddrs = append(sortedAddrs, addr)
slices.SortFunc(sortedAddrs, common.Address.Cmp)
}

for _, addr := range sortedAddrs {
fmt.Fprintf(out, "%#x:", addr)
var sortedKeys []common.Hash
storage := t[addr]
for key := range storage {
sortedKeys = append(sortedKeys, key)
}
slices.SortFunc(sortedKeys, common.Hash.Cmp)
for _, key := range sortedKeys {
fmt.Fprintf(out, " %X : %X\n", key, storage[key])
}
}
return out.String()
}
Loading