diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index b50ea4004cc36..79022354a3f97 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -184,11 +184,11 @@ message RetrieveRequest { bool is_count = 13; int64 iteration_extension_reduce_rate = 14; string username = 15; - bool reduce_stop_for_best = 16; + bool reduce_stop_for_best = 16; //deprecated + int32 reduce_type = 17; } - message RetrieveResults { common.MsgBase base = 1; common.Status status = 2; diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 845477a0ea2ab..8e25deb39f06c 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -20,6 +20,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/reduce" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -75,9 +76,9 @@ type queryTask struct { } type queryParams struct { - limit int64 - offset int64 - reduceStopForBest bool + limit int64 + offset int64 + reduceType reduce.IReduceType } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -142,6 +143,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e limit int64 offset int64 reduceStopForBest bool + isIterator bool err error ) reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair) @@ -154,10 +156,29 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } } + isIteratorStr, err := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, queryParamsPair) + // if reduce_stop_for_best is provided + if err == nil { + isIterator, err = strconv.ParseBool(isIteratorStr) + if err != nil { + return nil, merr.WrapErrParameterInvalid("true or false", isIteratorStr, + "value for iterator field is invalid") + } + } + + reduceType := reduce.IReduceNoOrder + if isIterator { + if reduceStopForBest { + reduceType = reduce.IReduceInOrderForBest + } else { + reduceType = reduce.IReduceInOrder + } + } + limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) // if limit is not provided if err != nil { - return &queryParams{limit: typeutil.Unlimited, reduceStopForBest: reduceStopForBest}, nil + return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType}, nil } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { @@ -179,9 +200,9 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } return &queryParams{ - limit: limit, - offset: offset, - reduceStopForBest: reduceStopForBest, + limit: limit, + offset: offset, + reduceType: reduceType, }, nil } @@ -343,7 +364,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if err != nil { return err } - t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest + t.RetrieveRequest.ReduceType = int32(queryParams.reduceType) t.queryParams = queryParams t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset @@ -615,9 +636,10 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re cursors := make([]int64, len(validRetrieveResults)) if queryParams != nil && queryParams.limit != typeutil.Unlimited { - // reduceStopForBest will try to get as many results as possible + // IReduceInOrderForBest will try to get as many results as possible // so loopEnd in this case will be set to the sum of all results' size - if !queryParams.reduceStopForBest { + // to get as many qualified results as possible + if reduce.ShouldUseInputLimit(queryParams.reduceType) { loopEnd = int(queryParams.limit) } } @@ -626,7 +648,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re if queryParams != nil && queryParams.offset > 0 { for i := int64(0); i < queryParams.offset; i++ { sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) - if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) { return ret, nil } cursors[sel]++ @@ -638,7 +660,7 @@ func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.Re maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors) - if sel == -1 || (queryParams.reduceStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) { break } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 592a90e0f32b5..63ad8849f558e 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -600,7 +601,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 2, reduceStopForBest: true}) + &queryParams{limit: 2, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -613,7 +614,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 1, offset: 1, reduceStopForBest: true}) + &queryParams{limit: 1, offset: 1, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -624,7 +625,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = true result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: 2, offset: 1, reduceStopForBest: true}) + &queryParams{limit: 2, offset: 1, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) @@ -637,7 +638,7 @@ func TestTaskQuery_functions(t *testing.T) { r2.HasMoreResult = false result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: typeutil.Unlimited, reduceStopForBest: true}) + &queryParams{limit: typeutil.Unlimited, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) @@ -648,7 +649,7 @@ func TestTaskQuery_functions(t *testing.T) { t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) { result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - &queryParams{limit: typeutil.Unlimited, offset: 3, reduceStopForBest: true}) + &queryParams{limit: typeutil.Unlimited, offset: 3, reduceType: reduce.IReduceInOrderForBest}) assert.NoError(t, err) assert.Equal(t, 2, len(result.GetFieldsData())) assert.Equal(t, []int64{22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data) diff --git a/internal/querynodev2/segments/default_limit_reducer.go b/internal/querynodev2/segments/default_limit_reducer.go index 4334b464c5d08..58ea0a06fba04 100644 --- a/internal/querynodev2/segments/default_limit_reducer.go +++ b/internal/querynodev2/segments/default_limit_reducer.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/reduce" ) type defaultLimitReducer struct { @@ -15,24 +16,24 @@ type defaultLimitReducer struct { } type mergeParam struct { - limit int64 - outputFieldsId []int64 - schema *schemapb.CollectionSchema - mergeStopForBest bool + limit int64 + outputFieldsId []int64 + schema *schemapb.CollectionSchema + reduceType reduce.IReduceType } -func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceStopForBest bool) *mergeParam { +func NewMergeParam(limit int64, outputFieldsId []int64, schema *schemapb.CollectionSchema, reduceType reduce.IReduceType) *mergeParam { return &mergeParam{ - limit: limit, - outputFieldsId: outputFieldsId, - schema: schema, - mergeStopForBest: reduceStopForBest, + limit: limit, + outputFieldsId: outputFieldsId, + schema: schema, + reduceType: reduceType, } } func (r *defaultLimitReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveResults) (*internalpb.RetrieveResults, error) { reduceParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), - r.schema, r.req.GetReq().GetReduceStopForBest()) + r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType())) return mergeInternalRetrieveResultsAndFillIfEmpty(ctx, results, reduceParam) } @@ -50,7 +51,7 @@ type defaultLimitReducerSegcore struct { } func (r *defaultLimitReducerSegcore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, segments []Segment, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { - mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, r.req.GetReq().GetReduceStopForBest()) + mergeParam := NewMergeParam(r.req.GetReq().GetLimit(), r.req.GetReq().GetOutputFieldsId(), r.schema, reduce.ToReduceType(r.req.GetReq().GetReduceType())) return mergeSegcoreRetrieveResultsAndFillIfEmpty(ctx, results, mergeParam, segments, plan, r.manager) } diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 70c3b28a225ff..b3c2c7677c1e4 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -293,7 +293,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna return ret, nil } - if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) { loopEnd = int(param.limit) } @@ -305,7 +305,7 @@ func MergeInternalRetrieveResult(ctx context.Context, retrieveResults []*interna maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd; { sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) - if sel == -1 || (param.mergeStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) { break } @@ -422,7 +422,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore selected := make([]int, 0, ret.GetAllRetrieveCount()) var limit int = -1 - if param.limit != typeutil.Unlimited && !param.mergeStopForBest { + if param.limit != typeutil.Unlimited && reduce.ShouldUseInputLimit(param.reduceType) { limit = int(param.limit) } @@ -435,7 +435,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() for j := 0; j < loopEnd && (limit == -1 || availableCount < limit); j++ { sel, drainOneResult := typeutil.SelectMinPKWithTimestamp(validRetrieveResults, cursors) - if sel == -1 || (param.mergeStopForBest && drainOneResult) { + if sel == -1 || (reduce.ShouldStopWhenDrained(param.reduceType) && drainOneResult) { break } diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 2d9a2fa9939f2..b9e3bacff55bf 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -100,7 +100,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -114,7 +114,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.Run("test nil results", func() { ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), nil, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -133,7 +133,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } ret, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -185,7 +185,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { for _, test := range tests { suite.Run(test.description, func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(test.limit, make([]int64, 0), nil, false)) + NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -225,14 +225,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } _, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result}, - NewMergeParam(reqLimit, make([]int64, 0), nil, false)) + NewMergeParam(reqLimit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) intFieldData, has := getFieldData(result, Int64FieldID) @@ -262,7 +262,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { } result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -321,7 +321,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -335,7 +335,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { suite.Run("test nil results", func() { ret, err := MergeInternalRetrieveResult(context.Background(), nil, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Empty(ret.GetIds()) suite.Empty(ret.GetFieldsData()) @@ -373,7 +373,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }, } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{ret1, ret2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(2, len(result.GetFieldsData())) suite.Equal([]int64{0, 1}, result.GetIds().GetIntId().GetData()) @@ -424,7 +424,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { for _, test := range tests { suite.Run(test.description, func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(test.limit, make([]int64, 0), nil, false)) + NewMergeParam(test.limit, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal(int(test.limit), len(result.GetIds().GetIntId().GetData())) suite.Equal(resultIDs[0:test.limit], result.GetIds().GetIntId().GetData()) @@ -463,14 +463,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } _, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result, result}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Error(err) paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600") }) suite.Run("test int ID", func() { result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{1, 2, 3, 4}, result.GetIds().GetIntId().GetData()) @@ -501,7 +501,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { } result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{r1, r2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, false)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceNoOrder)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]string{"a", "b", "c", "d"}, result.GetIds().GetStrId().GetData()) @@ -568,7 +568,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = true result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // has more result both, stop reduce when draining one result @@ -586,7 +586,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = false result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // as result1 and result2 don't have better results neither @@ -604,7 +604,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = false result, err := MergeSegcoreRetrieveResultsV1(context.Background(), []*segcorepb.RetrieveResults{result1, result2}, - NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, true)) + NewMergeParam(typeutil.Unlimited, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) // as result1 may have better results, stop reducing when draining it @@ -643,7 +643,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = true result2.HasMoreResult = false result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 6, 7}, result.GetIds().GetIntId().GetData()) @@ -687,7 +687,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = false result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2, 4, 7}, result.GetIds().GetIntId().GetData()) @@ -696,7 +696,7 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { result1.HasMoreResult = false result2.HasMoreResult = true result, err := MergeInternalRetrieveResult(context.Background(), []*internalpb.RetrieveResults{result1, result2}, - NewMergeParam(3, make([]int64, 0), nil, true)) + NewMergeParam(3, make([]int64, 0), nil, reduce.IReduceInOrderForBest)) suite.NoError(err) suite.Equal(3, len(result.GetFieldsData())) suite.Equal([]int64{0, 2}, result.GetIds().GetIntId().GetData()) diff --git a/internal/util/reduce/reduce_info.go b/internal/util/reduce/reduce_info.go index 91de9f2df262e..f401e9f357669 100644 --- a/internal/util/reduce/reduce_info.go +++ b/internal/util/reduce/reduce_info.go @@ -90,3 +90,30 @@ func (r *ResultInfo) GetIsAdvance() bool { func (r *ResultInfo) SetMetricType(metricType string) { r.metricType = metricType } + +type IReduceType int32 + +const ( + IReduceNoOrder IReduceType = iota + IReduceInOrder + IReduceInOrderForBest +) + +func ShouldStopWhenDrained(reduceType IReduceType) bool { + return reduceType == IReduceInOrder || reduceType == IReduceInOrderForBest +} + +func ToReduceType(val int32) IReduceType { + switch val { + case 1: + return IReduceInOrder + case 2: + return IReduceInOrderForBest + default: + return IReduceNoOrder + } +} + +func ShouldUseInputLimit(reduceType IReduceType) bool { + return reduceType == IReduceNoOrder || reduceType == IReduceInOrder +}