diff --git a/x/collection/keeper/invariants.go b/x/collection/keeper/invariants.go new file mode 100644 index 0000000000..7d000bea1d --- /dev/null +++ b/x/collection/keeper/invariants.go @@ -0,0 +1,76 @@ +package keeper + +import ( + "strings" + + sdk "github.com/Finschia/finschia-sdk/types" + "github.com/Finschia/finschia-sdk/x/collection" +) + +const ( + totalSupplyInvariant = "total-supply" +) + +func RegisterInvariants(ir sdk.InvariantRegistry, k Keeper) { + for name, invariant := range map[string]func(k Keeper) sdk.Invariant{ + totalSupplyInvariant: TotalSupplyInvariant, + } { + ir.RegisterRoute(collection.ModuleName, name, invariant(k)) + } +} + +func TotalSupplyInvariant(k Keeper) sdk.Invariant { + return func(ctx sdk.Context) (string, bool) { + // cache, we don't want to write changes + ctx, _ = ctx.CacheContext() + + invalidClassIDs := map[string][]string{} + k.iterateContracts(ctx, func(contract collection.Contract) (stop bool) { + supplies := map[string]sdk.Int{} + k.iterateContractSupplies(ctx, contract.Id, func(classID string, amount sdk.Int) (stop bool) { + supplies[classID] = amount + return false + }) + + k.iterateContractBalances(ctx, contract.Id, func(address sdk.AccAddress, balance collection.Coin) (stop bool) { + classID := collection.SplitTokenID(balance.TokenId) + amount, ok := supplies[classID] + if !ok { + amount = sdk.ZeroInt() + } + + supplies[classID] = amount.Sub(balance.Amount) + return false + }) + + invalidClassIDsCandidate := []string{} + for classID, supply := range supplies { + if !supply.IsZero() { + invalidClassIDsCandidate = append(invalidClassIDsCandidate, classID) + } + } + + if len(invalidClassIDsCandidate) != 0 { + invalidClassIDs[contract.Id] = invalidClassIDsCandidate + } + + return false + }) + + broken := len(invalidClassIDs) != 0 + msg := "no violation found" + if broken { + concatenated := []string{} + delimiter := ":" + for contractID, classIDs := range invalidClassIDs { + for _, classID := range classIDs { + concatenated = append(concatenated, contractID+delimiter+classID) + } + } + + msg = "violation found on following classIDs: " + strings.Join(concatenated, ", ") + } + + return sdk.FormatInvariant(collection.ModuleName, totalSupplyInvariant, msg), broken + } +} diff --git a/x/collection/module/module.go b/x/collection/module/module.go index 1c9a323cd5..51c0ae6d12 100644 --- a/x/collection/module/module.go +++ b/x/collection/module/module.go @@ -89,7 +89,9 @@ func NewAppModule(cdc codec.Codec, keeper keeper.Keeper) AppModule { } // RegisterInvariants does nothing, there are no invariants to enforce -func (AppModule) RegisterInvariants(_ sdk.InvariantRegistry) {} +func (am AppModule) RegisterInvariants(ir sdk.InvariantRegistry) { + keeper.RegisterInvariants(ir, am.keeper) +} // Route returns the message routing key for the collection module. func (am AppModule) Route() sdk.Route { return sdk.Route{} }