diff --git a/x/wasm/client/cli/query.go b/x/wasm/client/cli/query.go index 2b064f5fbd..7dd820e3fb 100644 --- a/x/wasm/client/cli/query.go +++ b/x/wasm/client/cli/query.go @@ -144,7 +144,7 @@ func GetCmdListCode() *cobra.Command { SilenceUsage: true, } flags.AddQueryFlagsToCmd(cmd) - flags.AddPaginationFlagsToCmd(cmd, "list codes") + addPaginationFlags(cmd, "list codes") return cmd } @@ -190,7 +190,7 @@ func GetCmdListContractByCode() *cobra.Command { SilenceUsage: true, } flags.AddQueryFlagsToCmd(cmd) - flags.AddPaginationFlagsToCmd(cmd, "list contracts by code") + addPaginationFlags(cmd, "list contracts by code") return cmd } @@ -368,10 +368,7 @@ func GetCmdGetContractStateAll() *cobra.Command { SilenceUsage: true, } flags.AddQueryFlagsToCmd(cmd) - cmd.Flags().String(flags.FlagPageKey, "", "pagination page-key of contract state to query for") - cmd.Flags().Uint64(flags.FlagLimit, 100, "pagination limit of contract state to query for") - cmd.Flags().Bool(flags.FlagReverse, false, "results are sorted in descending order") - + addPaginationFlags(cmd, "contract state") return cmd } @@ -507,7 +504,7 @@ func GetCmdGetContractHistory() *cobra.Command { } flags.AddQueryFlagsToCmd(cmd) - flags.AddPaginationFlagsToCmd(cmd, "contract history") + addPaginationFlags(cmd, "contract history") return cmd } @@ -543,7 +540,7 @@ func GetCmdListPinnedCode() *cobra.Command { SilenceUsage: true, } flags.AddQueryFlagsToCmd(cmd) - flags.AddPaginationFlagsToCmd(cmd, "list codes") + addPaginationFlags(cmd, "list codes") return cmd } @@ -584,7 +581,7 @@ func GetCmdListContractsByCreator() *cobra.Command { SilenceUsage: true, } flags.AddQueryFlagsToCmd(cmd) - flags.AddPaginationFlagsToCmd(cmd, "list contracts by creator") + addPaginationFlags(cmd, "list contracts by creator") return cmd } @@ -677,3 +674,10 @@ func GetCmdQueryParams() *cobra.Command { return cmd } + +// supports a subset of the SDK pagination params for better resource utilization +func addPaginationFlags(cmd *cobra.Command, query string) { + cmd.Flags().String(flags.FlagPageKey, "", fmt.Sprintf("pagination page-key of %s to query for", query)) + cmd.Flags().Uint64(flags.FlagLimit, 100, fmt.Sprintf("pagination limit of %s to query for", query)) + cmd.Flags().Bool(flags.FlagReverse, false, "results are sorted in descending order") +} diff --git a/x/wasm/keeper/querier.go b/x/wasm/keeper/querier.go index 78c48c3a42..3cc80d2d6e 100644 --- a/x/wasm/keeper/querier.go +++ b/x/wasm/keeper/querier.go @@ -61,12 +61,15 @@ func (q GrpcQuerier) ContractHistory(c context.Context, req *types.QueryContract if err != nil { return nil, err } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } ctx := sdk.UnwrapSDKContext(c) r := make([]types.ContractCodeHistoryEntry, 0) - prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.GetContractCodeHistoryElementPrefix(contractAddr)) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, value []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, value []byte, accumulate bool) (bool, error) { if accumulate { var e types.ContractCodeHistoryEntry if err := q.cdc.Unmarshal(value, &e); err != nil { @@ -93,11 +96,15 @@ func (q GrpcQuerier) ContractsByCode(c context.Context, req *types.QueryContract if req.CodeId == 0 { return nil, errorsmod.Wrap(types.ErrInvalid, "code id") } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) r := make([]string, 0) - prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.GetContractByCodeIDSecondaryIndexPrefix(req.CodeId)) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, value []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, value []byte, accumulate bool) (bool, error) { if accumulate { var contractAddr sdk.AccAddress = key[types.AbsoluteTxPositionLen:] r = append(r, contractAddr.String()) @@ -117,14 +124,16 @@ func (q GrpcQuerier) AllContractState(c context.Context, req *types.QueryAllCont if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") } - if req.Pagination != nil && - (req.Pagination.Offset != 0 || req.Pagination.CountTotal) { - return nil, status.Error(codes.InvalidArgument, "offset and count queries not supported anymore") - } + contractAddr, err := sdk.AccAddressFromBech32(req.Address) if err != nil { return nil, err } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) if !q.keeper.HasContractInfo(ctx, contractAddr) { return nil, types.ErrNoSuchContractFn(contractAddr.String()). @@ -133,7 +142,7 @@ func (q GrpcQuerier) AllContractState(c context.Context, req *types.QueryAllCont r := make([]types.Model, 0) prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.GetContractStorePrefix(contractAddr)) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, value []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, value []byte, accumulate bool) (bool, error) { if accumulate { r = append(r, types.Model{ Key: key, @@ -238,10 +247,15 @@ func (q GrpcQuerier) Codes(c context.Context, req *types.QueryCodesRequest) (*ty if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) r := make([]types.CodeInfoResponse, 0) prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.CodeKeyPrefix) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, value []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, value []byte, accumulate bool) (bool, error) { if accumulate { var c types.CodeInfo if err := q.cdc.Unmarshal(value, &c); err != nil { @@ -302,11 +316,15 @@ func (q GrpcQuerier) PinnedCodes(c context.Context, req *types.QueryPinnedCodesR if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) r := make([]uint64, 0) - prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.PinnedCodeIndexPrefix) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, _ []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, _ []byte, accumulate bool) (bool, error) { if accumulate { r = append(r, sdk.BigEndianToUint64(key)) } @@ -332,6 +350,11 @@ func (q GrpcQuerier) ContractsByCreator(c context.Context, req *types.QueryContr if req == nil { return nil, status.Error(codes.InvalidArgument, "empty request") } + paginationParams, err := ensurePaginationParams(req.Pagination) + if err != nil { + return nil, err + } + ctx := sdk.UnwrapSDKContext(c) contracts := make([]string, 0) @@ -340,7 +363,7 @@ func (q GrpcQuerier) ContractsByCreator(c context.Context, req *types.QueryContr return nil, err } prefixStore := prefix.NewStore(ctx.KVStore(q.storeKey), types.GetContractsByCreatorPrefix(creatorAddress)) - pageRes, err := query.FilteredPaginate(prefixStore, req.Pagination, func(key, _ []byte, accumulate bool) (bool, error) { + pageRes, err := query.FilteredPaginate(prefixStore, paginationParams, func(key, _ []byte, accumulate bool) (bool, error) { if accumulate { accAddres := sdk.AccAddress(key[types.AbsoluteTxPositionLen:]) contracts = append(contracts, accAddres.String()) @@ -356,3 +379,25 @@ func (q GrpcQuerier) ContractsByCreator(c context.Context, req *types.QueryContr Pagination: pageRes, }, nil } + +// max limit to pagination queries +const maxResultEntries = 100 + +var errLegacyPaginationUnsupported = status.Error(codes.InvalidArgument, "offset and count queries not supported") + +// ensure that pagination is done via key iterator with reasonable limit +func ensurePaginationParams(req *query.PageRequest) (*query.PageRequest, error) { + if req == nil { + return &query.PageRequest{ + Key: nil, + Limit: query.DefaultLimit, + }, nil + } + if req.Offset != 0 || req.CountTotal { + return nil, errLegacyPaginationUnsupported + } + if req.Limit > maxResultEntries || req.Limit <= 0 { + req.Limit = maxResultEntries + } + return req, nil +} diff --git a/x/wasm/keeper/querier_test.go b/x/wasm/keeper/querier_test.go index af9a6e3f4f..8de6b6aabf 100644 --- a/x/wasm/keeper/querier_test.go +++ b/x/wasm/keeper/querier_test.go @@ -64,7 +64,7 @@ func TestQueryAllContractState(t *testing.T) { Offset: 1, }, }, - expErr: status.Error(codes.InvalidArgument, "offset and count queries not supported anymore"), + expErr: errLegacyPaginationUnsupported, }, "with pagination count": { srcQuery: &types.QueryAllContractStateRequest{ @@ -73,7 +73,7 @@ func TestQueryAllContractState(t *testing.T) { CountTotal: true, }, }, - expErr: status.Error(codes.InvalidArgument, "offset and count queries not supported anymore"), + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { srcQuery: &types.QueryAllContractStateRequest{ @@ -347,7 +347,7 @@ func TestQueryContractsByCode(t *testing.T) { Offset: 5, }, }, - expAddr: contractAddrs[5:10], + expErr: errLegacyPaginationUnsupported, }, "with invalid pagination key": { req: &types.QueryContractsByCodeRequest{ @@ -357,7 +357,7 @@ func TestQueryContractsByCode(t *testing.T) { Key: []byte("test"), }, }, - expErr: fmt.Errorf("invalid request, either offset or key is expected, got both"), + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { req: &types.QueryContractsByCodeRequest{ @@ -406,6 +406,7 @@ func TestQueryContractHistory(t *testing.T) { srcHistory []types.ContractCodeHistoryEntry req types.QueryContractHistoryRequest expContent []types.ContractCodeHistoryEntry + expErr error }{ "response with internal fields cleared": { srcHistory: []types.ContractCodeHistoryEntry{{ @@ -475,12 +476,7 @@ func TestQueryContractHistory(t *testing.T) { Offset: 1, }, }, - expContent: []types.ContractCodeHistoryEntry{{ - Operation: types.ContractCodeHistoryOperationTypeMigrate, - CodeID: 2, - Msg: []byte(`"migrate message 1"`), - Updated: &types.AbsoluteTxPosition{BlockHeight: 3, TxIndex: 4}, - }}, + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { srcHistory: []types.ContractCodeHistoryEntry{{ @@ -515,7 +511,7 @@ func TestQueryContractHistory(t *testing.T) { Updated: types.NewAbsoluteTxPosition(ctx), Msg: []byte(`"init message"`), }}, - expContent: nil, + expContent: []types.ContractCodeHistoryEntry{}, }, } for msg, spec := range specs { @@ -527,14 +523,14 @@ func TestQueryContractHistory(t *testing.T) { // when q := Querier(keeper) - got, err := q.ContractHistory(sdk.WrapSDKContext(xCtx), &spec.req) //nolint:gosec - + got, gotErr := q.ContractHistory(sdk.WrapSDKContext(xCtx), &spec.req) //nolint:gosec // then - if spec.expContent == nil { - require.Error(t, types.ErrEmpty) + if spec.expErr != nil { + require.Error(t, gotErr) + assert.ErrorIs(t, gotErr, spec.expErr) return } - require.NoError(t, err) + require.NoError(t, gotErr) assert.Equal(t, spec.expContent, got.Entries) }) } @@ -551,6 +547,7 @@ func TestQueryCodeList(t *testing.T) { storedCodeIDs []uint64 req types.QueryCodesRequest expCodeIDs []uint64 + expErr error }{ "none": {}, "no gaps": { @@ -568,7 +565,7 @@ func TestQueryCodeList(t *testing.T) { Offset: 1, }, }, - expCodeIDs: []uint64{2, 3}, + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { storedCodeIDs: []uint64{1, 2, 3}, @@ -602,10 +599,15 @@ func TestQueryCodeList(t *testing.T) { } // when q := Querier(keeper) - got, err := q.Codes(sdk.WrapSDKContext(xCtx), &spec.req) //nolint:gosec + got, gotErr := q.Codes(sdk.WrapSDKContext(xCtx), &spec.req) //nolint:gosec // then - require.NoError(t, err) + if spec.expErr != nil { + require.Error(t, gotErr) + require.ErrorIs(t, gotErr, spec.expErr) + return + } + require.NoError(t, gotErr) require.NotNil(t, got.CodeInfos) require.Len(t, got.CodeInfos, len(spec.expCodeIDs)) for i, exp := range spec.expCodeIDs { @@ -695,7 +697,7 @@ func TestQueryPinnedCodes(t *testing.T) { specs := map[string]struct { srcQuery *types.QueryPinnedCodesRequest expCodeIDs []uint64 - expErr *errorsmod.Error + expErr error }{ "query all": { srcQuery: &types.QueryPinnedCodesRequest{}, @@ -707,7 +709,7 @@ func TestQueryPinnedCodes(t *testing.T) { Offset: 1, }, }, - expCodeIDs: []uint64{exampleContract2.CodeID}, + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { srcQuery: &types.QueryPinnedCodesRequest{ @@ -728,11 +730,13 @@ func TestQueryPinnedCodes(t *testing.T) { } for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - got, err := q.PinnedCodes(sdk.WrapSDKContext(ctx), spec.srcQuery) - require.True(t, spec.expErr.Is(err), err) + got, gotErr := q.PinnedCodes(sdk.WrapSDKContext(ctx), spec.srcQuery) if spec.expErr != nil { + require.Error(t, gotErr) + assert.ErrorIs(t, gotErr, spec.expErr) return } + require.NoError(t, gotErr) require.NotNil(t, got) assert.Equal(t, spec.expCodeIDs, got.CodeIDs) }) @@ -948,8 +952,7 @@ func TestQueryContractsByCreatorList(t *testing.T) { Offset: 1, }, }, - expContractAddr: allExpecedContracts[1:], - expErr: nil, + expErr: errLegacyPaginationUnsupported, }, "with pagination limit": { srcQuery: &types.QueryContractsByCreatorRequest{ @@ -978,13 +981,13 @@ func TestQueryContractsByCreatorList(t *testing.T) { q := Querier(keepers.WasmKeeper) for msg, spec := range specs { t.Run(msg, func(t *testing.T) { - got, err := q.ContractsByCreator(sdk.WrapSDKContext(ctx), spec.srcQuery) - + got, gotErr := q.ContractsByCreator(sdk.WrapSDKContext(ctx), spec.srcQuery) if spec.expErr != nil { - require.Equal(t, spec.expErr, err) + require.Error(t, gotErr) + assert.ErrorContains(t, gotErr, spec.expErr.Error()) return } - require.NoError(t, err) + require.NoError(t, gotErr) require.NotNil(t, got) assert.Equal(t, spec.expContractAddr, got.ContractAddresses) }) @@ -998,3 +1001,47 @@ func fromBase64(s string) []byte { } return r } + +func TestEnsurePaginationParams(t *testing.T) { + specs := map[string]struct { + src *query.PageRequest + exp *query.PageRequest + expErr error + }{ + "custom limit": { + src: &query.PageRequest{Limit: 10}, + exp: &query.PageRequest{Limit: 10}, + }, + "limit not set": { + src: &query.PageRequest{}, + exp: &query.PageRequest{Limit: 100}, + }, + "limit > max": { + src: &query.PageRequest{Limit: 101}, + exp: &query.PageRequest{Limit: 100}, + }, + "no pagination params set": { + exp: &query.PageRequest{Limit: 100}, + }, + "non empty offset": { + src: &query.PageRequest{Offset: 1}, + expErr: errLegacyPaginationUnsupported, + }, + "count enabled": { + src: &query.PageRequest{CountTotal: true}, + expErr: errLegacyPaginationUnsupported, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + got, gotErr := ensurePaginationParams(spec.src) + if spec.expErr != nil { + require.Error(t, gotErr) + assert.ErrorIs(t, gotErr, spec.expErr) + return + } + require.NoError(t, gotErr) + assert.Equal(t, spec.exp, got) + }) + } +}