Skip to content

Commit

Permalink
Add unit tests on migration
Browse files Browse the repository at this point in the history
  • Loading branch information
0Tech committed Aug 31, 2023
1 parent 8cdd61c commit eee7b93
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 20 deletions.
56 changes: 40 additions & 16 deletions x/collection/keeper/migrations/v2/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,46 @@ var (

balanceKeyPrefix = []byte{0x20}

supplyKeyPrefix = []byte{0x40}
mintedKeyPrefix = []byte{0x41}
burntKeyPrefix = []byte{0x42}
SupplyKeyPrefix = []byte{0x40}
MintedKeyPrefix = []byte{0x41}
BurntKeyPrefix = []byte{0x42}
)

func ContractKey(contractID string) []byte {
key := make([]byte, len(contractKeyPrefix)+len(contractID))

copy(key, contractKeyPrefix)
copy(key[len(contractKeyPrefix):], contractID)

return key
}

func BalanceKey(contractID string, address sdk.AccAddress, tokenID string) []byte {
prefix := balanceKeyPrefixByAddress(contractID, address)
key := make([]byte, len(prefix)+len(tokenID))

copy(key, prefix)
copy(key[len(prefix):], tokenID)

return key
}

func balanceKeyPrefixByAddress(contractID string, address sdk.AccAddress) []byte {
prefix := balanceKeyPrefixByContractID(contractID)
key := make([]byte, len(prefix)+1+len(address))

begin := 0
copy(key, prefix)

begin += len(prefix)
key[begin] = byte(len(address))

begin++
copy(key[begin:], address)

return key
}

func balanceKeyPrefixByContractID(contractID string) []byte {
key := make([]byte, len(balanceKeyPrefix)+1+len(contractID))

Expand Down Expand Up @@ -46,7 +81,7 @@ func splitBalanceKey(key []byte) (contractID string, address sdk.AccAddress, tok
return
}

func statisticKey(keyPrefix []byte, contractID string, classID string) []byte {
func StatisticKey(keyPrefix []byte, contractID string, classID string) []byte {
prefix := statisticKeyPrefixByContractID(keyPrefix, contractID)
key := make([]byte, len(prefix)+len(classID))

Expand All @@ -71,18 +106,7 @@ func statisticKeyPrefixByContractID(keyPrefix []byte, contractID string) []byte
return key
}

func splitStatisticKey(keyPrefix, key []byte) (contractID string, classID string) {
begin := len(keyPrefix) + 1
end := begin + int(key[begin-1])
contractID = string(key[begin:end])

begin = end
classID = string(key[begin:])

return
}

func nextClassIDKey(contractID string) []byte {
func NextClassIDKey(contractID string) []byte {
key := make([]byte, len(nextClassIDKeyPrefix)+len(contractID))

copy(key, nextClassIDKeyPrefix)
Expand Down
8 changes: 4 additions & 4 deletions x/collection/keeper/migrations/v2/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func evalContractFTSupplies(store storetypes.KVStore, contractID string) (map[st
}

func updateContractFTStatistics(store storetypes.KVStore, contractID string, supplies map[string]sdk.Int) error {
bz := store.Get(nextClassIDKey(contractID))
bz := store.Get(NextClassIDKey(contractID))
if bz == nil {
return fmt.Errorf("no next class ids of contract %s", contractID)
}
Expand All @@ -95,7 +95,7 @@ func updateContractFTStatistics(store storetypes.KVStore, contractID string, sup
classID := fmt.Sprintf("%08x", intClassID)

// update supply
supplyKey := statisticKey(supplyKeyPrefix, contractID, classID)
supplyKey := StatisticKey(SupplyKeyPrefix, contractID, classID)
supply, ok := supplies[classID]
if ok {
bz, err := supply.Marshal()
Expand All @@ -108,7 +108,7 @@ func updateContractFTStatistics(store storetypes.KVStore, contractID string, sup
}

// get burnt
burntKey := statisticKey(burntKeyPrefix, contractID, classID)
burntKey := StatisticKey(BurntKeyPrefix, contractID, classID)
burnt := sdk.ZeroInt()
if bz := store.Get(burntKey); bz != nil {
if err := burnt.Unmarshal(bz); err != nil {
Expand All @@ -118,7 +118,7 @@ func updateContractFTStatistics(store storetypes.KVStore, contractID string, sup

// update minted
minted := supply.Add(burnt)
mintedKey := statisticKey(mintedKeyPrefix, contractID, classID)
mintedKey := StatisticKey(MintedKeyPrefix, contractID, classID)
if !minted.IsZero() {
bz, err := minted.Marshal()
if err != nil {
Expand Down
152 changes: 152 additions & 0 deletions x/collection/keeper/migrations/v2/store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package v2_test

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

simappparams "github.com/Finschia/finschia-sdk/simapp/params"
"github.com/Finschia/finschia-sdk/testutil"
sdk "github.com/Finschia/finschia-sdk/types"
"github.com/Finschia/finschia-sdk/x/collection"

"github.com/Finschia/finschia-sdk/x/collection/keeper/migrations/v2"
)

func TestMigrateStore(t *testing.T) {
collectionKey := sdk.NewKVStoreKey(collection.StoreKey)
newKey := sdk.NewTransientStoreKey("transient_test")
encCfg := simappparams.MakeTestEncodingConfig()
ctx := testutil.DefaultContext(collectionKey, newKey)

// set state
store := ctx.KVStore(collectionKey)

contractID := "deadbeef"
store.Set(v2.ContractKey(contractID), encCfg.Marshaler.MustMarshal(&collection.Contract{Id: contractID}))
nextClassIDs := collection.DefaultNextClassIDs(contractID)
classID := fmt.Sprintf("%08x", nextClassIDs.Fungible.Uint64())
nextClassIDs.Fungible = nextClassIDs.Fungible.Incr()
store.Set(v2.NextClassIDKey(contractID), encCfg.Marshaler.MustMarshal(&nextClassIDs))

tokenID := collection.NewFTID(classID)
oneIntBz, err := sdk.OneInt().Marshal()
require.NoError(t, err)
addresses := []sdk.AccAddress{
sdk.AccAddress("fennec"),
sdk.AccAddress("penguin"),
sdk.AccAddress("cheetah"),
}
for _, addr := range addresses {
store.Set(v2.BalanceKey(contractID, addr, tokenID), oneIntBz)
}
store.Set(v2.StatisticKey(v2.SupplyKeyPrefix, contractID, classID), oneIntBz)
store.Set(v2.StatisticKey(v2.MintedKeyPrefix, contractID, classID), oneIntBz)
store.Set(v2.StatisticKey(v2.BurntKeyPrefix, contractID, classID), oneIntBz)

for name, tc := range map[string]struct {
malleate func(ctx sdk.Context)
valid bool
supply int
minted int
}{
"valid": {
valid: true,
supply: len(addresses),
minted: len(addresses) + 1,
},
"valid (nil supply)": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Delete(v2.StatisticKey(v2.SupplyKeyPrefix, contractID, classID))
},
valid: true,
supply: len(addresses),
minted: len(addresses) + 1,
},
"valid (nil minted)": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Delete(v2.StatisticKey(v2.MintedKeyPrefix, contractID, classID))
},
valid: true,
supply: len(addresses),
minted: len(addresses) + 1,
},
"valid (nil burnt)": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Delete(v2.StatisticKey(v2.BurntKeyPrefix, contractID, classID))
},
valid: true,
supply: len(addresses),
minted: len(addresses),
},
"contract unmarshal failed": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Set(v2.ContractKey(contractID), encCfg.Marshaler.MustMarshal(&collection.GenesisState{}))
},
},
"balance unmarshal failed": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Set(v2.BalanceKey(contractID, sdk.AccAddress("hyena"), tokenID), encCfg.Marshaler.MustMarshal(&collection.GenesisState{}))
},
},
"no next class id": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Delete(v2.NextClassIDKey(contractID))
},
},
"next class id unmarshal failed": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Set(v2.NextClassIDKey(contractID), []byte("invalid"))
},
},
"burnt unmarshal failed": {
malleate: func(ctx sdk.Context) {
store := ctx.KVStore(collectionKey)
store.Set(v2.StatisticKey(v2.BurntKeyPrefix, contractID, classID), encCfg.Marshaler.MustMarshal(&collection.GenesisState{}))
},
},
} {
t.Run(name, func(t *testing.T) {
ctx, _ := ctx.CacheContext()
if tc.malleate != nil {
tc.malleate(ctx)
}

// migrate
err := v2.MigrateStore(ctx, collectionKey, encCfg.Marshaler)
if !tc.valid {
require.Error(t, err)
return
}
require.NoError(t, err)

store := ctx.KVStore(collectionKey)

// supply
supplyKey := v2.StatisticKey(v2.SupplyKeyPrefix, contractID, classID)
supply := sdk.ZeroInt()
if bz := store.Get(supplyKey); bz != nil {
err := supply.Unmarshal(bz)
require.NoError(t, err)
}
require.Equal(t, int64(tc.supply), supply.Int64())

// minted
mintedKey := v2.StatisticKey(v2.MintedKeyPrefix, contractID, classID)
minted := sdk.ZeroInt()
if bz := store.Get(mintedKey); bz != nil {
err := minted.Unmarshal(bz)
require.NoError(t, err)
}
require.Equal(t, int64(tc.minted), minted.Int64())
})
}
}

0 comments on commit eee7b93

Please sign in to comment.