Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
tac0turtle committed Aug 8, 2023
1 parent b558e3a commit 1a9f45f
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 47 deletions.
6 changes: 3 additions & 3 deletions tests/integration/staking/keeper/deterministic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,15 +663,15 @@ func TestGRPCHistoricalInfo(t *testing.T) {

rapid.Check(t, func(rt *rapid.T) {
numVals := rapid.IntRange(1, 5).Draw(rt, "num-vals")
vals := make(stakingtypes.Validators, 0, numVals)
vals := stakingtypes.Validators{}
for i := 0; i < numVals; i++ {
validator := createAndSetValidatorWithStatus(t, rt, f, stakingtypes.Bonded)
vals = append(vals, validator)
vals.Validators = append(vals.Validators, validator)
}

historicalInfo := stakingtypes.HistoricalInfo{
Header: cmtproto.Header{},
Valset: vals,
Valset: vals.Validators,
}

height := rapid.Int64Min(0).Draw(rt, "height")
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/staking/keeper/grpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func createValidatorAccs(t *testing.T, f *fixture) ([]sdk.AccAddress, []types.Va
// have its order changed
sortedVals := make([]types.Validator, len(validators))
copy(sortedVals, validators)
hi := types.NewHistoricalInfo(header, sortedVals, f.stakingKeeper.PowerReduction(f.sdkCtx))
hi := types.NewHistoricalInfo(header, types.Validators{Validators: sortedVals}, f.stakingKeeper.PowerReduction(f.sdkCtx))
assert.NilError(t, f.stakingKeeper.HistoricalInfo.Set(f.sdkCtx, uint64(5), hi))

return addrs, validators
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestGRPCQueryDelegatorValidators(t *testing.T) {
assert.NilError(t, err)
assert.Equal(t, 1, len(res.Validators))
assert.Assert(t, res.Pagination.NextKey != nil)
assert.Equal(t, uint64(len(delValidators)), res.Pagination.Total)
assert.Equal(t, uint64(len(delValidators.Validators)), res.Pagination.Total)
} else {
assert.ErrorContains(t, err, tc.expErrMsg)
assert.Assert(t, res == nil)
Expand Down
4 changes: 2 additions & 2 deletions x/staking/keeper/delegation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ func (s *KeeperTestSuite) TestDelegation() {

resVals, err := keeper.GetDelegatorValidators(ctx, addrDels[0], 3)
require.NoError(err)
require.Equal(3, len(resVals))
require.Equal(3, len(resVals.Validators))
resVals, err = keeper.GetDelegatorValidators(ctx, addrDels[1], 4)
require.NoError(err)
require.Equal(3, len(resVals))
require.Equal(3, len(resVals.Validators))

for i := 0; i < 3; i++ {
resVal, err := keeper.GetDelegatorValidator(ctx, addrDels[0], valAddrs[i])
Expand Down
8 changes: 4 additions & 4 deletions x/staking/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ func (k Querier) Validators(ctx context.Context, req *types.QueryValidatorsReque

vals := types.Validators{}
for _, val := range validators {
vals = append(vals, *val)
vals.Validators = append(vals.Validators, *val)
}

return &types.QueryValidatorsResponse{Validators: vals, Pagination: pageRes}, nil
return &types.QueryValidatorsResponse{Validators: vals.Validators, Pagination: pageRes}, nil
}

// Validator queries validator info for given validator address
Expand Down Expand Up @@ -465,14 +465,14 @@ func (k Querier) DelegatorValidators(ctx context.Context, req *types.QueryDelega
return err
}

validators = append(validators, validator)
validators.Validators = append(validators.Validators, validator)
return nil
})
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

return &types.QueryDelegatorValidatorsResponse{Validators: validators, Pagination: pageRes}, nil
return &types.QueryDelegatorValidatorsResponse{Validators: validators.Validators, Pagination: pageRes}, nil
}

// Pool queries the pool info
Expand Down
2 changes: 1 addition & 1 deletion x/staking/keeper/historical_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (k Keeper) TrackHistoricalInfo(ctx context.Context) error {
return err
}

historicalEntry := types.NewHistoricalInfo(sdkCtx.BlockHeader(), lastVals, k.PowerReduction(ctx))
historicalEntry := types.NewHistoricalInfo(sdkCtx.BlockHeader(), types.Validators{Validators: lastVals, ValidatorCodec: k.validatorAddressCodec}, k.PowerReduction(ctx))

// Set latest HistoricalInfo at current height
return k.HistoricalInfo.Set(ctx, uint64(sdkCtx.BlockHeight()), historicalEntry)
Expand Down
6 changes: 3 additions & 3 deletions x/staking/keeper/historical_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (s *KeeperTestSuite) TestHistoricalInfo() {
validators[i] = testutil.NewValidator(s.T(), valAddr, PKs[i])
}

hi := stakingtypes.NewHistoricalInfo(ctx.BlockHeader(), validators, keeper.PowerReduction(ctx))
hi := stakingtypes.NewHistoricalInfo(ctx.BlockHeader(), stakingtypes.Validators{Validators: validators}, keeper.PowerReduction(ctx))
require.NoError(keeper.HistoricalInfo.Set(ctx, uint64(2), hi))

recv, err := keeper.HistoricalInfo.Get(ctx, uint64(2))
Expand Down Expand Up @@ -73,8 +73,8 @@ func (s *KeeperTestSuite) TestTrackHistoricalInfo() {
testutil.NewValidator(s.T(), addrVals[0], PKs[0]),
testutil.NewValidator(s.T(), addrVals[1], PKs[1]),
}
hi4 := stakingtypes.NewHistoricalInfo(h4, valSet, keeper.PowerReduction(ctx))
hi5 := stakingtypes.NewHistoricalInfo(h5, valSet, keeper.PowerReduction(ctx))
hi4 := stakingtypes.NewHistoricalInfo(h4, stakingtypes.Validators{Validators: valSet}, keeper.PowerReduction(ctx))
hi5 := stakingtypes.NewHistoricalInfo(h5, stakingtypes.Validators{Validators: valSet}, keeper.PowerReduction(ctx))
require.NoError(keeper.HistoricalInfo.Set(ctx, uint64(4), hi4))
require.NoError(keeper.HistoricalInfo.Set(ctx, uint64(5), hi5))
recv, err := keeper.HistoricalInfo.Get(ctx, uint64(4))
Expand Down
8 changes: 4 additions & 4 deletions x/staking/keeper/query_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (k Keeper) GetDelegatorValidators(

iterator, err := store.Iterator(delegatorPrefixKey, storetypes.PrefixEndBytes(delegatorPrefixKey)) // smallest to largest
if err != nil {
return nil, err
return types.Validators{}, err
}
defer iterator.Close()

Expand All @@ -29,19 +29,19 @@ func (k Keeper) GetDelegatorValidators(

valAddr, err := k.validatorAddressCodec.StringToBytes(delegation.GetValidatorAddr())
if err != nil {
return nil, err
return types.Validators{}, err
}

validator, err := k.GetValidator(ctx, valAddr)
if err != nil {
return nil, err
return types.Validators{}, err
}

validators[i] = validator
i++
}

return validators[:i], nil // trim
return types.Validators{Validators: validators[:i], ValidatorCodec: k.validatorAddressCodec}, nil // trim
}

// GetDelegatorValidator returns a validator that a delegator is bonded to
Expand Down
4 changes: 2 additions & 2 deletions x/staking/testutil/cmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ func ToCmtValidator(v types.Validator, r math.Int) (*cmttypes.Validator, error)

// ToCmtValidators casts all validators to the corresponding CometBFT type.
func ToCmtValidators(v types.Validators, r math.Int) ([]*cmttypes.Validator, error) {
validators := make([]*cmttypes.Validator, len(v))
validators := make([]*cmttypes.Validator, len(v.Validators))
var err error
for i, val := range v {
for i, val := range v.Validators {
validators[i], err = ToCmtValidator(val, r)
if err != nil {
return nil, err
Expand Down
11 changes: 6 additions & 5 deletions x/staking/types/historical_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/cosmos/gogoproto/proto"

"cosmossdk.io/core/address"
"cosmossdk.io/errors"
"cosmossdk.io/math"

Expand All @@ -16,23 +17,23 @@ import (
// it will first sort valset before inclusion into historical info
func NewHistoricalInfo(header cmtproto.Header, valSet Validators, powerReduction math.Int) HistoricalInfo {
// Must sort in the same way that CometBFT does
sort.SliceStable(valSet, func(i, j int) bool {
return ValidatorsByVotingPower(valSet).Less(i, j, powerReduction)
sort.SliceStable(valSet.Validators, func(i, j int) bool {
return ValidatorsByVotingPower(valSet.Validators).Less(i, j, powerReduction)
})

return HistoricalInfo{
Header: header,
Valset: valSet,
Valset: valSet.Validators,
}
}

// ValidateBasic will ensure HistoricalInfo is not nil and sorted
func ValidateBasic(hi HistoricalInfo) error {
func ValidateBasic(hi HistoricalInfo, valAc address.Codec) error {
if len(hi.Valset) == 0 {
return errors.Wrap(ErrInvalidHistoricalInfo, "validator set is empty")
}

if !sort.IsSorted(Validators(hi.Valset)) {
if !sort.IsSorted(Validators{Validators: hi.Valset, ValidatorCodec: valAc}) {
return errors.Wrap(ErrInvalidHistoricalInfo, "validator set is not sorted by address")
}

Expand Down
12 changes: 7 additions & 5 deletions x/staking/types/historical_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
"github.com/stretchr/testify/require"

"github.com/cosmos/cosmos-sdk/codec/address"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/cosmos/cosmos-sdk/x/staking/types"
)
Expand All @@ -31,11 +32,12 @@ func TestValidateBasic(t *testing.T) {
hi := types.HistoricalInfo{
Header: header,
}
err := types.ValidateBasic(hi)
ac := address.NewBech32Codec("cosmosvaloper")
err := types.ValidateBasic(hi, ac)
require.Error(t, err, "ValidateBasic passed on nil ValSet")

// Ensure validators are not sorted
for sort.IsSorted(types.Validators(validators)) {
for sort.IsSorted(types.Validators{Validators: validators, ValidatorCodec: ac}) {
rand.Shuffle(len(validators), func(i, j int) {
validators[i], validators[j] = validators[j], validators[i]
})
Expand All @@ -44,10 +46,10 @@ func TestValidateBasic(t *testing.T) {
Header: header,
Valset: validators,
}
err = types.ValidateBasic(hi)
err = types.ValidateBasic(hi, ac)
require.Error(t, err, "ValidateBasic passed on unsorted ValSet")

hi = types.NewHistoricalInfo(header, validators, sdk.DefaultPowerReduction)
err = types.ValidateBasic(hi)
hi = types.NewHistoricalInfo(header, types.Validators{Validators: validators, ValidatorCodec: ac}, sdk.DefaultPowerReduction)
err = types.ValidateBasic(hi, ac)
require.NoError(t, err, "ValidateBasic failed on valid HistoricalInfo")
}
29 changes: 21 additions & 8 deletions x/staking/types/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
abci "github.com/cometbft/cometbft/abci/types"
cmtprotocrypto "github.com/cometbft/cometbft/proto/tendermint/crypto"

"cosmossdk.io/core/address"
"cosmossdk.io/errors"
"cosmossdk.io/math"

Expand Down Expand Up @@ -63,10 +64,13 @@ func NewValidator(operator string, pubKey cryptotypes.PubKey, description Descri
}

// Validators is a collection of Validator
type Validators []Validator
type Validators struct {
Validators []Validator
ValidatorCodec address.Codec
}

func (v Validators) String() (out string) {
for _, val := range v {
for _, val := range v.Validators {
out += val.String() + "\n"
}

Expand All @@ -75,7 +79,7 @@ func (v Validators) String() (out string) {

// ToSDKValidators - convenience function convert []Validator to []sdk.ValidatorI
func (v Validators) ToSDKValidators() (validators []ValidatorI) {
for _, val := range v {
for _, val := range v.Validators {
validators = append(validators, val)
}

Expand All @@ -89,17 +93,26 @@ func (v Validators) Sort() {

// Implements sort interface
func (v Validators) Len() int {
return len(v)
return len(v.Validators)
}

// Implements sort interface
func (v Validators) Less(i, j int) bool {
return strings.Compare(v[i].GetOperator(), v[j].GetOperator()) == -1
vi, err := v.ValidatorCodec.StringToBytes(v.Validators[i].GetOperator())
if err != nil {
panic(err)
}
vj, err := v.ValidatorCodec.StringToBytes(v.Validators[j].GetOperator())
if err != nil {
panic(err)
}

return bytes.Compare(vi, vj) == -1
}

// Implements sort interface
func (v Validators) Swap(i, j int) {
v[i], v[j] = v[j], v[i]
v.Validators[i], v.Validators[j] = v.Validators[j], v.Validators[i]
}

// ValidatorsByVotingPower implements sort.Interface for []Validator based on
Expand Down Expand Up @@ -129,8 +142,8 @@ func (valz ValidatorsByVotingPower) Swap(i, j int) {

// UnpackInterfaces implements UnpackInterfacesMessage.UnpackInterfaces
func (v Validators) UnpackInterfaces(c codectypes.AnyUnpacker) error {
for i := range v {
if err := v[i].UnpackInterfaces(c); err != nil {
for i := range v.Validators {
if err := v.Validators[i].UnpackInterfaces(c); err != nil {
return err
}
}
Expand Down
17 changes: 9 additions & 8 deletions x/staking/types/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"cosmossdk.io/math"

"github.com/cosmos/cosmos-sdk/codec/address"
"github.com/cosmos/cosmos-sdk/codec/legacy"
cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
Expand Down Expand Up @@ -256,7 +257,7 @@ func TestValidatorsSortDeterminism(t *testing.T) {
}

// Save sorted copy
sort.Sort(types.Validators(vals))
sort.Sort(types.Validators{Validators: vals, ValidatorCodec: address.NewBech32Codec("cosmosvaloper")})
copy(sortedVals, vals)

// Randomly shuffle validators, sort, and check it is equal to original sort
Expand All @@ -265,7 +266,7 @@ func TestValidatorsSortDeterminism(t *testing.T) {
vals[i], vals[j] = vals[j], vals[i]
})

types.Validators(vals).Sort()
types.Validators{Validators: vals, ValidatorCodec: address.NewBech32Codec("cosmosvaloper")}.Sort()
require.Equal(t, sortedVals, vals, "Validator sort returned different slices")
}
}
Expand All @@ -286,16 +287,16 @@ func TestValidatorsSortCometBFT(t *testing.T) {
vals[i].Tokens = math.NewInt(1000000)
}

valz := types.Validators(vals)
valz := types.Validators{Validators: vals, ValidatorCodec: address.NewBech32Codec("cosmosvaloper")}

// create expected CometBFT validators by converting to CometBFT then sorting
expectedVals, err := testutil.ToCmtValidators(valz, sdk.DefaultPowerReduction)
require.NoError(t, err)
sort.Sort(cmttypes.ValidatorsByVotingPower(expectedVals))

// sort in SDK and then convert to CometBFT
sort.SliceStable(valz, func(i, j int) bool {
return types.ValidatorsByVotingPower(valz).Less(i, j, sdk.DefaultPowerReduction)
sort.SliceStable(valz.Validators, func(i, j int) bool {
return types.ValidatorsByVotingPower(valz.Validators).Less(i, j, sdk.DefaultPowerReduction)
})
actualVals, err := testutil.ToCmtValidators(valz, sdk.DefaultPowerReduction)
require.NoError(t, err)
Expand All @@ -304,15 +305,15 @@ func TestValidatorsSortCometBFT(t *testing.T) {
}

func TestValidatorToCmt(t *testing.T) {
vals := make(types.Validators, 10)
vals := types.Validators{}
expected := make([]*cmttypes.Validator, 10)

for i := range vals {
for i := 0; i < 10; i++ {
pk := ed25519.GenPrivKey().PubKey()
val := newValidator(t, sdk.ValAddress(pk.Address()), pk)
val.Status = types.Bonded
val.Tokens = math.NewInt(rand.Int63())
vals[i] = val
vals.Validators = append(vals.Validators, val)
cmtPk, err := cryptocodec.ToCmtPubKeyInterface(pk)
require.NoError(t, err)
expected[i] = cmttypes.NewValidator(cmtPk, val.ConsensusPower(sdk.DefaultPowerReduction))
Expand Down

0 comments on commit 1a9f45f

Please sign in to comment.