Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistent params validation #394

Merged
merged 9 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions x/ccv/consumer/keeper/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)

// TestParams tests the default params set for a consumer chain, and related getters/setters
// TestParams tests getters/setters for consumer params
func TestParams(t *testing.T) {
consumerKeeper, ctx, ctrl, _ := testkeeper.GetConsumerKeeperAndCtx(t, testkeeper.NewInMemKeeperParams(t))
defer ctrl.Finish()
Expand All @@ -21,7 +21,8 @@ func TestParams(t *testing.T) {
params := consumerKeeper.GetParams(ctx)
require.Equal(t, expParams, params)

newParams := types.NewParams(false, 1000, "abc", "def", 7*24*time.Hour)
newParams := types.NewParams(false, 1000,
"channel-2", "cosmos19pe9pg5dv9k5fzgzmsrgnw9rl9asf7ddwhu7lm", 7*24*time.Hour)
consumerKeeper.SetParams(ctx, newParams)
params = consumerKeeper.GetParams(ctx)
require.Equal(t, newParams, params)
Expand All @@ -30,12 +31,12 @@ func TestParams(t *testing.T) {
gotBPDT := consumerKeeper.GetBlocksPerDistributionTransmission(ctx)
require.Equal(t, gotBPDT, int64(10))

consumerKeeper.SetDistributionTransmissionChannel(ctx, "foobarbaz")
consumerKeeper.SetDistributionTransmissionChannel(ctx, "channel-7")
gotChan := consumerKeeper.GetDistributionTransmissionChannel(ctx)
require.Equal(t, gotChan, "foobarbaz")
require.Equal(t, gotChan, "channel-7")

consumerKeeper.SetProviderFeePoolAddrStr(ctx, "foobar")
consumerKeeper.SetProviderFeePoolAddrStr(ctx, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la")
gotAddr := consumerKeeper.
GetProviderFeePoolAddrStr(ctx)
require.Equal(t, gotAddr, "foobar")
require.Equal(t, gotAddr, "cosmos1dkas8mu4kyhl5jrh4nzvm65qz588hy9qcz08la")
}
6 changes: 3 additions & 3 deletions x/ccv/consumer/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (k Keeper) SendVSCMaturedPackets(ctx sdk.Context) error {
channelID, // source channel id
ccv.ConsumerPortID, // source port id
packetData.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
return err
Expand Down Expand Up @@ -154,7 +154,7 @@ func (k Keeper) SendSlashPacket(ctx sdk.Context, validator abci.Validator, valse
channelID, // source channel id
ccv.ConsumerPortID, // source port id
packetData.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
panic(err)
Expand Down Expand Up @@ -191,7 +191,7 @@ func (k Keeper) SendPendingSlashRequests(ctx sdk.Context) {
channelID, // source channel id
ccv.ConsumerPortID, // source port id
slashReq.Packet.GetBytes(),
k.GetParams(ctx).CcvTimeoutPeriod,
k.GetCCVTimeoutPeriod(ctx),
)
if err != nil {
panic(err)
Expand Down
3 changes: 3 additions & 0 deletions x/ccv/consumer/types/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ func (gs GenesisState) Validate() error {
if len(gs.InitialValSet) == 0 {
return sdkerrors.Wrap(ccv.ErrInvalidGenesis, "initial validator set is empty")
}
if err := gs.Params.Validate(); err != nil {
return err
}

if gs.NewChain {
if gs.ProviderClientState == nil {
Expand Down
24 changes: 24 additions & 0 deletions x/ccv/consumer/types/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,18 @@ func TestValidateInitialGenesisState(t *testing.T) {
valUpdates, types.SlashRequests{}, params),
true,
},
{
"invalid new consumer genesis state: invalid params",
types.NewInitialGenesisState(cs, consensusState, valUpdates, types.SlashRequests{},
types.NewParams(
true,
types.DefaultBlocksPerDistributionTransmission,
"",
"",
0, // CCV timeout period cannot be 0
)),
true,
},
}

for _, c := range cases {
Expand Down Expand Up @@ -257,6 +269,18 @@ func TestValidateRestartGenesisState(t *testing.T) {
types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, nil, nil, nil, params),
true,
},
{
"invalid restart consumer genesis state: invalid params",
types.NewRestartGenesisState("ccvclient", "ccvchannel", nil, valUpdates, nil, nil,
types.NewParams(
true,
types.DefaultBlocksPerDistributionTransmission,
"",
"",
0, // CCV timeout period cannot be 0
)),
true,
},
}

for _, c := range cases {
Expand Down
53 changes: 32 additions & 21 deletions x/ccv/consumer/types/params.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package types

import (
"fmt"
"time"

paramtypes "github.com/cosmos/cosmos-sdk/x/params/types"
Expand Down Expand Up @@ -52,41 +51,53 @@ func DefaultParams() Params {

// Validate all ccv-consumer module parameters
func (p Params) Validate() error {
if err := ccvtypes.ValidateBool(p.Enabled); err != nil {
return err
}
if err := ccvtypes.ValidatePositiveInt64(p.BlocksPerDistributionTransmission); err != nil {
return err
}
if err := validateDistributionTransmissionChannel(p.DistributionTransmissionChannel); err != nil {
return err
}
if err := validateProviderFeePoolAddrStr(p.ProviderFeePoolAddrStr); err != nil {
return err
}
if err := ccvtypes.ValidateDuration(p.CcvTimeoutPeriod); err != nil {
shaspitz marked this conversation as resolved.
Show resolved Hide resolved
return err
}
return nil
}

// ParamSetPairs implements params.ParamSet
func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs {
return paramtypes.ParamSetPairs{
paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, validateBool),
paramtypes.NewParamSetPair(KeyEnabled, p.Enabled, ccvtypes.ValidateBool),
paramtypes.NewParamSetPair(KeyBlocksPerDistributionTransmission,
p.BlocksPerDistributionTransmission, validateInt64),
p.BlocksPerDistributionTransmission, ccvtypes.ValidatePositiveInt64),
paramtypes.NewParamSetPair(KeyDistributionTransmissionChannel,
p.DistributionTransmissionChannel, validateString),
p.DistributionTransmissionChannel, validateDistributionTransmissionChannel),
paramtypes.NewParamSetPair(KeyProviderFeePoolAddrStr,
p.ProviderFeePoolAddrStr, validateString),
p.ProviderFeePoolAddrStr, validateProviderFeePoolAddrStr),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod,
p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod),
p.CcvTimeoutPeriod, ccvtypes.ValidateDuration),
}
}

func validateBool(i interface{}) error {
if _, ok := i.(bool); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
func validateDistributionTransmissionChannel(i interface{}) error {
// Accept empty string as valid, since this will be the default value on genesis
if i == "" {
return nil
}
return nil
// Otherwise validate as usual for a channelID
return ccvtypes.ValidateChannelIdentifier(i)
}

func validateInt64(i interface{}) error {
if _, ok := i.(int64); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
func validateProviderFeePoolAddrStr(i interface{}) error {
// Accept empty string as valid, since this will be the default value on genesis
if i == "" {
return nil
}
return nil
}

func validateString(i interface{}) error {
if _, ok := i.(string); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
// Otherwise validate as usual for a bech32 address
return ccvtypes.ValidateBech32(i)
}
35 changes: 35 additions & 0 deletions x/ccv/consumer/types/params_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package types_test

import (
"testing"

"github.com/stretchr/testify/require"

consumertypes "github.com/cosmos/interchain-security/x/ccv/consumer/types"
)

// Tests the validation of consumer params that happens at genesis
func TestValidateParams(t *testing.T) {

testCases := []struct {
name string
params consumertypes.Params
expPass bool
}{
{"default params", consumertypes.DefaultParams(), true},
{"custom valid params", consumertypes.NewParams(true, 5, "", "", 5), true},
{"custom invalid params, block per dist transmission", consumertypes.NewParams(true, -5, "", "", 5), false},
{"custom invalid params, dist transmission channel", consumertypes.NewParams(true, 5, "badchannel/", "", 5), false},
{"custom invalid params, provider fee pool addr string", consumertypes.NewParams(true, 5, "", "imabadaddress", 5), false},
{"custom invalid params, ccv timeout", consumertypes.NewParams(true, 5, "", "", -5), false},
}

for _, tc := range testCases {
err := tc.params.Validate()
if tc.expPass {
require.Nil(t, err, "expected error to be nil for test case: %s", tc.name)
} else {
require.NotNil(t, err, "expected error but got nil for test case: %s", tc.name)
}
}
}
4 changes: 2 additions & 2 deletions x/ccv/provider/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (p Params) Validate() error {
if p.TemplateClient == nil {
return fmt.Errorf("template client is nil")
}
if ccvtypes.ValidateCCVTimeoutPeriod(p.CcvTimeoutPeriod) != nil {
if ccvtypes.ValidateDuration(p.CcvTimeoutPeriod) != nil {
shaspitz marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("ccv timeout period is invalid")
}
return validateTemplateClient(*p.TemplateClient)
Expand All @@ -67,7 +67,7 @@ func (p Params) Validate() error {
func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs {
return paramtypes.ParamSetPairs{
paramtypes.NewParamSetPair(KeyTemplateClient, p.TemplateClient, validateTemplateClient),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateCCVTimeoutPeriod),
paramtypes.NewParamSetPair(ccvtypes.KeyCCVTimeoutPeriod, p.CcvTimeoutPeriod, ccvtypes.ValidateDuration),
}
}

Expand Down
55 changes: 53 additions & 2 deletions x/ccv/types/shared_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package types
import (
fmt "fmt"
"time"

sdktypes "github.com/cosmos/cosmos-sdk/types"
ibchost "github.com/cosmos/ibc-go/v3/modules/core/24-host"
)

const (
Expand All @@ -14,13 +17,61 @@ var (
KeyCCVTimeoutPeriod = []byte("CcvTimeoutPeriod")
)

func ValidateCCVTimeoutPeriod(i interface{}) error {
func ValidateDuration(i interface{}) error {
period, ok := i.(time.Duration)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
if period <= time.Duration(0) {
return fmt.Errorf("ibc timeout period is not positive")
return fmt.Errorf("duration must be positive")
}
return nil
}

func ValidateBool(i interface{}) error {
if _, ok := i.(bool); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidateInt64(i interface{}) error {
if _, ok := i.(int64); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidatePositiveInt64(i interface{}) error {
if err := ValidateInt64(i); err != nil {
return err
}
if i.(int64) <= int64(0) {
return fmt.Errorf("int must be positive")
}
return nil
}

func ValidateString(i interface{}) error {
if _, ok := i.(string); !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return nil
}

func ValidateChannelIdentifier(i interface{}) error {
value, ok := i.(string)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
return ibchost.ChannelIdentifierValidator(value)
}

func ValidateBech32(i interface{}) error {
value, ok := i.(string)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}
_, err := sdktypes.AccAddressFromBech32(value)
return err
}