diff --git a/go.mod b/go.mod index 82cc0151152..8dfd6c58750 100644 --- a/go.mod +++ b/go.mod @@ -200,4 +200,4 @@ retract ( replace github.com/cometbft/cometbft => github.com/dydxprotocol/cometbft v0.38.6-0.20240220185844-e704122c8540 -replace cosmossdk.io/store => github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240515175455-8168b4407fac +replace cosmossdk.io/store => github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240326190927-d35618165018 diff --git a/go.sum b/go.sum index ce29d3d21be..7e5e9d5ee5b 100644 --- a/go.sum +++ b/go.sum @@ -199,8 +199,8 @@ github.com/dvsekhvalnov/jose2go v1.6.0 h1:Y9gnSnP4qEI0+/uQkHvFXeD2PLPJeXEL+ySMEA github.com/dvsekhvalnov/jose2go v1.6.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/dydxprotocol/cometbft v0.38.6-0.20240220185844-e704122c8540 h1:pkYQbAdOAAoZBSId9kLupCgZHj8YvA9LzM31fVYpjlw= github.com/dydxprotocol/cometbft v0.38.6-0.20240220185844-e704122c8540/go.mod h1:REQN+ObgfYxi39TcYR/Hv95C9bPxY3sYJCvghryj7vY= -github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240515175455-8168b4407fac h1:frUaYZlrs9/Tk8fAHjMhcrpk73UEZ36fD7s+megReKQ= -github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240515175455-8168b4407fac/go.mod h1:zMcD3hfNwd0WMTpdRUhS3QxoCoEtBXWeoKsu3iaLBbQ= +github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240326190927-d35618165018 h1:Dn08pzQTajFp1GHaZFd0istbjl793PaT50vfj4mVKNs= +github.com/dydxprotocol/cosmos-sdk/store v1.0.3-0.20240326190927-d35618165018/go.mod h1:zMcD3hfNwd0WMTpdRUhS3QxoCoEtBXWeoKsu3iaLBbQ= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= diff --git a/store/cachemulti/locking_test.go b/store/cachemulti/locking_test.go deleted file mode 100644 index d590890e42c..00000000000 --- a/store/cachemulti/locking_test.go +++ /dev/null @@ -1,363 +0,0 @@ -package cachemulti_test - -import ( - "sync" - "testing" - "time" - - "cosmossdk.io/log" - "cosmossdk.io/store/metrics" - pruningtypes "cosmossdk.io/store/pruning/types" - "cosmossdk.io/store/rootmulti" - "cosmossdk.io/store/types" - dbm "github.com/cosmos/cosmos-db" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStore_LinearizeReadsAndWrites(t *testing.T) { - key := []byte("kv_store_key") - storeKey := types.NewKVStoreKey("store1") - lockKey := []byte("a") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - wg := sync.WaitGroup{} - wg.Add(100) - for i := 0; i < 100; i++ { - go func() { - defer wg.Done() - - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - kvStore := lockingCms.GetKVStore(storeKey) - v := kvStore.Get(key) - if v == nil { - kvStore.Set(key, []byte{1}) - } else { - v[0]++ - kvStore.Set(key, v) - } - lockingCms.Write() - }() - } - - wg.Wait() - require.Equal(t, []byte{100}, lockingCms.GetKVStore(storeKey).Get(key)) -} - -func TestStore_LockOrderToPreventDeadlock(t *testing.T) { - key := []byte("kv_store_key") - storeKey := types.NewKVStoreKey("store1") - lockKeyA := []byte("a") - lockKeyB := []byte("b") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - // Acquire keys in two different orders ensuring that we don't reach deadlock. - wg := sync.WaitGroup{} - wg.Add(200) - for i := 0; i < 100; i++ { - go func() { - defer wg.Done() - - lockingCms.Lock([][]byte{lockKeyA, lockKeyB}) - defer lockingCms.Unlock([][]byte{lockKeyA, lockKeyB}) - kvStore := lockingCms.GetKVStore(storeKey) - v := kvStore.Get(key) - if v == nil { - kvStore.Set(key, []byte{1}) - } else { - v[0]++ - kvStore.Set(key, v) - } - lockingCms.Write() - }() - - go func() { - defer wg.Done() - - lockingCms.Lock([][]byte{lockKeyB, lockKeyA}) - defer lockingCms.Unlock([][]byte{lockKeyB, lockKeyA}) - kvStore := lockingCms.GetKVStore(storeKey) - v := kvStore.Get(key) - if v == nil { - kvStore.Set(key, []byte{1}) - } else { - v[0]++ - kvStore.Set(key, v) - } - lockingCms.Write() - }() - } - - wg.Wait() - require.Equal(t, []byte{200}, lockingCms.GetKVStore(storeKey).Get(key)) -} - -func TestStore_AllowForParallelUpdates(t *testing.T) { - storeKey := types.NewKVStoreKey("store1") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - wg := sync.WaitGroup{} - wg.Add(100) - - for i := byte(0); i < 100; i++ { - k := []byte{i} - go func() { - defer wg.Done() - - // We specifically don't unlock the keys during processing so that we can show that we must process all - // of these in parallel before the wait group is done. - lockingCms.Lock([][]byte{k}) - lockingCms.GetKVStore(storeKey).Set(k, k) - lockingCms.Write() - }() - } - - wg.Wait() - for i := byte(0); i < 100; i++ { - lockingCms.Unlock([][]byte{{i}}) - } - for i := byte(0); i < 100; i++ { - require.Equal(t, []byte{i}, lockingCms.GetKVStore(storeKey).Get([]byte{i})) - } -} - -func TestStore_AddLocksDuringTransaction(t *testing.T) { - key := []byte("kv_store_key") - storeKey := types.NewKVStoreKey("store1") - lockKey := []byte("lockkey") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - wg := sync.WaitGroup{} - wg.Add(100) - for i := byte(0); i < 100; i++ { - k := []byte{i} - go func() { - defer wg.Done() - - lockingCms.Lock([][]byte{k}) - defer lockingCms.Unlock([][]byte{k}) - lockingCms.GetKVStore(storeKey).Set(k, k) - - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - kvStore := lockingCms.GetKVStore(storeKey) - v := kvStore.Get(key) - if v == nil { - kvStore.Set(key, []byte{1}) - } else { - v[0]++ - kvStore.Set(key, v) - } - lockingCms.Write() - }() - } - - wg.Wait() - for i := byte(0); i < 100; i++ { - require.Equal(t, []byte{i}, lockingCms.GetKVStore(storeKey).Get([]byte{i})) - } - require.Equal(t, []byte{100}, lockingCms.GetKVStore(storeKey).Get(key)) -} - -func TestStore_MaintainLockOverMultipleTransactions(t *testing.T) { - keyA := []byte("kv_store_key_a") - keyB := []byte("kv_store_key_b") - storeKey := types.NewKVStoreKey("store1") - lockKeyA := []byte("lockkeya") - lockKeyB := []byte("lockkeyb") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - // Key A is set differently in the first and second transaction so we can check it - // to see what transaction was run last. - lockingCms.GetKVStore(storeKey).Set(keyA, []byte{0}) - lockingCms.GetKVStore(storeKey).Set(keyB, []byte{0}) - - wg := sync.WaitGroup{} - wg.Add(100) - for i := byte(0); i < 100; i++ { - k := []byte{i} - go func() { - defer wg.Done() - - lockingCms.Lock([][]byte{k}) - defer lockingCms.Unlock([][]byte{k}) - lockingCms.GetKVStore(storeKey).Set(k, k) - - lockingCms.Lock([][]byte{lockKeyA}) - defer lockingCms.Unlock([][]byte{lockKeyA}) - - func() { - lockingCms.Lock([][]byte{lockKeyB}) - defer lockingCms.Unlock([][]byte{lockKeyB}) - - assert.Equal(t, []byte{0}, lockingCms.GetKVStore(storeKey).Get(keyA)) - lockingCms.GetKVStore(storeKey).Set(keyA, []byte{1}) - v := lockingCms.GetKVStore(storeKey).Get(keyB) - v[0]++ - lockingCms.GetKVStore(storeKey).Set(keyB, v) - lockingCms.Write() - }() - - func() { - lockingCms.Lock([][]byte{lockKeyB}) - defer lockingCms.Unlock([][]byte{lockKeyB}) - - assert.Equal(t, []byte{1}, lockingCms.GetKVStore(storeKey).Get(keyA)) - lockingCms.GetKVStore(storeKey).Set(keyA, []byte{0}) - v := lockingCms.GetKVStore(storeKey).Get(keyB) - v[0]++ - lockingCms.GetKVStore(storeKey).Set(keyB, v) - lockingCms.Write() - }() - }() - } - - wg.Wait() - require.Equal(t, []byte{200}, lockingCms.GetKVStore(storeKey).Get(keyB)) -} - -func TestStore_ReadWriteLock(t *testing.T) { - numReadersKey := []byte("kv_store_key_a") - numWritersKey := []byte("kv_store_key_b") - maxNumReadersKey := []byte("kv_store_key_c") - maxNumWritersKey := []byte("kv_store_key_d") - storeKey := types.NewKVStoreKey("store1") - rwLockKey := []byte("lockkeya") - lockKey := []byte("lockkeyb") - - db := dbm.NewMemDB() - store := rootmulti.NewStore(db, log.NewNopLogger(), metrics.NewNoOpMetrics()) - store.SetPruning(pruningtypes.NewPruningOptions(pruningtypes.PruningNothing)) - store.MountStoreWithDB(storeKey, types.StoreTypeIAVL, db) - err := store.LoadLatestVersion() - assert.NoError(t, err) - lockingCms := store.LockingCacheMultiStore() - - lockingCms.GetKVStore(storeKey).Set(numReadersKey, []byte{0}) - lockingCms.GetKVStore(storeKey).Set(numWritersKey, []byte{0}) - lockingCms.GetKVStore(storeKey).Set(maxNumReadersKey, []byte{0}) - lockingCms.GetKVStore(storeKey).Set(maxNumWritersKey, []byte{0}) - - wg := sync.WaitGroup{} - wg.Add(200) - // Start 100 readers and 100 writers. Record the maximum number of readers and writers seen. - for i := 0; i < 100; i++ { - go func() { - defer wg.Done() - - lockingCms.RLockRW([][]byte{rwLockKey}) - defer lockingCms.RUnlockRW([][]byte{rwLockKey}) - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - v := lockingCms.GetKVStore(storeKey).Get(numReadersKey) - v[0]++ - lockingCms.GetKVStore(storeKey).Set(numReadersKey, v) - lockingCms.Write() - }() - - time.Sleep(100 * time.Millisecond) - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - numReaders := lockingCms.GetKVStore(storeKey).Get(numReadersKey)[0] - maxNumReaders := lockingCms.GetKVStore(storeKey).Get(maxNumReadersKey)[0] - if numReaders > maxNumReaders { - lockingCms.GetKVStore(storeKey).Set(maxNumReadersKey, []byte{numReaders}) - } - lockingCms.Write() - }() - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - v := lockingCms.GetKVStore(storeKey).Get(numReadersKey) - v[0]-- - lockingCms.GetKVStore(storeKey).Set(numReadersKey, v) - lockingCms.Write() - }() - }() - - go func() { - defer wg.Done() - - lockingCms.LockRW([][]byte{rwLockKey}) - defer lockingCms.UnlockRW([][]byte{rwLockKey}) - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - v := lockingCms.GetKVStore(storeKey).Get(numWritersKey) - v[0]++ - lockingCms.GetKVStore(storeKey).Set(numWritersKey, v) - lockingCms.Write() - }() - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - numWriters := lockingCms.GetKVStore(storeKey).Get(numWritersKey)[0] - maxNumWriters := lockingCms.GetKVStore(storeKey).Get(maxNumWritersKey)[0] - if numWriters > maxNumWriters { - lockingCms.GetKVStore(storeKey).Set(maxNumWritersKey, []byte{numWriters}) - } - lockingCms.Write() - lockingCms.Write() - }() - - func() { - lockingCms.Lock([][]byte{lockKey}) - defer lockingCms.Unlock([][]byte{lockKey}) - v := lockingCms.GetKVStore(storeKey).Get(numWritersKey) - v[0]-- - lockingCms.GetKVStore(storeKey).Set(numWritersKey, v) - lockingCms.Write() - }() - }() - } - - wg.Wait() - // At some point there should be more than one reader. If this test is flaky, sleep time - // can be added to the reader to deflake. - require.Less(t, []byte{1}, lockingCms.GetKVStore(storeKey).Get(maxNumReadersKey)) - // There must be at most one writer at once. - require.Equal(t, []byte{1}, lockingCms.GetKVStore(storeKey).Get(maxNumWritersKey)) -} diff --git a/store/cachemulti/store.go b/store/cachemulti/store.go index 251d104c27c..722af21f153 100644 --- a/store/cachemulti/store.go +++ b/store/cachemulti/store.go @@ -3,14 +3,14 @@ package cachemulti import ( "fmt" "io" - "sync" + + dbm "github.com/cosmos/cosmos-db" "cosmossdk.io/store/cachekv" "cosmossdk.io/store/dbadapter" + "cosmossdk.io/store/lockingkv" "cosmossdk.io/store/tracekv" "cosmossdk.io/store/types" - dbm "github.com/cosmos/cosmos-db" - "golang.org/x/exp/slices" ) // storeNameCtxKey is the TraceContext metadata key that identifies @@ -31,24 +31,19 @@ type Store struct { traceWriter io.Writer traceContext types.TraceContext - - locks *sync.Map // map from string key to *sync.Mutex or *sync.RWMutex } var ( _ types.CacheMultiStore = Store{} + _ types.LockingStore = Store{} ) // NewFromKVStore creates a new Store object from a mapping of store keys to // CacheWrapper objects and a KVStore as the database. Each CacheWrapper store // is a branched store. func NewFromKVStore( - store types.KVStore, - stores map[types.StoreKey]types.CacheWrapper, - keys map[string]types.StoreKey, - traceWriter io.Writer, - traceContext types.TraceContext, - locks *sync.Map, + store types.KVStore, stores map[types.StoreKey]types.CacheWrapper, + keys map[string]types.StoreKey, traceWriter io.Writer, traceContext types.TraceContext, ) Store { cms := Store{ db: cachekv.NewStore(store), @@ -56,7 +51,6 @@ func NewFromKVStore( keys: keys, traceWriter: traceWriter, traceContext: traceContext, - locks: locks, } for key, store := range stores { @@ -73,13 +67,46 @@ func NewFromKVStore( return cms } +// NewLockingFromKVStore creates a new Store object from a mapping of store keys to +// CacheWrapper objects and a KVStore as the database. Each CacheWrapper store +// is a branched store. +func NewLockingFromKVStore( + store types.KVStore, stores map[types.StoreKey]types.CacheWrapper, + keys map[string]types.StoreKey, traceWriter io.Writer, traceContext types.TraceContext, +) Store { + cms := Store{ + db: cachekv.NewStore(store), + stores: make(map[types.StoreKey]types.CacheWrap, len(stores)), + keys: keys, + traceWriter: traceWriter, + traceContext: traceContext, + } + + for key, store := range stores { + if cms.TracingEnabled() { + tctx := cms.traceContext.Clone().Merge(types.TraceContext{ + storeNameCtxKey: key.Name(), + }) + + store = tracekv.NewStore(store.(types.KVStore), cms.traceWriter, tctx) + } + if kvStoreKey, ok := key.(*types.KVStoreKey); ok && kvStoreKey.IsLocking() { + cms.stores[key] = lockingkv.NewStore(store.(types.KVStore)) + } else { + cms.stores[key] = cachekv.NewStore(store.(types.KVStore)) + } + } + + return cms +} + // NewStore creates a new Store object from a mapping of store keys to // CacheWrapper objects. Each CacheWrapper store is a branched store. func NewStore( db dbm.DB, stores map[types.StoreKey]types.CacheWrapper, keys map[string]types.StoreKey, traceWriter io.Writer, traceContext types.TraceContext, ) Store { - return NewFromKVStore(dbadapter.Store{DB: db}, stores, keys, traceWriter, traceContext, nil) + return NewFromKVStore(dbadapter.Store{DB: db}, stores, keys, traceWriter, traceContext) } // NewLockingStore creates a new Store object from a mapping of store keys to @@ -88,14 +115,7 @@ func NewLockingStore( db dbm.DB, stores map[types.StoreKey]types.CacheWrapper, keys map[string]types.StoreKey, traceWriter io.Writer, traceContext types.TraceContext, ) Store { - return NewFromKVStore( - dbadapter.Store{DB: db}, - stores, - keys, - traceWriter, - traceContext, - &sync.Map{}, - ) + return NewLockingFromKVStore(dbadapter.Store{DB: db}, stores, keys, traceWriter, traceContext) } func newCacheMultiStoreFromCMS(cms Store) Store { @@ -104,7 +124,7 @@ func newCacheMultiStoreFromCMS(cms Store) Store { stores[k] = v } - return NewFromKVStore(cms.db, stores, nil, cms.traceWriter, cms.traceContext, cms.locks) + return NewFromKVStore(cms.db, stores, nil, cms.traceWriter, cms.traceContext) } // SetTracer sets the tracer for the MultiStore that the underlying @@ -153,88 +173,13 @@ func (cms Store) Write() { } } -// Lock, Unlock, RLockRW, LockRW, RUnlockRW, UnlockRW constitute a permissive locking interface -// that can be used to synchronize concurrent access to the store. Locking of a key should -// represent locking of some part of the store. Note that improper access is not enforced, and it is -// the user's responsibility to ensure proper locking of any access by concurrent goroutines. -// -// Common mistakes may include: -// - Introducing data races by reading or writing state that is claimed by a competing goroutine -// - Introducing deadlocks by locking in different orders through multiple calls to locking methods. -// i.e. if A calls Lock(a) followed by Lock(b), and B calls Lock(b) followed by Lock(a) -// - Using a key as an exclusive lock after it has already been initialized as a read-write lock - -// Lock acquires exclusive locks on a set of keys. -func (cms Store) Lock(keys [][]byte) { - for _, stringKey := range keysToSortedStrings(keys) { - v, _ := cms.locks.LoadOrStore(stringKey, &sync.Mutex{}) - lock := v.(*sync.Mutex) - lock.Lock() - } -} - -// Unlock releases exclusive locks on a set of keys. -func (cms Store) Unlock(keys [][]byte) { - for _, key := range keys { - v, ok := cms.locks.Load(string(key)) - if !ok { - panic("Key not found") - } - lock := v.(*sync.Mutex) - lock.Unlock() - } -} - -// RLockRW acquires read locks on a set of keys. -func (cms Store) RLockRW(keys [][]byte) { - for _, stringKey := range keysToSortedStrings(keys) { - v, _ := cms.locks.LoadOrStore(stringKey, &sync.RWMutex{}) - lock := v.(*sync.RWMutex) - lock.RLock() - } -} - -// LockRW acquires write locks on a set of keys. -func (cms Store) LockRW(keys [][]byte) { - for _, stringKey := range keysToSortedStrings(keys) { - v, _ := cms.locks.LoadOrStore(stringKey, &sync.RWMutex{}) - lock := v.(*sync.RWMutex) - lock.Lock() - } -} - -// RUnlockRW releases read locks on a set of keys. -func (cms Store) RUnlockRW(keys [][]byte) { - for _, key := range keys { - v, ok := cms.locks.Load(string(key)) - if !ok { - panic("Key not found") - } - lock := v.(*sync.RWMutex) - lock.RUnlock() - } -} - -// UnlockRW releases write locks on a set of keys. -func (cms Store) UnlockRW(keys [][]byte) { - for _, key := range keys { - v, ok := cms.locks.Load(string(key)) - if !ok { - panic("Key not found") +// Unlock calls Unlock on each underlying LockingStore. +func (cms Store) Unlock() { + for _, store := range cms.stores { + if s, ok := store.(types.LockingStore); ok { + s.Unlock() } - lock := v.(*sync.RWMutex) - lock.Unlock() - } -} - -func keysToSortedStrings(keys [][]byte) []string { - // Ensure that we always operate in a deterministic ordering when acquiring locks to prevent deadlock. - stringLockedKeys := make([]string, len(keys)) - for i, key := range keys { - stringLockedKeys[i] = string(key) } - slices.Sort(stringLockedKeys) - return stringLockedKeys } // Implements CacheWrapper. @@ -252,6 +197,40 @@ func (cms Store) CacheMultiStore() types.CacheMultiStore { return newCacheMultiStoreFromCMS(cms) } +// CacheMultiStoreWithLocking branches each store wrapping each store with a cachekv store if not locked or +// delegating to CacheWrapWithLocks if it is a LockingCacheWrapper. +func (cms Store) CacheMultiStoreWithLocking(storeLocks map[types.StoreKey][][]byte) types.CacheMultiStore { + stores := make(map[types.StoreKey]types.CacheWrapper) + for k, v := range cms.stores { + stores[k] = v + } + + cms2 := Store{ + db: cachekv.NewStore(cms.db), + stores: make(map[types.StoreKey]types.CacheWrap, len(stores)), + keys: cms.keys, + traceWriter: cms.traceWriter, + traceContext: cms.traceContext, + } + + for key, store := range stores { + if lockKeys, ok := storeLocks[key]; ok { + cms2.stores[key] = store.(types.LockingCacheWrapper).CacheWrapWithLocks(lockKeys) + } else { + if cms.TracingEnabled() { + tctx := cms.traceContext.Clone().Merge(types.TraceContext{ + storeNameCtxKey: key.Name(), + }) + + store = tracekv.NewStore(store.(types.KVStore), cms.traceWriter, tctx) + } + cms2.stores[key] = cachekv.NewStore(store.(types.KVStore)) + } + } + + return cms2 +} + // CacheMultiStoreWithVersion implements the MultiStore interface. It will panic // as an already cached multi-store cannot load previous versions. // diff --git a/store/lockingkv/lockingkv.go b/store/lockingkv/lockingkv.go new file mode 100644 index 00000000000..7038ba07cb6 --- /dev/null +++ b/store/lockingkv/lockingkv.go @@ -0,0 +1,252 @@ +package lockingkv + +import ( + "io" + "sort" + "sync" + + "golang.org/x/exp/slices" + + "cosmossdk.io/store/cachekv" + "cosmossdk.io/store/tracekv" + storetypes "cosmossdk.io/store/types" +) + +var ( + _ storetypes.CacheKVStore = &LockableKV{} + _ storetypes.LockingCacheWrapper = &LockableKV{} + _ storetypes.CacheKVStore = &LockedKV{} + _ storetypes.LockingStore = &LockedKV{} +) + +func NewStore(parent storetypes.KVStore) *LockableKV { + return &LockableKV{ + parent: parent, + locks: sync.Map{}, + } +} + +// LockableKV is a store that is able to provide locks. Each locking key that is used for a lock must represent a +// disjoint partition of store keys that are able to be mutated. For example, locking per account public key would +// provide a lock over all mutations related to that account. +type LockableKV struct { + parent storetypes.KVStore + locks sync.Map // map from string key to *sync.Mutex. + mutations sync.Map // map from string key to []byte. +} + +func (s *LockableKV) Write() { + s.locks.Range(func(key, value any) bool { + lock := value.(*sync.Mutex) + // We should be able to acquire the lock and only would not be able to if for some reason a child + // store was not unlocked. + if !lock.TryLock() { + panic("LockedKV is missing Unlock() invocation.") + } + + // We specifically don't unlock here which prevents users from acquiring the locks again and + // mutating the values allowing the Write() invocation only to happen once effectively. + + return true + }) + + values := make(map[string][]byte) + s.mutations.Range(func(key, value any) bool { + values[key.(string)] = value.([]byte) + return true + }) + + // We need to make the mutations to the parent in a deterministic order to ensure a deterministic hash. + for _, sortedKey := range getSortedKeys[sort.StringSlice](values) { + value := values[sortedKey] + + if value == nil { + s.parent.Delete([]byte(sortedKey)) + } else { + s.parent.Set([]byte(sortedKey), value) + } + } +} + +func (s *LockableKV) GetStoreType() storetypes.StoreType { + return s.parent.GetStoreType() +} + +// CacheWrap allows for branching the store. Care must be taken to ensure that synchronization outside of this +// store is performed to ensure that reads and writes are linearized. +func (s *LockableKV) CacheWrap() storetypes.CacheWrap { + return cachekv.NewStore(s) +} + +// CacheWrapWithTrace allows for branching the store with tracing. Care must be taken to ensure that synchronization +// outside of this store is performed to ensure that reads and writes are linearized. +func (s *LockableKV) CacheWrapWithTrace(w io.Writer, tc storetypes.TraceContext) storetypes.CacheWrap { + return cachekv.NewStore(tracekv.NewStore(s, w, tc)) +} + +// CacheWrapWithLocks returns a store that allows mutating a set of store keys that are covered by the +// set of lock keys. Each lock key should represent a disjoint partitioned space of store keys for which +// the caller is acquiring locks for. +func (s *LockableKV) CacheWrapWithLocks(lockKeys [][]byte) storetypes.CacheWrap { + stringLockedKeys := make([]string, len(lockKeys)) + for i, key := range lockKeys { + stringLockedKeys[i] = string(key) + } + // Ensure that we always operate in a deterministic ordering when acquiring locks to prevent deadlock. + slices.Sort(stringLockedKeys) + for _, stringKey := range stringLockedKeys { + v, _ := s.locks.LoadOrStore(stringKey, &sync.Mutex{}) + lock := v.(*sync.Mutex) + lock.Lock() + } + + return &LockedKV{ + parent: s, + lockKeys: stringLockedKeys, + mutations: make(map[string][]byte), + } +} + +func (s *LockableKV) Get(key []byte) []byte { + v, loaded := s.mutations.Load(string(key)) + if loaded { + return v.([]byte) + } + + return s.parent.Get(key) +} + +func (s *LockableKV) Has(key []byte) bool { + v, loaded := s.mutations.Load(string(key)) + if loaded { + return v.([]byte) != nil + } + + return s.parent.Has(key) +} + +func (s *LockableKV) Set(key, value []byte) { + s.mutations.Store(string(key), value) +} + +func (s *LockableKV) Delete(key []byte) { + s.Set(key, nil) +} + +func (s *LockableKV) Iterator(start, end []byte) storetypes.Iterator { + panic("This store does not support iteration.") +} + +func (s *LockableKV) ReverseIterator(start, end []byte) storetypes.Iterator { + panic("This store does not support iteration.") +} + +func (s *LockableKV) writeMutations(mutations map[string][]byte) { + // We don't need to sort here since the sync.Map stores keys and values in an arbitrary order. + // LockableKV.Write is responsible for sorting all the keys to ensure a deterministic write order. + for key, mutation := range mutations { + s.mutations.Store(key, mutation) + } +} + +func (s *LockableKV) unlock(lockKeys []string) { + for _, key := range lockKeys { + v, ok := s.locks.Load(key) + if !ok { + panic("Key not found") + } + lock := v.(*sync.Mutex) + + lock.Unlock() + } +} + +// LockedKV is a store that only allows setting of keys that have been locked via CacheWrapWithLocks. +// All other keys are allowed to be read but the user must ensure that no one else is able to mutate those +// values without the appropriate synchronization occurring outside of this store. +// +// This store does not support iteration. +type LockedKV struct { + parent *LockableKV + + lockKeys []string + mutations map[string][]byte +} + +func (s *LockedKV) Write() { + s.parent.writeMutations(s.mutations) +} + +func (s *LockedKV) Unlock() { + s.parent.unlock(s.lockKeys) +} + +func (s *LockedKV) GetStoreType() storetypes.StoreType { + return s.parent.GetStoreType() +} + +func (s *LockedKV) CacheWrap() storetypes.CacheWrap { + return cachekv.NewStore(s) +} + +func (s *LockedKV) CacheWrapWithTrace(w io.Writer, tc storetypes.TraceContext) storetypes.CacheWrap { + return cachekv.NewStore(tracekv.NewStore(s, w, tc)) +} + +func (s *LockedKV) Get(key []byte) []byte { + if key == nil { + panic("nil key") + } + + if value, ok := s.mutations[string(key)]; ok { + return value + } + + return s.parent.Get(key) +} + +func (s *LockedKV) Has(key []byte) bool { + if key == nil { + panic("nil key") + } + + if value, ok := s.mutations[string(key)]; ok { + return value != nil + } + + return s.parent.Has(key) +} + +func (s *LockedKV) Set(key, value []byte) { + if key == nil { + panic("nil key") + } + + s.mutations[string(key)] = value +} + +func (s *LockedKV) Delete(key []byte) { + s.Set(key, nil) +} + +func (s *LockedKV) Iterator(start, end []byte) storetypes.Iterator { + panic("This store does not support iteration.") +} + +func (s *LockedKV) ReverseIterator(start, end []byte) storetypes.Iterator { + panic("This store does not support iteration.") +} + +// getSortedKeys returns the keys of the map in sorted order. +func getSortedKeys[R interface { + ~[]K + sort.Interface +}, K comparable, V any](m map[K]V, +) []K { + keys := make([]K, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Sort(R(keys)) + return keys +} diff --git a/store/lockingkv/lockingkv_test.go b/store/lockingkv/lockingkv_test.go new file mode 100644 index 00000000000..162b362e679 --- /dev/null +++ b/store/lockingkv/lockingkv_test.go @@ -0,0 +1,184 @@ +package lockingkv_test + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "cosmossdk.io/store/lockingkv" + "cosmossdk.io/store/transient" + storetypes "cosmossdk.io/store/types" +) + +var ( + a = []byte("a") + b = []byte("b") + key = []byte("key") +) + +func TestLockingKV_LinearizeReadsAndWrites(t *testing.T) { + parent := transient.NewStore() + locking := lockingkv.NewStore(parent) + + wg := sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + + locked := locking.CacheWrapWithLocks([][]byte{a}) + defer locked.(storetypes.LockingStore).Unlock() + v := locked.(storetypes.KVStore).Get(key) + if v == nil { + locked.(storetypes.KVStore).Set(key, []byte{1}) + } else { + v[0]++ + locked.(storetypes.KVStore).Set(key, v) + } + locked.Write() + }() + } + + wg.Wait() + require.Equal(t, []byte{100}, locking.Get(key)) +} + +func TestLockingKV_LockOrderToPreventDeadlock(t *testing.T) { + parent := transient.NewStore() + locking := lockingkv.NewStore(parent) + + // Acquire keys in two different orders ensuring that we don't reach deadlock. + wg := sync.WaitGroup{} + wg.Add(200) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + + locked := locking.CacheWrapWithLocks([][]byte{a, b}) + defer locked.(storetypes.LockingStore).Unlock() + v := locked.(storetypes.KVStore).Get(key) + if v == nil { + locked.(storetypes.KVStore).Set(key, []byte{1}) + } else { + v[0]++ + locked.(storetypes.KVStore).Set(key, v) + } + locked.Write() + }() + + go func() { + defer wg.Done() + + locked := locking.CacheWrapWithLocks([][]byte{b, a}) + defer locked.(storetypes.LockingStore).Unlock() + v := locked.(storetypes.KVStore).Get(key) + if v == nil { + locked.(storetypes.KVStore).Set(key, []byte{1}) + } else { + v[0]++ + locked.(storetypes.KVStore).Set(key, v) + } + locked.Write() + }() + } + + wg.Wait() + require.Equal(t, []byte{200}, locking.Get(key)) +} + +func TestLockingKV_AllowForParallelUpdates(t *testing.T) { + parent := transient.NewStore() + locking := lockingkv.NewStore(parent) + + wg := sync.WaitGroup{} + wg.Add(100) + + lockeds := make([]storetypes.LockingStore, 100) + for i := byte(0); i < 100; i++ { + k := []byte{i} + // We specifically don't unlock the keys during processing so that we can show that we must process all + // of these in parallel before the wait group is done. + locked := locking.CacheWrapWithLocks([][]byte{k}) + lockeds[i] = locked.(storetypes.LockingStore) + go func() { + // The defer order is from last to first so we mark that we are done and then exit. + defer wg.Done() + + locked.(storetypes.KVStore).Set(k, k) + locked.Write() + }() + } + + wg.Wait() + for _, locked := range lockeds { + locked.Unlock() + } + for i := byte(0); i < 100; i++ { + require.Equal(t, []byte{i}, locking.Get([]byte{i})) + } +} + +func TestLockingKV_SetGetHas(t *testing.T) { + parent := transient.NewStore() + parent.Set(a, b) + locking := lockingkv.NewStore(parent) + + // Check that Get is transitive to the parent. + require.Equal(t, b, locking.Get(a)) + require.Nil(t, locking.Get(b)) + + // Check that Has is transitive to the parent. + require.True(t, locking.Has(a)) + require.False(t, locking.Has(b)) + + // Check that Set isn't transitive to the parent. + locking.Set(key, a) + require.False(t, parent.Has(key)) + + // Check that we can read our writes. + require.True(t, locking.Has(key)) + require.Equal(t, a, locking.Get(key)) + + // Check that committing the writes to the parent. + locking.Write() + require.True(t, parent.Has(key)) + require.Equal(t, a, parent.Get(key)) +} + +func TestLockedKV_SetGetHas(t *testing.T) { + parent := transient.NewStore() + parent.Set(a, b) + locking := lockingkv.NewStore(parent) + locked := locking.CacheWrapWithLocks([][]byte{key}).(storetypes.CacheKVStore) + + // Check that Get is transitive to the parent. + require.Equal(t, b, locked.Get(a)) + require.Nil(t, locked.Get(b)) + + // Check that Has is transitive to the parent. + require.True(t, locked.Has(a)) + require.False(t, locked.Has(b)) + + // Check that Set isn't transitive to the parent. + locked.Set(key, a) + require.False(t, locking.Has(key)) + + // Check that we can read our writes. + require.True(t, locked.Has(key)) + require.Equal(t, a, locked.Get(key)) + + // Check that committing the writes to the parent and not the parent's parent. + locked.Write() + require.True(t, locking.Has(key)) + require.Equal(t, a, locking.Get(key)) + require.False(t, parent.Has(key)) + require.Nil(t, parent.Get(key)) + + // Unlock and get another instance of the store to see that the mutations in the locking store are visible. + locked.(storetypes.LockingStore).Unlock() + locked = locking.CacheWrapWithLocks([][]byte{key}).(storetypes.CacheKVStore) + require.True(t, locked.Has(key)) + require.Equal(t, a, locked.Get(key)) + locked.(storetypes.LockingStore).Unlock() +} diff --git a/store/types/store.go b/store/types/store.go index 57c98636d16..106bab8c91a 100644 --- a/store/types/store.go +++ b/store/types/store.go @@ -158,12 +158,6 @@ type MultiStore interface { type CacheMultiStore interface { MultiStore Write() // Writes operations to underlying KVStore - Lock(keys [][]byte) - Unlock(keys [][]byte) - RLockRW(Rkeys [][]byte) - LockRW(Rkeys [][]byte) - RUnlockRW(keys [][]byte) - UnlockRW(keys [][]byte) } // CommitMultiStore is an interface for a MultiStore without cache capabilities. @@ -284,6 +278,14 @@ type CacheKVStore interface { Write() } +// LockingStore allows for unlocking the associated lock keys that were acquired during +// locking with CacheWrapWithLocks on a LockingCacheWrapper. +type LockingStore interface { + Store + + Unlock() +} + // CommitKVStore is an interface for MultiStore. type CommitKVStore interface { Committer @@ -316,6 +318,13 @@ type CacheWrapper interface { CacheWrapWithTrace(w io.Writer, tc TraceContext) CacheWrap } +type LockingCacheWrapper interface { + CacheWrapper + + // CacheWrapWithLocks branches a store with the specific lock keys being acquired. + CacheWrapWithLocks(lockKeys [][]byte) CacheWrap +} + func (cid CommitID) IsZero() bool { return cid.Version == 0 && len(cid.Hash) == 0 } @@ -383,7 +392,8 @@ type CapabilityKey StoreKey // KVStoreKey is used for accessing substores. // Only the pointer value should ever be used - it functions as a capabilities key. type KVStoreKey struct { - name string + name string + locking bool } // NewKVStoreKey returns a new pointer to a KVStoreKey. @@ -415,7 +425,19 @@ func (key *KVStoreKey) Name() string { } func (key *KVStoreKey) String() string { - return fmt.Sprintf("KVStoreKey{%p, %s}", key, key.name) + return fmt.Sprintf("KVStoreKey{%p, %s, locking: %t}", key, key.name, key.locking) +} + +func (key *KVStoreKey) IsLocking() bool { + return key.locking +} + +// Enables locking for the store key. +func (key *KVStoreKey) WithLocking() *KVStoreKey { + return &KVStoreKey{ + name: key.name, + locking: true, + } } // TransientStoreKey is used for indexing transient stores in a MultiStore diff --git a/store/types/store_test.go b/store/types/store_test.go index b6304d131bc..26337b132cc 100644 --- a/store/types/store_test.go +++ b/store/types/store_test.go @@ -81,7 +81,9 @@ func TestKVStoreKey(t *testing.T) { key := NewKVStoreKey("test") require.Equal(t, "test", key.name) require.Equal(t, key.name, key.Name()) - require.Equal(t, fmt.Sprintf("KVStoreKey{%p, test}", key), key.String()) + require.Equal(t, fmt.Sprintf("KVStoreKey{%p, test, locking: false}", key), key.String()) + keyWithLocking := key.WithLocking() + require.Equal(t, fmt.Sprintf("KVStoreKey{%p, test, locking: true}", keyWithLocking), keyWithLocking.String()) } func TestNilKVStoreKey(t *testing.T) {