diff --git a/runtime/storage.go b/runtime/storage.go index 4a2c79136..87f3cf094 100644 --- a/runtime/storage.go +++ b/runtime/storage.go @@ -40,12 +40,12 @@ type StorageConfig struct { StorageFormatV2Enabled bool } -type storageFormat uint8 +type StorageFormat uint8 const ( - storageFormatUnknown storageFormat = iota - storageFormatV1 - storageFormatV2 + StorageFormatUnknown StorageFormat = iota + StorageFormatV1 + StorageFormatV2 ) type Storage struct { @@ -292,7 +292,7 @@ func (s *Storage) getDomainStorageMapForV2Account( } func (s *Storage) getDomainStorageMap( - format storageFormat, + format StorageFormat, inter *interpreter.Interpreter, address common.Address, domain common.StorageDomain, @@ -300,14 +300,14 @@ func (s *Storage) getDomainStorageMap( ) *interpreter.DomainStorageMap { switch format { - case storageFormatV1: + case StorageFormatV1: return s.getDomainStorageMapForV1Account( address, domain, createIfNotExists, ) - case storageFormatV2: + case StorageFormatV2: return s.getDomainStorageMapForV2Account( inter, address, @@ -320,15 +320,15 @@ func (s *Storage) getDomainStorageMap( } } -func (s *Storage) getCachedAccountFormat(address common.Address) (format storageFormat, known bool) { +func (s *Storage) getCachedAccountFormat(address common.Address) (format StorageFormat, known bool) { isV1, cached := s.cachedV1Accounts[address] if !cached { - return storageFormatUnknown, false + return StorageFormatUnknown, false } if isV1 { - return storageFormatV1, true + return StorageFormatV1, true } else { - return storageFormatV2, true + return StorageFormatV2, true } } @@ -708,6 +708,37 @@ func (s *Storage) CheckHealth() error { return nil } +// AccountStorageFormat returns either StorageFormatV1 or StorageFormatV2 for existing accounts, +// and StorageFormatUnknown for non-existing accounts. +func (s *Storage) AccountStorageFormat(address common.Address) (format StorageFormat) { + cachedFormat, known := s.getCachedAccountFormat(address) + if known { + return cachedFormat + } + + defer func() { + // Cache account fomat + switch format { + case StorageFormatV1: + s.cacheIsV1Account(address, true) + case StorageFormatV2: + s.cacheIsV1Account(address, false) + } + }() + + if s.Config.StorageFormatV2Enabled { + if s.isV2Account(address) { + return StorageFormatV2 + } + } + + if s.isV1Account(address) { + return StorageFormatV1 + } + + return StorageFormatUnknown +} + type UnreferencedRootSlabsError struct { UnreferencedRootSlabIDs []atree.SlabID } diff --git a/runtime/storage_test.go b/runtime/storage_test.go index 470510149..c9efbfdc5 100644 --- a/runtime/storage_test.go +++ b/runtime/storage_test.go @@ -8976,6 +8976,249 @@ func TestGetDomainStorageMapRegisterReadsForV2Account(t *testing.T) { } } +func TestAccountStorageFormatForNonExistingAccount(t *testing.T) { + + t.Parallel() + + address := common.MustBytesToAddress([]byte{0x1}) + + testCases := []struct { + name string + storageFormatV2Enabled bool + format StorageFormat + }{ + { + name: "non-existing account, storageFormatV2Enabled = false", + storageFormatV2Enabled: false, + format: StorageFormatUnknown, + }, + { + name: "non-existing account, storageFormatV2Enabled = true", + storageFormatV2Enabled: true, + format: StorageFormatUnknown, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ledger := NewTestLedger(nil, nil) + + storage := NewStorage( + ledger, + nil, + StorageConfig{ + StorageFormatV2Enabled: tc.storageFormatV2Enabled, + }, + ) + + for range 2 { + format := storage.AccountStorageFormat(address) + require.Equal(t, tc.format, format) + } + }) + } +} + +func TestAccountStorageFormatForV1Account(t *testing.T) { + t.Parallel() + + address := common.MustBytesToAddress([]byte{0x1}) + + createV1AccountWithDomain := func( + address common.Address, + domain common.StorageDomain, + ) (storedValues map[string][]byte, StorageIndices map[string]uint64) { + ledger := NewTestLedger(nil, nil) + + persistentSlabStorage := NewPersistentSlabStorage(ledger, nil) + + orderedMap, err := atree.NewMap( + persistentSlabStorage, + atree.Address(address), + atree.NewDefaultDigesterBuilder(), + interpreter.EmptyTypeInfo{}, + ) + require.NoError(t, err) + + slabIndex := orderedMap.SlabID().Index() + + for i := range 3 { + + key := interpreter.StringStorageMapKey(strconv.Itoa(i)) + + value := interpreter.NewUnmeteredIntValueFromInt64(int64(i)) + + existingStorable, err := orderedMap.Set( + key.AtreeValueCompare, + key.AtreeValueHashInput, + key.AtreeValue(), + value, + ) + require.NoError(t, err) + require.Nil(t, existingStorable) + } + + // Commit domain storage map + err = persistentSlabStorage.FastCommit(runtime.NumCPU()) + require.NoError(t, err) + + // Create domain register + err = ledger.SetValue(address[:], []byte(domain.Identifier()), slabIndex[:]) + require.NoError(t, err) + + return ledger.StoredValues, ledger.StorageIndices + } + + testCases := []struct { + name string + storageFormatV2Enabled bool + format StorageFormat + }{ + { + name: "v1 account, storageFormatV2Enabled = false", + storageFormatV2Enabled: false, + format: StorageFormatV1, + }, + { + name: "v1 account, storageFormatV2Enabled = true", + storageFormatV2Enabled: true, + format: StorageFormatV1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + storedValues, storedIndices := createV1AccountWithDomain( + address, + common.StorageDomainPathStorage, + ) + + ledger := NewTestLedgerWithData(nil, nil, storedValues, storedIndices) + + storage := NewStorage( + ledger, + nil, + StorageConfig{ + StorageFormatV2Enabled: tc.storageFormatV2Enabled, + }, + ) + + for range 2 { + format := storage.AccountStorageFormat(address) + require.Equal(t, tc.format, format) + } + }) + } +} + +func TestAccountStorageFormatForV2Account(t *testing.T) { + t.Parallel() + + address := common.MustBytesToAddress([]byte{0x1}) + + createV2AccountWithDomain := func( + address common.Address, + domain common.StorageDomain, + ) (storedValues map[string][]byte, StorageIndices map[string]uint64) { + ledger := NewTestLedger(nil, nil) + + persistentSlabStorage := NewPersistentSlabStorage(ledger, nil) + + accountOrderedMap, err := atree.NewMap( + persistentSlabStorage, + atree.Address(address), + atree.NewDefaultDigesterBuilder(), + interpreter.EmptyTypeInfo{}, + ) + require.NoError(t, err) + + slabIndex := accountOrderedMap.SlabID().Index() + + domainOrderedMap, err := atree.NewMap( + persistentSlabStorage, + atree.Address(address), + atree.NewDefaultDigesterBuilder(), + interpreter.EmptyTypeInfo{}, + ) + require.NoError(t, err) + + domainKey := interpreter.Uint64StorageMapKey(domain) + + existingDomain, err := accountOrderedMap.Set( + domainKey.AtreeValueCompare, + domainKey.AtreeValueHashInput, + domainKey.AtreeValue(), + domainOrderedMap, + ) + require.NoError(t, err) + require.Nil(t, existingDomain) + + for i := range 3 { + + key := interpreter.StringStorageMapKey(strconv.Itoa(i)) + + value := interpreter.NewUnmeteredIntValueFromInt64(int64(i)) + + existingStorable, err := domainOrderedMap.Set( + key.AtreeValueCompare, + key.AtreeValueHashInput, + key.AtreeValue(), + value, + ) + require.NoError(t, err) + require.Nil(t, existingStorable) + } + + // Commit domain storage map + err = persistentSlabStorage.FastCommit(runtime.NumCPU()) + require.NoError(t, err) + + // Create account register + err = ledger.SetValue(address[:], []byte(AccountStorageKey), slabIndex[:]) + require.NoError(t, err) + + return ledger.StoredValues, ledger.StorageIndices + } + + testCases := []struct { + name string + storageFormatV2Enabled bool + format StorageFormat + }{ + { + name: "v2 account, storageFormatV2Enabled = true", + storageFormatV2Enabled: true, + format: StorageFormatV2, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + storedValues, storedIndices := createV2AccountWithDomain( + address, + common.StorageDomainPathStorage, + ) + + ledger := NewTestLedgerWithData(nil, nil, storedValues, storedIndices) + + storage := NewStorage( + ledger, + nil, + StorageConfig{ + StorageFormatV2Enabled: tc.storageFormatV2Enabled, + }, + ) + + for range 2 { + format := storage.AccountStorageFormat(address) + require.Equal(t, tc.format, format) + } + }) + } +} + // createAndWriteAccountStorageMap creates account storage map with given domains and writes random values to domain storage map. func createAndWriteAccountStorageMap( t testing.TB,