From b8a1b6d75961c0f76d50de0a8be8af60bc345ec6 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 26 Apr 2024 11:15:22 +0800 Subject: [PATCH] fix: Pass offset param in rerank params for HybridSearch (#737) See also milvus-io/milvus#32562 --------- Signed-off-by: Congqi Xia Signed-off-by: ThreadDao Co-authored-by: ThreadDao --- client/data.go | 5 +++++ test/base/milvus_client.go | 6 +++--- test/common/utils.go | 7 ++++--- test/testcases/hybrid_search_test.go | 18 +++++++++++------- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/client/data.go b/client/data.go index 4cb19b274..62a4e4c4c 100644 --- a/client/data.go +++ b/client/data.go @@ -67,8 +67,13 @@ func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitio sReqs = append(sReqs, r) } + opt := &SearchQueryOption{} + for _, o := range opts { + o(opt) + } params := reranker.GetParams() params = append(params, &commonpb.KeyValuePair{Key: limitKey, Value: strconv.FormatInt(int64(limit), 10)}) + params = append(params, &commonpb.KeyValuePair{Key: offsetKey, Value: strconv.FormatInt(int64(opt.Offset), 10)}) req := &milvuspb.HybridSearchRequest{ CollectionName: collName, diff --git a/test/base/milvus_client.go b/test/base/milvus_client.go index bb94dd2a9..1a21d615c 100644 --- a/test/base/milvus_client.go +++ b/test/base/milvus_client.go @@ -427,11 +427,11 @@ func (mc *MilvusClient) Search(ctx context.Context, collName string, partitions } func (mc *MilvusClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, - reranker client.Reranker, subRequests []*client.ANNSearchRequest) ([]client.SearchResult, error) { + reranker client.Reranker, subRequests []*client.ANNSearchRequest, opts ...client.SearchQueryOptionFunc) ([]client.SearchResult, error) { funcName := "HybridSearch" - preRequest(funcName, ctx, collName, partitions, limit, outputFields, reranker, subRequests) + preRequest(funcName, ctx, collName, partitions, limit, outputFields, reranker, subRequests, opts) - searchResult, err := mc.mClient.HybridSearch(ctx, collName, partitions, limit, outputFields, reranker, subRequests) + searchResult, err := mc.mClient.HybridSearch(ctx, collName, partitions, limit, outputFields, reranker, subRequests, opts...) postResponse(funcName, err, searchResult) return searchResult, err diff --git a/test/common/utils.go b/test/common/utils.go index 5753887c1..02358e16a 100644 --- a/test/common/utils.go +++ b/test/common/utils.go @@ -52,9 +52,10 @@ const ( DefaultShards = int32(2) DefaultNb = 3000 DefaultNq = 5 - DefaultTopK = 10 - TestCapacity = 100 // default array field capacity - TestMaxLen = 100 // default varchar field max length + //DefaultNq = 1 + DefaultTopK = 10 + TestCapacity = 100 // default array field capacity + TestMaxLen = 100 // default varchar field max length ) // const default value from milvus diff --git a/test/testcases/hybrid_search_test.go b/test/testcases/hybrid_search_test.go index 3fd41febf..a9e7c6c80 100644 --- a/test/testcases/hybrid_search_test.go +++ b/test/testcases/hybrid_search_test.go @@ -208,8 +208,6 @@ func TestHybridSearchInvalidVectors(t *testing.T) { // hybrid search Pagination -> verify success func TestHybridSearchMultiVectorsPagination(t *testing.T) { - // TODO "https://github.com/milvus-io/milvus/issues/32174" - // TODO "https://github.com/milvus-io/milvus-sdk-go/issues/718" t.Parallel() ctx := createContext(t, time.Second*common.DefaultTimeout*2) // connect @@ -219,7 +217,7 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) { cp := CollectionParams{CollectionFieldsType: AllVectors, AutoID: false, EnableDynamicField: false, ShardsNum: common.DefaultShards, Dim: common.DefaultDim} - dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb, + dp := DataParams{DoInsert: true, CollectionFieldsType: AllVectors, start: 0, nb: common.DefaultNb * 5, dim: common.DefaultDim, EnableDynamicField: false} // index params @@ -240,9 +238,9 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) { _, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs) common.CheckErr(t, errSearch, true) - // hybrid search with invalid offset - //_, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithOffset(invalidOffset)) - //common.CheckErr(t, errSearch, false, "top k should be in range [1, 16384]") + //hybrid search with invalid offset + _, errSearch = mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, client.NewRRFReranker(), sReqs, client.WithOffset(invalidOffset)) + common.CheckErr(t, errSearch, false, "should be gte than 0", "(offset+limit) should be in range [1, 16384]") } // search with different reranker and offset @@ -253,13 +251,19 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) { client.NewWeightedReranker([]float64{0.4, 1.0}), } { sReqs := []*client.ANNSearchRequest{ - client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK, client.WithOffset(5)), + client.NewANNSearchRequest(common.DefaultFloatVecFieldName, entity.L2, expr, queryVec1, sp, common.DefaultTopK), client.NewANNSearchRequest(common.DefaultFloat16VecFieldName, entity.L2, expr, queryVec2, sp, common.DefaultTopK), } // hybrid search searchRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, common.DefaultTopK, []string{}, reranker, sReqs) common.CheckErr(t, errSearch, true) + offsetRes, errSearch := mc.HybridSearch(ctx, collName, []string{}, 5, []string{}, reranker, sReqs, client.WithOffset(5)) + common.CheckErr(t, errSearch, true) common.CheckSearchResult(t, searchRes, 1, common.DefaultTopK) + common.CheckSearchResult(t, offsetRes, 1, 5) + for i := 0; i < len(searchRes); i++ { + require.Equal(t, searchRes[i].IDs.(*entity.ColumnInt64).Data()[5:], offsetRes[i].IDs.(*entity.ColumnInt64).Data()) + } } }