diff --git a/internal/proxy/look_aside_balancer.go b/internal/proxy/look_aside_balancer.go index 18e91a47e0188..dd1115d8b092a 100644 --- a/internal/proxy/look_aside_balancer.go +++ b/internal/proxy/look_aside_balancer.go @@ -19,8 +19,6 @@ package proxy import ( "context" "math" - "math/rand" - "strconv" "sync" "time" @@ -31,46 +29,50 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) +type CostMetrics struct { + cost atomic.Pointer[internalpb.CostAggregation] + executingNQ atomic.Int64 + ts atomic.Int64 + unavailable atomic.Bool +} + type LookAsideBalancer struct { clientMgr shardClientMgr - // query node -> workload latest metrics - metricsMap *typeutil.ConcurrentMap[int64, *internalpb.CostAggregation] - - // query node -> last update metrics ts - metricsUpdateTs *typeutil.ConcurrentMap[int64, int64] - - // query node -> total nq of requests which already send but response hasn't received - executingTaskTotalNQ *typeutil.ConcurrentMap[int64, *atomic.Int64] - - unreachableQueryNodes *typeutil.ConcurrentSet[int64] - + metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics] // query node id -> number of consecutive heartbeat failures failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64] + // idx for round_robin + idx atomic.Int64 + closeCh chan struct{} closeOnce sync.Once wg sync.WaitGroup + + // param for replica selection + metricExpireInterval int64 + checkWorkloadRequestNum int64 + workloadToleranceFactor float64 } func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer { balancer := &LookAsideBalancer{ clientMgr: clientMgr, - metricsMap: typeutil.NewConcurrentMap[int64, *internalpb.CostAggregation](), - metricsUpdateTs: typeutil.NewConcurrentMap[int64, int64](), - executingTaskTotalNQ: typeutil.NewConcurrentMap[int64, *atomic.Int64](), - unreachableQueryNodes: typeutil.NewConcurrentSet[int64](), + metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](), failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](), closeCh: make(chan struct{}), } + balancer.metricExpireInterval = Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64() + balancer.checkWorkloadRequestNum = Params.ProxyCfg.CheckWorkloadRequestNum.GetAsInt64() + balancer.workloadToleranceFactor = Params.ProxyCfg.WorkloadToleranceFactor.GetAsFloat() + return balancer } @@ -86,54 +88,82 @@ func (b *LookAsideBalancer) Close() { }) } -func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) { - log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60) +func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error) { targetNode := int64(-1) - targetScore := float64(math.MaxFloat64) - rand.Shuffle(len(availableNodes), func(i, j int) { - availableNodes[i], availableNodes[j] = availableNodes[j], availableNodes[i] - }) - for _, node := range availableNodes { - if b.unreachableQueryNodes.Contain(node) { - log.RatedWarn(5, "query node is unreachable, skip it", - zap.Int64("nodeID", node)) - continue + defer func() { + if targetNode != -1 { + metrics, _ := b.metricsMap.GetOrInsert(targetNode, &CostMetrics{}) + metrics.executingNQ.Add(nq) + } + }() + + // after assign n request, try to assign the task to a query node which has much less workload + idx := b.idx.Load() + if idx%b.checkWorkloadRequestNum != 0 { + for i := 0; i < len(availableNodes); i++ { + targetNode = availableNodes[int(idx)%len(availableNodes)] + targetMetrics, ok := b.metricsMap.Get(targetNode) + if !ok || !targetMetrics.unavailable.Load() { + break + } } - cost, _ := b.metricsMap.Get(node) - executingNQ, ok := b.executingTaskTotalNQ.Get(node) - if !ok { - executingNQ = atomic.NewInt64(0) - b.executingTaskTotalNQ.Insert(node, executingNQ) + if targetNode == -1 { + return targetNode, merr.WrapErrServiceUnavailable("all available nodes are unreachable") } - score := b.calculateScore(node, cost, executingNQ.Load()) - metrics.ProxyWorkLoadScore.WithLabelValues(strconv.FormatInt(node, 10)).Set(score) + b.idx.Inc() + return targetNode, nil + } - if targetNode == -1 || score < targetScore { - targetScore = score + // compute each query node's workload score, select the one with least workload score + minScore := int64(math.MaxInt64) + maxScore := int64(0) + nowTs := time.Now().UnixMilli() + for i := 0; i < len(availableNodes); i++ { + node := availableNodes[(int(idx)+i)%len(availableNodes)] + score := int64(0) + metrics, ok := b.metricsMap.Get(node) + if ok { + if metrics.unavailable.Load() { + continue + } + + executingNQ := metrics.executingNQ.Load() + // for multi-replica cases, when there are no task which waiting in queue, + // the response time will effect the score, to prevent the score based on a too old metrics + // we expire the cost metrics if no task in queue. + if executingNQ != 0 || nowTs-metrics.ts.Load() <= b.metricExpireInterval { + score = b.calculateScore(node, metrics.cost.Load(), executingNQ) + } + } + + if score < minScore || targetNode == -1 { + minScore = score targetNode = node } + if score > maxScore { + maxScore = score + } } - if targetNode == -1 { - return -1, merr.WrapErrServiceUnavailable("all available nodes are unreachable") + if float64(maxScore-minScore)/float64(minScore) <= b.workloadToleranceFactor { + // if all query node has nearly same workload, just fall back to round_robin + b.idx.Inc() } - // update executing task cost - totalNQ, _ := b.executingTaskTotalNQ.Get(targetNode) - nq := totalNQ.Add(cost) - metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq)) + if targetNode == -1 { + return targetNode, merr.WrapErrServiceUnavailable("all available nodes are unreachable") + } return targetNode, nil } // when task canceled, should reduce executing total nq cost func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) { - totalNQ, ok := b.executingTaskTotalNQ.Get(node) + metrics, ok := b.metricsMap.Get(node) if ok { - nq := totalNQ.Sub(nq) - metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq)) + metrics.executingNQ.Sub(nq) } } @@ -141,29 +171,29 @@ func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) { func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) { // cache the latest query node cost metrics for updating the score if cost != nil { - b.metricsMap.Insert(node, cost) + metrics, ok := b.metricsMap.Get(node) + if !ok { + metrics = &CostMetrics{} + b.metricsMap.Insert(node, metrics) + } + metrics.cost.Store(cost) + metrics.ts.Store(time.Now().UnixMilli()) + metrics.unavailable.CompareAndSwap(true, false) } - b.metricsUpdateTs.Insert(node, time.Now().UnixMilli()) - - // one query/search succeed, we regard heartbeat succeed, clear heartbeat failed counter - b.trySetQueryNodeReachable(node) } // calculateScore compute the query node's workload score // https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh -func (b *LookAsideBalancer) calculateScore(node int64, cost *internalpb.CostAggregation, executingNQ int64) float64 { - if cost == nil || cost.GetResponseTime() == 0 { - return math.Pow(float64(executingNQ), 3.0) +func (b *LookAsideBalancer) calculateScore(node int64, cost *internalpb.CostAggregation, executingNQ int64) int64 { + pow3 := func(n int64) int64 { + return n * n * n } - // for multi-replica cases, when there are no task which waiting in queue, - // the response time will effect the score, to prevent the score based on a too old value - // we expire the cost metrics by second if no task in queue. - if executingNQ == 0 && b.isNodeCostMetricsTooOld(node) { - return 0 + if cost == nil || cost.GetResponseTime() == 0 { + return pow3(executingNQ) } - executeSpeed := float64(cost.ResponseTime) - float64(cost.ServiceTime) + executeSpeed := cost.ResponseTime - cost.ServiceTime if executingNQ < 0 { log.Warn("unexpected executing nq value", zap.Int64("executingNQ", executingNQ)) @@ -176,30 +206,21 @@ func (b *LookAsideBalancer) calculateScore(node int64, cost *internalpb.CostAggr return executeSpeed } - workload := math.Pow(float64(1+cost.GetTotalNQ()+executingNQ), 3.0) * float64(cost.ServiceTime) + // workload := math.Pow(float64(1+cost.GetTotalNQ()+executingNQ), 3.0) * float64(cost.ServiceTime) + workload := pow3(1+cost.GetTotalNQ()+executingNQ) * cost.ServiceTime if workload < 0 { - return math.MaxFloat64 + return math.MaxInt64 } return executeSpeed + workload } -// if the node cost metrics hasn't been updated for a second, we think the metrics is too old -func (b *LookAsideBalancer) isNodeCostMetricsTooOld(node int64) bool { - lastUpdateTs, ok := b.metricsUpdateTs.Get(node) - if !ok || lastUpdateTs == 0 { - return false - } - - return time.Now().UnixMilli()-lastUpdateTs > Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64() -} - func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60) defer b.wg.Done() - checkQueryNodeHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond) - ticker := time.NewTicker(checkQueryNodeHealthInterval) + checkHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond) + ticker := time.NewTicker(checkHealthInterval) defer ticker.Stop() log.Info("Start check query node health loop") pool := conc.NewDefaultPool[any]() @@ -210,15 +231,19 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { return case <-ticker.C: - now := time.Now().UnixMilli() var futures []*conc.Future[any] - b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool { - if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() { - futures = append(futures, pool.Submit(func() (any, error) { - checkInterval := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond) - ctx, cancel := context.WithTimeout(context.Background(), checkInterval) + now := time.Now() + b.metricsMap.Range(func(node int64, metrics *CostMetrics) bool { + futures = append(futures, pool.Submit(func() (any, error) { + if now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() { + checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), checkTimeout) defer cancel() + if node == -1 { + panic("let it panic") + } + qn, err := b.clientMgr.GetClient(ctx, node) if err != nil { // get client from clientMgr failed, which means this qn isn't a shard leader anymore, skip it's health check @@ -228,26 +253,23 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { resp, err := qn.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) if err != nil { - if b.trySetQueryNodeUnReachable(node, err) { - log.Warn("get component status failed, set node unreachable", zap.Int64("node", node), zap.Error(err)) - } + b.trySetQueryNodeUnReachable(node, err) + log.RatedWarn(10, "get component status failed, set node unreachable", zap.Int64("node", node), zap.Error(err)) return struct{}{}, nil } if resp.GetState().GetStateCode() != commonpb.StateCode_Healthy { - if b.trySetQueryNodeUnReachable(node, merr.ErrServiceUnavailable) { - log.Warn("component status unhealthy, set node unreachable", zap.Int64("node", node), zap.Error(err)) - } + b.trySetQueryNodeUnReachable(node, merr.ErrServiceUnavailable) + log.RatedWarn(10, "component status unhealthy, set node unreachable", zap.Int64("node", node), zap.Error(err)) + return struct{}{}, nil } + } - // check health successfully, try set query node reachable - b.metricsUpdateTs.Insert(node, time.Now().Local().UnixMilli()) - b.trySetQueryNodeReachable(node) - - return struct{}{}, nil - })) - } + // check health successfully, try set query node reachable + b.trySetQueryNodeReachable(node) + return struct{}{}, nil + })) return true }) @@ -256,7 +278,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) { } } -func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) bool { +func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) { failures, ok := b.failedHeartBeatCounter.Get(node) if !ok { failures = atomic.NewInt64(0) @@ -270,8 +292,9 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) bo zap.Error(err)) if failures.Load() < Params.ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() { - return false + return } + // if the total time of consecutive heartbeat failures reach the session.ttl, remove the offline query node limit := Params.CommonCfg.SessionTTL.GetAsDuration(time.Second).Seconds() / Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond).Seconds() @@ -279,14 +302,14 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) bo log.Info("the heartbeat failures has reach it's upper limit, remove the query node", zap.Int64("nodeID", node)) // stop the heartbeat - b.metricsUpdateTs.GetAndRemove(node) - b.metricsMap.GetAndRemove(node) - b.executingTaskTotalNQ.GetAndRemove(node) - b.unreachableQueryNodes.Remove(node) - return false + b.metricsMap.Remove(node) + return } - return b.unreachableQueryNodes.Insert(node) + metrics, ok := b.metricsMap.Get(node) + if ok { + metrics.unavailable.Store(true) + } } func (b *LookAsideBalancer) trySetQueryNodeReachable(node int64) { @@ -295,7 +318,9 @@ func (b *LookAsideBalancer) trySetQueryNodeReachable(node int64) { if ok { failures.Store(0) } - if b.unreachableQueryNodes.TryRemove(node) { + + metrics, ok := b.metricsMap.Get(node) + if !ok || metrics.unavailable.CompareAndSwap(true, false) { log.Info("component recuperated, set node reachable", zap.Int64("node", node)) } } diff --git a/internal/proxy/look_aside_balancer_test.go b/internal/proxy/look_aside_balancer_test.go index e3db80dc7b73e..e80d31fc07b1a 100644 --- a/internal/proxy/look_aside_balancer_test.go +++ b/internal/proxy/look_aside_balancer_test.go @@ -64,9 +64,9 @@ func (suite *LookAsideBalancerSuite) TestUpdateMetrics() { suite.balancer.UpdateCostMetrics(1, costMetrics) - lastUpdateTs, ok := suite.balancer.metricsUpdateTs.Get(1) + metrics, ok := suite.balancer.metricsMap.Get(1) suite.True(ok) - suite.True(time.Now().UnixMilli()-lastUpdateTs <= 5) + suite.True(time.Now().UnixMilli()-metrics.ts.Load() <= 5) } func (suite *LookAsideBalancerSuite) TestCalculateScore() { @@ -98,19 +98,19 @@ func (suite *LookAsideBalancerSuite) TestCalculateScore() { score2 := suite.balancer.calculateScore(-1, costMetrics2, 0) score3 := suite.balancer.calculateScore(-1, costMetrics3, 0) score4 := suite.balancer.calculateScore(-1, costMetrics4, 0) - suite.Equal(float64(12), score1) - suite.Equal(float64(19), score2) - suite.Equal(float64(17), score3) - suite.Equal(float64(5), score4) + suite.Equal(int64(12), score1) + suite.Equal(int64(19), score2) + suite.Equal(int64(17), score3) + suite.Equal(int64(5), score4) score5 := suite.balancer.calculateScore(-1, costMetrics1, 5) score6 := suite.balancer.calculateScore(-1, costMetrics2, 5) score7 := suite.balancer.calculateScore(-1, costMetrics3, 5) score8 := suite.balancer.calculateScore(-1, costMetrics4, 5) - suite.Equal(float64(347), score5) - suite.Equal(float64(689), score6) - suite.Equal(float64(352), score7) - suite.Equal(float64(220), score8) + suite.Equal(int64(347), score5) + suite.Equal(int64(689), score6) + suite.Equal(int64(352), score7) + suite.Equal(int64(220), score8) // test score overflow costMetrics5 := &internalpb.CostAggregation{ @@ -120,15 +120,7 @@ func (suite *LookAsideBalancerSuite) TestCalculateScore() { } score9 := suite.balancer.calculateScore(-1, costMetrics5, math.MaxInt64) - suite.Equal(math.MaxFloat64, score9) - - // test metrics expire - suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli()) - score10 := suite.balancer.calculateScore(1, costMetrics4, 0) - suite.Equal(float64(5), score10) - suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli()-5000) - score11 := suite.balancer.calculateScore(1, costMetrics4, 0) - suite.Equal(float64(0), score11) + suite.Equal(int64(math.MaxInt64), score9) // test unexpected negative nq value costMetrics6 := &internalpb.CostAggregation{ @@ -137,14 +129,14 @@ func (suite *LookAsideBalancerSuite) TestCalculateScore() { TotalNQ: -1, } score12 := suite.balancer.calculateScore(-1, costMetrics6, math.MaxInt64) - suite.Equal(float64(4), score12) + suite.Equal(int64(4), score12) costMetrics7 := &internalpb.CostAggregation{ ResponseTime: 5, ServiceTime: 1, TotalNQ: 1, } score13 := suite.balancer.calculateScore(-1, costMetrics7, -1) - suite.Equal(float64(4), score13) + suite.Equal(int64(4), score13) } func (suite *LookAsideBalancerSuite) TestSelectNode() { @@ -279,7 +271,8 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() { } for node, executingNQ := range c.executingNQ { - suite.balancer.executingTaskTotalNQ.Insert(node, atomic.NewInt64(executingNQ)) + metrics, _ := suite.balancer.metricsMap.Get(node) + metrics.executingNQ.Store(executingNQ) } counter := make(map[int64]int64) for i := 0; i < c.requestCount; i++ { @@ -300,25 +293,31 @@ func (suite *LookAsideBalancerSuite) TestCancelWorkload() { suite.NoError(err) suite.balancer.CancelWorkload(node, 10) - executingNQ, ok := suite.balancer.executingTaskTotalNQ.Get(node) + metrics, ok := suite.balancer.metricsMap.Get(node) suite.True(ok) - suite.Equal(int64(0), executingNQ.Load()) + suite.Equal(int64(0), metrics.executingNQ.Load()) } func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { qn2 := mocks.NewMockQueryNodeClient(suite.T()) - suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil) + suite.clientMgr.EXPECT().GetClient(mock.Anything, int64(2)).Return(qn2, nil).Maybe() qn2.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{ State: &milvuspb.ComponentInfo{ StateCode: commonpb.StateCode_Healthy, }, - }, nil) - - suite.balancer.metricsUpdateTs.Insert(1, time.Now().UnixMilli()) - suite.balancer.metricsUpdateTs.Insert(2, time.Now().UnixMilli()) - suite.balancer.unreachableQueryNodes.Insert(2) + }, nil).Maybe() + + metrics1 := &CostMetrics{} + metrics1.ts.Store(time.Now().UnixMilli()) + metrics1.unavailable.Store(true) + suite.balancer.metricsMap.Insert(1, metrics1) + metrics2 := &CostMetrics{} + metrics2.ts.Store(time.Now().UnixMilli()) + metrics2.unavailable.Store(true) + suite.balancer.metricsMap.Insert(2, metrics2) suite.Eventually(func() bool { - return suite.balancer.unreachableQueryNodes.Contain(1) + metrics, ok := suite.balancer.metricsMap.Get(1) + return ok && metrics.unavailable.Load() }, 5*time.Second, 100*time.Millisecond) targetNode, err := suite.balancer.SelectNode(context.Background(), []int64{1}, 1) suite.ErrorIs(err, merr.ErrServiceUnavailable) @@ -326,16 +325,21 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() { suite.balancer.UpdateCostMetrics(1, &internalpb.CostAggregation{}) suite.Eventually(func() bool { - return !suite.balancer.unreachableQueryNodes.Contain(1) + metrics, ok := suite.balancer.metricsMap.Get(1) + return ok && !metrics.unavailable.Load() }, 3*time.Second, 100*time.Millisecond) suite.Eventually(func() bool { - return !suite.balancer.unreachableQueryNodes.Contain(2) + metrics, ok := suite.balancer.metricsMap.Get(2) + return ok && !metrics.unavailable.Load() }, 5*time.Second, 100*time.Millisecond) } func (suite *LookAsideBalancerSuite) TestGetClientFailed() { - suite.balancer.metricsUpdateTs.Insert(2, time.Now().UnixMilli()) + metrics1 := &CostMetrics{} + metrics1.ts.Store(time.Now().UnixMilli()) + metrics1.unavailable.Store(true) + suite.balancer.metricsMap.Insert(2, metrics1) // test get shard client from client mgr return nil suite.clientMgr.ExpectedCalls = nil @@ -364,13 +368,17 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() { }, }, nil) - suite.balancer.metricsUpdateTs.Insert(3, time.Now().UnixMilli()) + metrics1 := &CostMetrics{} + metrics1.ts.Store(time.Now().UnixMilli()) + suite.balancer.metricsMap.Insert(3, metrics1) suite.Eventually(func() bool { - return suite.balancer.unreachableQueryNodes.Contain(3) + metrics, ok := suite.balancer.metricsMap.Get(3) + return ok && metrics.unavailable.Load() }, 5*time.Second, 100*time.Millisecond) suite.Eventually(func() bool { - return !suite.balancer.unreachableQueryNodes.Contain(3) + metrics, ok := suite.balancer.metricsMap.Get(3) + return ok && !metrics.unavailable.Load() }, 5*time.Second, 100*time.Millisecond) } @@ -386,17 +394,82 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() { }, }, nil) - suite.balancer.metricsUpdateTs.Insert(3, time.Now().UnixMilli()) + metrics1 := &CostMetrics{} + metrics1.ts.Store(time.Now().UnixMilli()) + suite.balancer.metricsMap.Insert(3, metrics1) suite.Eventually(func() bool { - return suite.balancer.unreachableQueryNodes.Contain(3) + metrics, ok := suite.balancer.metricsMap.Get(3) + return ok && metrics.unavailable.Load() }, 5*time.Second, 100*time.Millisecond) suite.Eventually(func() bool { - return !suite.balancer.metricsUpdateTs.Contain(3) + _, ok := suite.balancer.metricsMap.Get(3) + return !ok }, 10*time.Second, 100*time.Millisecond) - suite.Eventually(func() bool { - return !suite.balancer.unreachableQueryNodes.Contain(3) - }, time.Second, 100*time.Millisecond) +} + +func BenchmarkSelectNode_QNWithSameWorkload(b *testing.B) { + balancer := NewLookAsideBalancer(nil) + + ctx := context.Background() + nodeList := make([]int64, 0o0) + + metrics := &internalpb.CostAggregation{ + ResponseTime: 100, + ServiceTime: 100, + TotalNQ: 100, + } + for i := 0; i < 16; i++ { + nodeID := int64(10000 + i) + nodeList = append(nodeList, nodeID) + } + cost := int64(7) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + node, _ := balancer.SelectNode(ctx, nodeList, cost) + balancer.CancelWorkload(node, cost) + balancer.UpdateCostMetrics(node, metrics) + } + }) +} + +func BenchmarkSelectNode_QNWithDifferentWorkload(b *testing.B) { + balancer := NewLookAsideBalancer(nil) + + ctx := context.Background() + nodeList := make([]int64, 0o0) + + metrics := &internalpb.CostAggregation{ + ResponseTime: 100, + ServiceTime: 100, + TotalNQ: 100, + } + + heavyMetric := &internalpb.CostAggregation{ + ResponseTime: 1000, + ServiceTime: 1000, + TotalNQ: 1000, + } + for i := 0; i < 16; i++ { + nodeID := int64(10000 + i) + nodeList = append(nodeList, nodeID) + } + cost := int64(7) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var i int + for pb.Next() { + node, _ := balancer.SelectNode(ctx, nodeList, cost) + balancer.CancelWorkload(node, cost) + if i%2 == 0 { + balancer.UpdateCostMetrics(node, heavyMetric) + } else { + balancer.UpdateCostMetrics(node, metrics) + } + i++ + } + }) } func TestLookAsideBalancerSuite(t *testing.T) { diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index 70d7f914e290a..80695e460c9fb 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -1230,7 +1230,9 @@ type proxyConfig struct { ShardLeaderCacheInterval ParamItem `refreshable:"false"` ReplicaSelectionPolicy ParamItem `refreshable:"false"` CheckQueryNodeHealthInterval ParamItem `refreshable:"false"` - CostMetricsExpireTime ParamItem `refreshable:"true"` + CostMetricsExpireTime ParamItem `refreshable:"false"` + CheckWorkloadRequestNum ParamItem `refreshable:"false"` + WorkloadToleranceFactor ParamItem `refreshable:"false"` RetryTimesOnReplica ParamItem `refreshable:"true"` RetryTimesOnHealthCheck ParamItem `refreshable:"true"` PartitionNameRegexp ParamItem `refreshable:"true"` @@ -1551,6 +1553,23 @@ please adjust in embedded Milvus: false`, } p.CostMetricsExpireTime.Init(base.mgr) + p.CheckWorkloadRequestNum = ParamItem{ + Key: "proxy.checkWorkloadRequestNum", + Version: "2.4.12", + DefaultValue: "10", + Doc: "after every requestNum requests has been assigned, try to check workload for query node", + } + p.CheckWorkloadRequestNum.Init(base.mgr) + + p.WorkloadToleranceFactor = ParamItem{ + Key: "proxy.workloadToleranceFactor", + Version: "2.4.12", + DefaultValue: "0.1", + Doc: `tolerance factor for query node workload difference, default to 10%, which means if query node's workload diff is higher than this factor, + proxy will compute each querynode's workload score, and assign request to the lowest workload node; otherwise, it will assign request to the node by round robin`, + } + p.WorkloadToleranceFactor.Init(base.mgr) + p.RetryTimesOnReplica = ParamItem{ Key: "proxy.retryTimesOnReplica", Version: "2.3.0", diff --git a/pkg/util/paramtable/component_param_test.go b/pkg/util/paramtable/component_param_test.go index 03a7014d67211..4ce1e6fccf4f2 100644 --- a/pkg/util/paramtable/component_param_test.go +++ b/pkg/util/paramtable/component_param_test.go @@ -208,6 +208,9 @@ func TestComponentParam(t *testing.T) { assert.False(t, Params.SkipPartitionKeyCheck.GetAsBool()) params.Save("proxy.skipPartitionKeyCheck", "true") assert.True(t, Params.SkipPartitionKeyCheck.GetAsBool()) + + assert.Equal(t, int64(10), Params.CheckWorkloadRequestNum.GetAsInt64()) + assert.Equal(t, float64(0.1), Params.WorkloadToleranceFactor.GetAsFloat()) }) // t.Run("test proxyConfig panic", func(t *testing.T) {