Skip to content

Commit

Permalink
enhance: Enable dynamic update replica selection policy (#35860)
Browse files Browse the repository at this point in the history
issue: #35859

---------

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
  • Loading branch information
weiliu1031 authored Sep 13, 2024
1 parent c03eb6f commit bd658a6
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 136 deletions.
107 changes: 53 additions & 54 deletions internal/proxy/lb_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,77 +61,75 @@ type LBPolicy interface {
Close()
}

const (
RoundRobin = "round_robin"
LookAside = "look_aside"
)

type LBPolicyImpl struct {
balancer LBBalancer
clientMgr shardClientMgr
getBalancer func() LBBalancer
clientMgr shardClientMgr
balancerMap map[string]LBBalancer
}

func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()

var balancer LBBalancer
switch balancePolicy {
case "round_robin":
log.Info("use round_robin policy on replica selection")
balancer = NewRoundRobinBalancer()
default:
log.Info("use look_aside policy on replica selection")
balancer = NewLookAsideBalancer(clientMgr)
balancerMap := make(map[string]LBBalancer)
balancerMap[LookAside] = NewLookAsideBalancer(clientMgr)
balancerMap[RoundRobin] = NewRoundRobinBalancer()

getBalancer := func() LBBalancer {
balancePolicy := params.Params.ProxyCfg.ReplicaSelectionPolicy.GetValue()
if _, ok := balancerMap[balancePolicy]; !ok {
return balancerMap[LookAside]
}
return balancerMap[balancePolicy]
}

return &LBPolicyImpl{
balancer: balancer,
clientMgr: clientMgr,
getBalancer: getBalancer,
clientMgr: clientMgr,
balancerMap: balancerMap,
}
}

func (lb *LBPolicyImpl) Start(ctx context.Context) {
lb.balancer.Start(ctx)
for _, lb := range lb.balancerMap {
lb.Start(ctx)
}
}

// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)

filterAvailableNodes := func(node int64, _ int) bool {
return !excludeNodes.Contain(node)
}

getShardLeaders := func() ([]int64, error) {
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
if err != nil {
return nil, err
}

return lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID }), nil
}

availableNodes := lo.Filter(workload.shardLeaders, filterAvailableNodes)
targetNode, err := lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes typeutil.UniqueSet) (int64, error) {
availableNodes := lo.FilterMap(workload.shardLeaders, func(node int64, _ int) (int64, bool) { return node, !excludeNodes.Contain(node) })
targetNode, err := balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
log := log.Ctx(ctx)
globalMetaCache.DeprecateShardCache(workload.db, workload.collectionName)
nodes, err := getShardLeaders()
if err != nil || len(nodes) == 0 {
shardLeaders, err := globalMetaCache.GetShards(ctx, false, workload.db, workload.collectionName, workload.collectionID)
if err != nil {
log.Warn("failed to get shard delegator",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Error(err))
return -1, err
}

availableNodes := lo.Filter(nodes, filterAvailableNodes)
availableNodes := lo.FilterMap(shardLeaders[workload.channel], func(node nodeInfo, _ int) (int64, bool) { return node.nodeID, !excludeNodes.Contain(node.nodeID) })
if len(availableNodes) == 0 {
nodes := lo.Map(shardLeaders[workload.channel], func(node nodeInfo, _ int) int64 { return node.nodeID })
log.Warn("no available shard delegator found",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("nodes", nodes),
zap.Int64s("excluded", excludeNodes.Collect()))
return -1, merr.WrapErrChannelNotAvailable("no available shard delegator found")
}

targetNode, err = lb.balancer.SelectNode(ctx, availableNodes, workload.nq)
targetNode, err = balancer.SelectNode(ctx, availableNodes, workload.nq)
if err != nil {
log.Warn("failed to select shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64s("availableNodes", availableNodes),
zap.Error(err))
return -1, err
Expand All @@ -144,17 +142,15 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, workload ChannelWorkload
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
excludeNodes := typeutil.NewUniqueSet()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("collectionName", workload.collectionName),
zap.String("channelName", workload.channel),
)

var lastErr error
err := retry.Do(ctx, func() error {
targetNode, err := lb.selectNode(ctx, workload, excludeNodes)
balancer := lb.getBalancer()
targetNode, err := lb.selectNode(ctx, balancer, workload, excludeNodes)
if err != nil {
log.Warn("failed to select node for shard",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err),
)
Expand All @@ -163,33 +159,34 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
}
return err
}
// cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode, workload.nq)

client, err := lb.clientMgr.GetClient(ctx, targetNode)
if err != nil {
log.Warn("search/query channel failed, node not available",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err))
excludeNodes.Insert(targetNode)

// cancel work load which assign to the target node
lb.balancer.CancelWorkload(targetNode, workload.nq)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}

err = workload.exec(ctx, targetNode, client, workload.channel)
if err != nil {
log.Warn("search/query channel failed",
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("nodeID", targetNode),
zap.Error(err))
excludeNodes.Insert(targetNode)
lb.balancer.CancelWorkload(targetNode, workload.nq)

lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode, workload.channel)
return lastErr
}

lb.balancer.CancelWorkload(targetNode, workload.nq)
return nil
}, retry.Attempts(workload.retryTimes))

Expand Down Expand Up @@ -232,9 +229,11 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
}

func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
lb.balancer.UpdateCostMetrics(node, cost)
lb.getBalancer().UpdateCostMetrics(node, cost)
}

func (lb *LBPolicyImpl) Close() {
lb.balancer.Close()
for _, lb := range lb.balancerMap {
lb.Close()
}
}
20 changes: 11 additions & 9 deletions internal/proxy/lb_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ func (s *LBPolicySuite) SetupTest() {
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
s.lbPolicy = NewLBPolicyImpl(s.mgr)
s.lbPolicy.Start(context.Background())
s.lbPolicy.balancer = s.lbBalancer
s.lbPolicy.getBalancer = func() LBBalancer {
return s.lbBalancer
}

err := InitMetaCache(context.Background(), s.rc, s.qc, s.mgr)
s.NoError(err)
Expand Down Expand Up @@ -163,7 +165,7 @@ func (s *LBPolicySuite) loadCollection() {
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
Expand All @@ -178,7 +180,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
Expand All @@ -192,7 +194,7 @@ func (s *LBPolicySuite) TestSelectNode() {
// test select node always fails, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
Expand All @@ -206,7 +208,7 @@ func (s *LBPolicySuite) TestSelectNode() {
// test all nodes has been excluded, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
Expand All @@ -222,7 +224,7 @@ func (s *LBPolicySuite) TestSelectNode() {
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
s.qc.ExpectedCalls = nil
s.qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(nil, merr.ErrServiceUnavailable)
targetNode, err = s.lbPolicy.selectNode(ctx, ChannelWorkload{
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
Expand Down Expand Up @@ -419,17 +421,17 @@ func (s *LBPolicySuite) TestUpdateCostMetrics() {

func (s *LBPolicySuite) TestNewLBPolicy() {
policy := NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()

Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.RoundRobinBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer")
policy.Close()

Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.balancer).String(), "*proxy.LookAsideBalancer")
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()
}

Expand Down
11 changes: 6 additions & 5 deletions internal/proxy/meta_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,11 +952,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))

info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}

cacheShardLeaders, ok := m.getCollectionShardLeader(database, collectionName)
if withCache {
if ok {
Expand All @@ -968,6 +963,12 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, database, col
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method, metrics.CacheMissLabel).Inc()
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord")
}

info, err := m.getFullCollectionInfo(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}

req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
Expand Down
33 changes: 4 additions & 29 deletions internal/proxy/roundrobin_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,51 +22,26 @@ import (

"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)

type RoundRobinBalancer struct {
// request num send to each node
nodeWorkload *typeutil.ConcurrentMap[int64, *atomic.Int64]
idx atomic.Int64
}

func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{
nodeWorkload: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
}
return &RoundRobinBalancer{}
}

func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
if len(availableNodes) == 0 {
return -1, merr.ErrNodeNotAvailable
}

targetNode := int64(-1)
var targetNodeWorkload *atomic.Int64
for _, node := range availableNodes {
workload, ok := b.nodeWorkload.Get(node)

if !ok {
workload = atomic.NewInt64(0)
b.nodeWorkload.Insert(node, workload)
}

if targetNodeWorkload == nil || workload.Load() < targetNodeWorkload.Load() {
targetNode = node
targetNodeWorkload = workload
}
}

targetNodeWorkload.Add(cost)
return targetNode, nil
idx := b.idx.Inc()
return availableNodes[int(idx)%len(availableNodes)], nil
}

func (b *RoundRobinBalancer) CancelWorkload(node int64, nq int64) {
load, ok := b.nodeWorkload.Get(node)

if ok {
load.Sub(nq)
}
}

func (b *RoundRobinBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {}
Expand Down
Loading

0 comments on commit bd658a6

Please sign in to comment.