diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index 6b92c59665706..b5fb71d70ab46 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -184,11 +184,11 @@ message RetrieveRequest { bool is_count = 13; int64 iteration_extension_reduce_rate = 14; string username = 15; - bool reduce_stop_for_best = 16; + bool reduce_stop_for_best = 16; //deprecated + int32 reduce_type = 17; } - message RetrieveResults { common.MsgBase base = 1; common.Status status = 2; diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 8ea5802e1b0f7..a79af7e7d752f 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/reduce" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -75,9 +76,9 @@ type queryTask struct { } type queryParams struct { - limit int64 - offset int64 - reduceStopForBest bool + limit int64 + offset int64 + reduceType reduce.IReduceType } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -142,6 +143,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e limit int64 offset int64 reduceStopForBest bool + isIterator bool err error ) reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair) @@ -154,10 +156,29 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } } + isIteratorStr, err := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, queryParamsPair) + // if reduce_stop_for_best is provided + if err == nil { + isIterator, err = strconv.ParseBool(isIteratorStr) + if err != nil { + return nil, merr.WrapErrParameterInvalid("true or false", isIteratorStr, + "value for iterator field is invalid") + } + } + + reduceType := reduce.IReduceNoOrder + if isIterator { + if reduceStopForBest { + reduceType = reduce.IReduceInOrderForBest + } else { + reduceType = reduce.IReduceInOrder + } + } + limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) // if limit is not provided if err != nil { - return &queryParams{limit: typeutil.Unlimited, reduceStopForBest: reduceStopForBest}, nil + return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { @@ -179,9 +200,9 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } return &queryParams{ - limit: limit, - offset: offset, - reduceStopForBest: reduceStopForBest, + limit: limit, + offset: offset, + reduceType: reduceType, }, nil } @@ -343,7 +364,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if err != nil { return err } - t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest + if queryParams.reduceType == reduce.IReduceInOrderForBest { + t.RetrieveRequest.ReduceStopForBest = true + } + t.RetrieveRequest.ReduceType = int32(queryParams.reduceType) t.queryParams = queryParams t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset @@ -612,9 +636,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re cursors := make([]int64, len(validRetrieveResults)) if queryParams != nil && queryParams.limit != typeutil.Unlimited { - // reduceStopForBest will try to get as many results as possible + // IReduceInOrderForBest will try to get as many results as possible // so loopEnd in this case will be set to the sum of all results' size - if !queryParams.reduceStopForBest { + // to get as many qualified results as possible + if reduce.ShouldUseInputLimit(queryParams.reduceType) { loopEnd = int(queryParams.limit) } } @@ -623,7 +648,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re if queryParams != nil && queryParams.offset > 0 { for i := int64(0); i < queryParams.offset; i++ { sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) - if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) { return ret, nil } cursors[sel]++ @@ -635,7 +660,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; j++ { sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) - if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) { break } retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel]) diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index f45b45ff42295..9f2ec742ef9be 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -442,6 +443,101 @@ func TestTaskQuery_functions(t *testing.T) { } }) + t.Run("test parseQueryParams for reduce type", func(t *testing.T) { + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "True", + }) + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + ret, err := parseQueryParams(inParams) + assert.NoError(t, err) + assert.Equal(t, reduce.IReduceInOrderForBest, ret.reduceType) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "True", + }) + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "TrueXXXX", + }) + ret, err := parseQueryParams(inParams) + assert.Error(t, err) + assert.Nil(t, ret) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "TrueXXXXX", + }) + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + ret, err := parseQueryParams(inParams) + assert.Error(t, err) + assert.Nil(t, ret) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "True", + }) + // when not setting iterator tag, ignore reduce_stop_for_best + ret, err := parseQueryParams(inParams) + assert.NoError(t, err) + assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + // when not setting reduce_stop_for_best tag, reduce by keep results in order + ret, err := parseQueryParams(inParams) + assert.NoError(t, err) + assert.Equal(t, reduce.IReduceInOrder, ret.reduceType) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "False", + }) + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "True", + }) + ret, err := parseQueryParams(inParams) + assert.NoError(t, err) + assert.Equal(t, reduce.IReduceInOrder, ret.reduceType) + } + { + var inParams []*commonpb.KeyValuePair + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: ReduceStopForBestKey, + Value: "False", + }) + inParams = append(inParams, &commonpb.KeyValuePair{ + Key: IteratorField, + Value: "False", + }) + ret, err := parseQueryParams(inParams) + assert.NoError(t, err) + assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType) + } + }) + t.Run("test reduceRetrieveResults", func(t *testing.T) { const ( Dim = 8 @@ -572,7 +668,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 2, reduceStopForBest: true}) + &queryParams{limit: 2, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -585,7 +681,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 1, offset: 1, reduceStopForBest: true}) + &queryParams{limit: 1, offset: 1, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -596,7 +692,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 2, offset: 1, reduceStopForBest: true}) + &queryParams{limit: 2, offset: 1, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) @@ -609,7 +705,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: typeutil.Unlimited, reduceStopForBest: true}) + &queryParams{limit: typeutil.Unlimited, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -620,11 +716,21 @@ func TestTaskQuery_functions(t *testing.T) { t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) { result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: typeutil.Unlimited, offset: 3, reduceStopForBest: true}) + &queryParams{limit: typeutil.Unlimited, offset: 3, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) }) + t.Run("test iterator without setting reduce stop for best", func(t *testing.T) { + r1.HasMoreResult = true + r2.HasMoreResult = true + result, err := reduceRetrieveResults(context.Background(), + []*internalpb.RetrieveResults{r1, r2}, + &queryParams{limit: 1, reduceType: reduce.IReduceInOrder}) + assert.NoError(t, err) + assert.Equal(t, 2, len(result.GetFieldsData())) + assert.Equal(t, []int64{11}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) + }) }) }) } diff --git a/internal/querynodev2/segments/default_limit_reducer.go b/internal/querynodev2/segments/default_limit_reducer.go index 4334b464c5d08..58ea0a06fba04 100644 --- a/internal/querynodev2/segments/default_limit_reducer.go +++ b/internal/querynodev2/segments/default_limit_reducer.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/reduce" ) type defaultLimitReducer struct { @@ -15,24 +16,24 @@ type defaultLimitReducer struct { } type mergeParam struct { - limit int64 - outputFieldsId []int64 - schema *schemapb.CollectionSchema - mergeStopForBest bool + limit int64 + outputFieldsId []int64 + schema *schemapb.CollectionSchema + reduceType reduce.IReduceType } -func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceStopForBest bool) *mergeParam { +func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceType reduce.IReduceType) *mergeParam { return &mergeParam{ - limit: limit, - outputFieldsId: outputFieldsId, - schema: schema, - mergeStopForBest: reduceStopForBest, + limit: limit, + outputFieldsId: outputFieldsId, + schema: schema, + reduceType: reduceType, } } func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { reduceParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), - r.schema, r.req.GetReq().GetReduceStopForBest()) + r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType())) return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, reduceParam) } @@ -50,7 +51,7 @@ type defaultLimitReducerSegcore struct { } func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { - mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest()) + mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType())) return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam, segments, plan, r.manager) } diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index fac6ba23c6772..f6e3efc0eae64 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -293,7 +293,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna return ret, nil } - if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) { loopEnd = int(param.limit) } @@ -305,7 +305,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) - if sel == -1 || (param.mergeStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) { break } @@ -416,7 +416,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore } var limit int = -1 - if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) { limit = int(param.limit) } @@ -438,7 +438,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ { sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) - if sel == -1 || (param.mergeStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) { break } diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 2d9a2fa9939f2..35d2a48699438 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -100,7 +100,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -114,7 +114,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.Run("test nil results", func() { ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), nil, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -133,7 +133,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -185,7 +185,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { for _, test := range tests { suite.Run(test.description, func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(test.limit, make([]int64, 0), nil, false)) + NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -225,14 +225,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } _, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result}, - NewMergeParam(reqLimit, make([]int64, 0), nil, false)) + NewMergeParam(reqLimit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) intFieldData, has := getFieldData(result, Int64FieldID) @@ -262,7 +262,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -321,7 +321,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -335,7 +335,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { suite.Run("test nil results", func() { ret, err := MergeInternalRetrieveResult(context.Background(), nil, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -373,7 +373,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }, } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -424,7 +424,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { for _, test := range tests { suite.Run(test.description, func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(test.limit, make([]int64, 0), nil, false)) + NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -463,14 +463,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } _, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result, result}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) @@ -501,7 +501,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -568,7 +568,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = true result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // has more result both, stop reduce when draining one result @@ -586,7 +586,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = false result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // as result1 and result2 don't have better results neither @@ -604,7 +604,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = false result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // as result1 may have better results, stop reducing when draining it @@ -643,7 +643,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = false result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData()) @@ -687,7 +687,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = false result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData()) @@ -696,11 +696,20 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = true result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData()) }) + suite.Run("test no stop reduce for best ", func() { + result1.HasMoreResult = true + result2.HasMoreResult = true + result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, + NewMergeParam(1, make([]int64, 0), nil, reduce.IReduceInOrder)) + suite.NoError(err) + suite.Equal(3, len(result.GetFieldsData())) + suite.Equal([]int64{0}, result.GetIds().GetIntId().GetData()) + }) }) } diff --git a/internal/util/reduce/reduce_info.go b/internal/util/reduce/reduce_info.go index 91de9f2df262e..f401e9f357669 100644 --- a/internal/util/reduce/reduce_info.go +++ b/internal/util/reduce/reduce_info.go @@ -90,3 +90,30 @@ func (r *ResultInfo) GetIsAdvance() bool { func (r *ResultInfo) SetMetricType(metricType string) { r.metricType = metricType } + +type IReduceType int32 + +const ( + IReduceNoOrder IReduceType = iota + IReduceInOrder + IReduceInOrderForBest +) + +func ShouldStopWhenDrained(reduceType IReduceType) bool { + return reduceType == IReduceInOrder || reduceType == IReduceInOrderForBest +} + +func ToReduceType(val int32) IReduceType { + switch val { + case 1: + return IReduceInOrder + case 2: + return IReduceInOrderForBest + default: + return IReduceNoOrder + } +} + +func ShouldUseInputLimit(reduceType IReduceType) bool { + return reduceType == IReduceNoOrder || reduceType == IReduceInOrder +}