Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(query): filtered collections pagination #16905

Merged
merged 15 commits into from
Jul 13, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Ref: https://keepachangelog.com/en/1.0.0/

* (server) [#16827](https://github.com/cosmos/cosmos-sdk/pull/16827) Properly use `--trace` flag (before it was setting the trace level instead of displaying the stacktraces).
* (x/bank) [#16841](https://github.com/cosmos/cosmos-sdk/pull/16841) correctly process legacy `DenomAddressIndex` values.
* (types/query) [#16905](https://github.com/cosmos/cosmos-sdk/pull/16905) – Collections Pagination now applies proper count when filtering results.

### API Breaking Changes

Expand Down
104 changes: 76 additions & 28 deletions types/query/collections_pagination.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,40 @@ type Collection[K, V any] interface {
KeyCodec() collcodec.KeyCodec[K]
}

// CollectionPaginate follows the same behavior as Paginate but works on a Collection.
func CollectionPaginate[K, V any, C Collection[K, V]](
// CollectionPaginate follows the same logic as Paginate but for collection types.
// transformFunc is used to transform the result to a different type.
func CollectionPaginate[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
pageReq *PageRequest,
) ([]collections.KeyValue[K, V], *PageResponse, error) {
return CollectionFilteredPaginate[K, V](ctx, coll, pageReq, nil)
transformFunc func(key K, value V) (T, error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]T, *PageResponse, error) {
return CollectionFilteredPaginate(
ctx,
coll,
pageReq,
nil,
transformFunc,
opts...,
)
}

// CollectionFilteredPaginate works in the same way as FilteredPaginate but for collection types.
// CollectionFilteredPaginate works in the same way as CollectionPaginate but allows to filter
// results using a predicateFunc.
// A nil predicateFunc means no filtering is applied and results are collected as is.
func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
// TransformFunc is applied only to results which are in range of the pagination and allow
// to convert the result to a different type.
// NOTE: do not collect results using the values/keys passed to predicateFunc as they are not
// guaranteed to be in the pagination range requested.
func CollectionFilteredPaginate[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
pageReq *PageRequest,
predicateFunc func(key K, value V) (include bool, err error),
transformFunc func(key K, value V) (T, error),
opts ...func(opt *CollectionsPaginateOptions[K]),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
) (results []T, pageRes *PageResponse, err error) {
pageReq = initPageRequestDefaults(pageReq)

offset := pageReq.Offset
Expand All @@ -65,12 +81,6 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
return nil, nil, fmt.Errorf("invalid request, either offset or key is expected, got both")
}

var (
results []collections.KeyValue[K, V]
pageRes *PageResponse
err error
)

opt := new(CollectionsPaginateOptions[K])
for _, o := range opts {
o(opt)
Expand All @@ -85,9 +95,9 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](
}

if len(key) != 0 {
results, pageRes, err = collFilteredPaginateByKey(ctx, coll, prefix, key, reverse, limit, predicateFunc)
results, pageRes, err = collFilteredPaginateByKey(ctx, coll, prefix, key, reverse, limit, predicateFunc, transformFunc)
} else {
results, pageRes, err = collFilteredPaginateNoKey(ctx, coll, prefix, reverse, offset, limit, countTotal, predicateFunc)
results, pageRes, err = collFilteredPaginateNoKey(ctx, coll, prefix, reverse, offset, limit, countTotal, predicateFunc, transformFunc)
}
// invalid iter error is ignored to retain Paginate behavior
if errors.Is(err, collections.ErrInvalidIterator) {
Expand All @@ -102,7 +112,7 @@ func CollectionFilteredPaginate[K, V any, C Collection[K, V]](

// collFilteredPaginateNoKey applies the provided pagination on the collection when the starting key is not set.
// If predicateFunc is nil no filtering is applied.
func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
func collFilteredPaginateNoKey[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
prefix []byte,
Expand All @@ -111,7 +121,8 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
limit uint64,
countTotal bool,
predicateFunc func(K, V) (bool, error),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
transformFunc func(K, V) (T, error),
) ([]T, *PageResponse, error) {
iterator, err := getCollIter[K, V](ctx, coll, prefix, nil, reverse)
if err != nil {
return nil, nil, err
Expand All @@ -125,7 +136,7 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
var (
count uint64
nextKey []byte
results []collections.KeyValue[K, V]
results []T
)

for ; iterator.Valid(); iterator.Next() {
Expand All @@ -138,18 +149,28 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
}
// if no predicate function is specified then we just include the result
if predicateFunc == nil {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
count++

// if predicate function is defined we check if the result matches the filtering criteria
} else {
include, err := predicateFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
if include {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
count++
}
}
count++
// second case, we found all the objects specified within the limit
case count == limit:
key, err := iterator.Key()
Expand All @@ -172,12 +193,31 @@ func collFilteredPaginateNoKey[K, V any, C Collection[K, V]](
// but we need to count how many possible results exist in total.
// so we keep increasing the count until the iterator is fully consumed.
case count > limit:
count++
if predicateFunc == nil {
count++

// if predicate function is defined we check if the result matches the filtering criteria
} else {
kv, err := iterator.KeyValue()
if err != nil {
return nil, nil, err
}

include, err := predicateFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
if include {
count++
}
}
}
}

resp := &PageResponse{
NextKey: nextKey,
}

if countTotal {
resp.Total = count + offset
}
Expand All @@ -200,15 +240,16 @@ func advanceIter[I interface {

// collFilteredPaginateByKey paginates a collection when a starting key
// is provided in the PageRequest. Predicate is applied only if not nil.
func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
func collFilteredPaginateByKey[K, V any, C Collection[K, V], T any](
ctx context.Context,
coll C,
prefix []byte,
key []byte,
reverse bool,
limit uint64,
predicateFunc func(K, V) (bool, error),
) ([]collections.KeyValue[K, V], *PageResponse, error) {
predicateFunc func(key K, value V) (bool, error),
transformFunc func(key K, value V) (transformed T, err error),
) (results []T, pageRes *PageResponse, err error) {
iterator, err := getCollIter[K, V](ctx, coll, prefix, key, reverse)
if err != nil {
return nil, nil, err
Expand All @@ -218,7 +259,6 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
var (
count uint64
nextKey []byte
results []collections.KeyValue[K, V]
)

for ; iterator.Valid(); iterator.Next() {
Expand All @@ -243,7 +283,11 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
}
// if no predicate is specified then we just append the result
if predicateFunc == nil {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
// if predicate is applied we execute the predicate function
// and append only if predicateFunc yields true.
} else {
Expand All @@ -252,7 +296,11 @@ func collFilteredPaginateByKey[K, V any, C Collection[K, V]](
return nil, nil, err
}
if include {
results = append(results, kv)
transformed, err := transformFunc(kv.Key, kv.Value)
if err != nil {
return nil, nil, err
}
results = append(results, transformed)
}
}
count++
Expand Down
13 changes: 11 additions & 2 deletions types/query/collections_pagination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ func TestCollectionPagination(t *testing.T) {
Limit: 3,
},
expResp: &PageResponse{
NextKey: encodeKey(3),
NextKey: encodeKey(5),
},
filter: func(key, value uint64) (bool, error) {
return key%2 == 0, nil
},
expResults: []collections.KeyValue[uint64, uint64]{
{Key: 0, Value: 0},
{Key: 2, Value: 2},
{Key: 4, Value: 4},
},
},
"filtered with key": {
Expand All @@ -131,7 +132,15 @@ func TestCollectionPagination(t *testing.T) {
for name, tc := range tcs {
tc := tc
t.Run(name, func(t *testing.T) {
gotResults, gotResponse, err := CollectionFilteredPaginate(ctx, m, tc.req, tc.filter)
gotResults, gotResponse, err := CollectionFilteredPaginate(
ctx,
m,
tc.req,
tc.filter,
func(key, value uint64) (collections.KeyValue[uint64, uint64], error) {
return collections.KeyValue[uint64, uint64]{Key: key, Value: value}, nil
},
)
if tc.wantErr != nil {
require.ErrorIs(t, err, tc.wantErr)
return
Expand Down
17 changes: 8 additions & 9 deletions x/auth/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,14 @@ func (s queryServer) Accounts(ctx context.Context, req *types.QueryAccountsReque
return nil, status.Error(codes.InvalidArgument, "empty request")
}

var accounts []*codectypes.Any
_, pageRes, err := query.CollectionFilteredPaginate(ctx, s.k.Accounts, req.Pagination, func(_ sdk.AccAddress, value sdk.AccountI) (include bool, err error) {
accountAny, err := codectypes.NewAnyWithValue(value)
if err != nil {
return false, err
}
accounts = append(accounts, accountAny)
return false, nil // we don't include it since we're already appending the account
})
accounts, pageRes, err := query.CollectionPaginate(
ctx,
s.k.Accounts,
req.Pagination,
func(_ sdk.AccAddress, value sdk.AccountI) (*codectypes.Any, error) {
return codectypes.NewAnyWithValue(value)
},
)

return &types.QueryAccountsResponse{Accounts: accounts, Pagination: pageRes}, err
}
Expand Down
63 changes: 30 additions & 33 deletions x/bank/keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,20 @@ func (k BaseKeeper) AllBalances(ctx context.Context, req *types.QueryAllBalances
}

sdkCtx := sdk.UnwrapSDKContext(ctx)

balances := sdk.NewCoins()

_, pageRes, err := query.CollectionFilteredPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], value math.Int) (include bool, err error) {
denom := key.K2()
if req.ResolveDenom {
if metadata, ok := k.GetDenomMetaData(sdkCtx, denom); ok {
denom = metadata.Display
balances, pageRes, err := query.CollectionPaginate(
ctx,
k.Balances,
req.Pagination,
func(key collections.Pair[sdk.AccAddress, string], value math.Int) (sdk.Coin, error) {
if req.ResolveDenom {
if metadata, ok := k.GetDenomMetaData(sdkCtx, key.K2()); ok {
return sdk.NewCoin(metadata.Display, value), nil
}
}
}
balances = append(balances, sdk.NewCoin(denom, value))
return false, nil // we don't include results because we're appending them here.
}, query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr))
return sdk.NewCoin(key.K2(), value), nil
},
query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr),
)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "paginate: %v", err)
}
Expand All @@ -94,12 +95,10 @@ func (k BaseKeeper) SpendableBalances(ctx context.Context, req *types.QuerySpend

sdkCtx := sdk.UnwrapSDKContext(ctx)

balances := sdk.NewCoins()
zeroAmt := math.ZeroInt()

_, pageRes, err := query.CollectionFilteredPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], _ math.Int) (include bool, err error) {
balances = append(balances, sdk.NewCoin(key.K2(), zeroAmt))
return false, nil // not including results as they're appended here
balances, pageRes, err := query.CollectionPaginate(ctx, k.Balances, req.Pagination, func(key collections.Pair[sdk.AccAddress, string], _ math.Int) (coin sdk.Coin, err error) {
return sdk.NewCoin(key.K2(), zeroAmt), nil
}, query.WithCollectionPaginationPairPrefix[sdk.AccAddress, string](addr))
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "paginate: %v", err)
Expand Down Expand Up @@ -280,19 +279,16 @@ func (k BaseKeeper) DenomOwners(
return nil, status.Error(codes.InvalidArgument, err.Error())
}

var denomOwners []*types.DenomOwner

_, pageRes, err := query.CollectionFilteredPaginate(goCtx, k.Balances.Indexes.Denom, req.Pagination,
func(key collections.Pair[string, sdk.AccAddress], value collections.NoValue) (include bool, err error) {
denomOwners, pageRes, err := query.CollectionPaginate(
goCtx,
k.Balances.Indexes.Denom,
req.Pagination,
func(key collections.Pair[string, sdk.AccAddress], value collections.NoValue) (*types.DenomOwner, error) {
amt, err := k.Balances.Get(goCtx, collections.Join(key.K2(), req.Denom))
if err != nil {
return false, err
return nil, err
}
denomOwners = append(denomOwners, &types.DenomOwner{
Address: key.K2().String(),
Balance: sdk.NewCoin(req.Denom, amt),
})
return false, nil
return &types.DenomOwner{Address: key.K2().String(), Balance: sdk.NewCoin(req.Denom, amt)}, nil
},
query.WithCollectionPaginationPairPrefix[string, sdk.AccAddress](req.Denom),
)
Expand All @@ -316,16 +312,17 @@ func (k BaseKeeper) SendEnabled(goCtx context.Context, req *types.QuerySendEnabl
}
}
} else {
results, pageResp, err := query.CollectionPaginate[string, bool](ctx, k.BaseViewKeeper.SendEnabled, req.Pagination)
results, pageResp, err := query.CollectionPaginate(
ctx,
k.BaseViewKeeper.SendEnabled,
req.Pagination, func(key string, value bool) (*types.SendEnabled, error) {
return types.NewSendEnabled(key, value), nil
},
)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
for _, r := range results {
resp.SendEnabled = append(resp.SendEnabled, &types.SendEnabled{
Denom: r.Key,
Enabled: r.Value,
})
}
resp.SendEnabled = results
resp.Pagination = pageResp
}

Expand Down
Loading