diff --git a/pkg/mcs/scheduling/server/config/config.go b/pkg/mcs/scheduling/server/config/config.go index cca2651d99e..45a606720a0 100644 --- a/pkg/mcs/scheduling/server/config/config.go +++ b/pkg/mcs/scheduling/server/config/config.go @@ -400,6 +400,11 @@ func (o *PersistConfig) GetHotRegionCacheHitsThreshold() int { return int(o.GetScheduleConfig().HotRegionCacheHitsThreshold) } +// GetPatrolRegionWorkerCount returns the worker count of the patrol. +func (o *PersistConfig) GetPatrolRegionWorkerCount() int { + return o.GetScheduleConfig().PatrolRegionWorkerCount +} + // GetMaxMovableHotPeerSize returns the max movable hot peer size. func (o *PersistConfig) GetMaxMovableHotPeerSize() int64 { return o.GetScheduleConfig().MaxMovableHotPeerSize diff --git a/pkg/schedule/checker/checker_controller.go b/pkg/schedule/checker/checker_controller.go index f9b75e942c9..5ac8e9e940e 100644 --- a/pkg/schedule/checker/checker_controller.go +++ b/pkg/schedule/checker/checker_controller.go @@ -18,6 +18,7 @@ import ( "bytes" "context" "strconv" + "sync" "time" "github.com/pingcap/failpoint" @@ -31,6 +32,7 @@ import ( "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/utils/keyutil" + "github.com/tikv/pd/pkg/utils/logutil" "github.com/tikv/pd/pkg/utils/syncutil" "go.uber.org/zap" ) @@ -47,6 +49,7 @@ const ( // MaxPatrolScanRegionLimit is the max limit of regions to scan for a batch. MaxPatrolScanRegionLimit = 8192 patrolRegionPartition = 1024 + patrolRegionChanLen = MaxPatrolScanRegionLimit ) var ( @@ -71,6 +74,7 @@ type Controller struct { priorityInspector *PriorityInspector pendingProcessedRegions *cache.TTLUint64 suspectKeyRanges *cache.TTLString // suspect key-range regions that may need fix + patrolRegionContext *PatrolRegionContext // duration is the duration of the last patrol round. // It's exported, so it should be protected by a mutex. @@ -82,6 +86,8 @@ type Controller struct { // It's used to update the ticker, so we need to // record it to avoid updating the ticker frequently. interval time.Duration + // workerCount is the count of workers to patrol regions. + workerCount int // patrolRegionScanLimit is the limit of regions to scan. // It is calculated by the number of regions. patrolRegionScanLimit int @@ -104,6 +110,7 @@ func NewController(ctx context.Context, cluster sche.CheckerCluster, conf config priorityInspector: NewPriorityInspector(cluster, conf), pendingProcessedRegions: pendingProcessedRegions, suspectKeyRanges: cache.NewStringTTL(ctx, time.Minute, 3*time.Minute), + patrolRegionContext: &PatrolRegionContext{}, interval: cluster.GetCheckerConfig().GetPatrolRegionInterval(), patrolRegionScanLimit: calculateScanLimit(cluster), } @@ -112,6 +119,9 @@ func NewController(ctx context.Context, cluster sche.CheckerCluster, conf config // PatrolRegions is used to scan regions. // The checkers will check these regions to decide if they need to do some operations. func (c *Controller) PatrolRegions() { + c.patrolRegionContext.init(c.ctx) + c.patrolRegionContext.startPatrolRegionWorkers(c) + defer c.patrolRegionContext.stop() ticker := time.NewTicker(c.interval) defer ticker.Stop() start := time.Now() @@ -123,11 +133,20 @@ func (c *Controller) PatrolRegions() { select { case <-ticker.C: c.updateTickerIfNeeded(ticker) + c.updatePatrolWorkersIfNeeded() if c.cluster.IsSchedulingHalted() { + for len(c.patrolRegionContext.regionChan) > 0 { + <-c.patrolRegionContext.regionChan + } log.Debug("skip patrol regions due to scheduling is halted") continue } + // wait for the regionChan to be drained + if len(c.patrolRegionContext.regionChan) > 0 { + continue + } + // Check priority regions first. c.checkPriorityRegions() // Check pending processed regions first. @@ -150,6 +169,9 @@ func (c *Controller) PatrolRegions() { start = time.Now() } failpoint.Inject("breakPatrol", func() { + for !c.IsPatrolRegionChanEmpty() { + time.Sleep(time.Millisecond * 10) + } failpoint.Return() }) case <-c.ctx.Done(): @@ -160,6 +182,32 @@ func (c *Controller) PatrolRegions() { } } +func (c *Controller) updateTickerIfNeeded(ticker *time.Ticker) { + // Note: we reset the ticker here to support updating configuration dynamically. + newInterval := c.cluster.GetCheckerConfig().GetPatrolRegionInterval() + if c.interval != newInterval { + c.interval = newInterval + ticker.Reset(newInterval) + log.Info("checkers starts patrol regions with new interval", zap.Duration("interval", newInterval)) + } +} + +func (c *Controller) updatePatrolWorkersIfNeeded() { + newWorkersCount := c.cluster.GetCheckerConfig().GetPatrolRegionWorkerCount() + if c.workerCount != newWorkersCount { + oldWorkersCount := c.workerCount + c.workerCount = newWorkersCount + // Stop the old workers and start the new workers. + c.patrolRegionContext.workersCancel() + c.patrolRegionContext.wg.Wait() + c.patrolRegionContext.workersCtx, c.patrolRegionContext.workersCancel = context.WithCancel(c.ctx) + c.patrolRegionContext.startPatrolRegionWorkers(c) + log.Info("checkers starts patrol regions with new workers count", + zap.Int("old-workers-count", oldWorkersCount), + zap.Int("new-workers-count", newWorkersCount)) + } +} + // GetPatrolRegionsDuration returns the duration of the last patrol region round. func (c *Controller) GetPatrolRegionsDuration() time.Duration { c.mu.RLock() @@ -182,7 +230,7 @@ func (c *Controller) checkRegions(startKey []byte) (key []byte, regions []*core. } for _, region := range regions { - c.tryAddOperators(region) + c.patrolRegionContext.regionChan <- region key = region.GetEndKey() } return @@ -446,13 +494,55 @@ func (c *Controller) GetPauseController(name string) (*PauseController, error) { } } -func (c *Controller) updateTickerIfNeeded(ticker *time.Ticker) { - // Note: we reset the ticker here to support updating configuration dynamically. - newInterval := c.cluster.GetCheckerConfig().GetPatrolRegionInterval() - if c.interval != newInterval { - c.interval = newInterval - ticker.Reset(newInterval) - log.Info("checkers starts patrol regions with new interval", zap.Duration("interval", newInterval)) +// IsPatrolRegionChanEmpty returns whether the patrol region channel is empty. +func (c *Controller) IsPatrolRegionChanEmpty() bool { + if c.patrolRegionContext == nil { + return true + } + return len(c.patrolRegionContext.regionChan) == 0 +} + +// PatrolRegionContext is used to store the context of patrol regions. +type PatrolRegionContext struct { + workersCtx context.Context + workersCancel context.CancelFunc + regionChan chan *core.RegionInfo + wg sync.WaitGroup +} + +func (p *PatrolRegionContext) init(ctx context.Context) { + p.regionChan = make(chan *core.RegionInfo, patrolRegionChanLen) + p.workersCtx, p.workersCancel = context.WithCancel(ctx) +} + +func (p *PatrolRegionContext) stop() { + log.Debug("closing patrol region workers") + close(p.regionChan) + p.workersCancel() + p.wg.Wait() + log.Debug("patrol region workers are closed") +} + +func (p *PatrolRegionContext) startPatrolRegionWorkers(c *Controller) { + for i := range c.workerCount { + p.wg.Add(1) + go func(i int) { + defer logutil.LogPanic() + defer p.wg.Done() + for { + select { + case region, ok := <-p.regionChan: + if !ok { + log.Debug("region channel is closed", zap.Int("worker-id", i)) + return + } + c.tryAddOperators(region) + case <-p.workersCtx.Done(): + log.Debug("region worker is closed", zap.Int("worker-id", i)) + return + } + } + }(i) } } diff --git a/pkg/schedule/checker/priority_inspector.go b/pkg/schedule/checker/priority_inspector.go index 65f9fb5f3dc..4cfdcf7df2f 100644 --- a/pkg/schedule/checker/priority_inspector.go +++ b/pkg/schedule/checker/priority_inspector.go @@ -22,6 +22,7 @@ import ( "github.com/tikv/pd/pkg/schedule/config" sche "github.com/tikv/pd/pkg/schedule/core" "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/utils/syncutil" ) // defaultPriorityQueueSize is the default value of priority queue size. @@ -31,16 +32,20 @@ const defaultPriorityQueueSize = 1280 type PriorityInspector struct { cluster sche.CheckerCluster conf config.CheckerConfigProvider - queue *cache.PriorityQueue + mu struct { + syncutil.RWMutex + queue *cache.PriorityQueue + } } // NewPriorityInspector creates a priority inspector. func NewPriorityInspector(cluster sche.CheckerCluster, conf config.CheckerConfigProvider) *PriorityInspector { - return &PriorityInspector{ + res := &PriorityInspector{ cluster: cluster, conf: conf, - queue: cache.NewPriorityQueue(defaultPriorityQueueSize), } + res.mu.queue = cache.NewPriorityQueue(defaultPriorityQueueSize) + return res } // RegionPriorityEntry records region priority info. @@ -99,24 +104,28 @@ func (p *PriorityInspector) inspectRegionInReplica(region *core.RegionInfo) (mak // It will remove if region's priority equal 0. // It's Attempt will increase if region's priority equal last. func (p *PriorityInspector) addOrRemoveRegion(priority int, regionID uint64) { + p.mu.Lock() + defer p.mu.Unlock() if priority < 0 { - if entry := p.queue.Get(regionID); entry != nil && entry.Priority == priority { + if entry := p.mu.queue.Get(regionID); entry != nil && entry.Priority == priority { e := entry.Value.(*RegionPriorityEntry) e.Attempt++ e.Last = time.Now() - p.queue.Put(priority, e) + p.mu.queue.Put(priority, e) } else { entry := NewRegionEntry(regionID) - p.queue.Put(priority, entry) + p.mu.queue.Put(priority, entry) } } else { - p.queue.Remove(regionID) + p.mu.queue.Remove(regionID) } } // GetPriorityRegions returns all regions in priority queue that needs rerun. func (p *PriorityInspector) GetPriorityRegions() (ids []uint64) { - entries := p.queue.Elems() + p.mu.RLock() + defer p.mu.RUnlock() + entries := p.mu.queue.Elems() for _, e := range entries { re := e.Value.(*RegionPriorityEntry) // avoid to some priority region occupy checker, region don't need check on next check interval @@ -130,11 +139,15 @@ func (p *PriorityInspector) GetPriorityRegions() (ids []uint64) { // RemovePriorityRegion removes priority region from priority queue. func (p *PriorityInspector) RemovePriorityRegion(regionID uint64) { - p.queue.Remove(regionID) + p.mu.Lock() + defer p.mu.Unlock() + p.mu.queue.Remove(regionID) } // getQueueLen returns the length of priority queue. // it's only used for test. func (p *PriorityInspector) getQueueLen() int { - return p.queue.Len() + p.mu.RLock() + defer p.mu.RUnlock() + return p.mu.queue.Len() } diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index 6b3f50d6d14..2d06f84fdfe 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -33,6 +33,7 @@ import ( "github.com/tikv/pd/pkg/schedule/operator" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/pkg/schedule/types" + "github.com/tikv/pd/pkg/utils/syncutil" "github.com/tikv/pd/pkg/versioninfo" "go.uber.org/zap" ) @@ -654,6 +655,7 @@ func (c *RuleChecker) handleFilterState(region *core.RegionInfo, filterByTempSta } type recorder struct { + syncutil.RWMutex offlineLeaderCounter map[uint64]uint64 lastUpdateTime time.Time } @@ -666,10 +668,14 @@ func newRecord() *recorder { } func (o *recorder) getOfflineLeaderCount(storeID uint64) uint64 { + o.RLock() + defer o.RUnlock() return o.offlineLeaderCounter[storeID] } func (o *recorder) incOfflineLeaderCount(storeID uint64) { + o.Lock() + defer o.Unlock() o.offlineLeaderCounter[storeID] += 1 o.lastUpdateTime = time.Now() } diff --git a/pkg/schedule/config/config.go b/pkg/schedule/config/config.go index 344569d6460..5e8f2c587ac 100644 --- a/pkg/schedule/config/config.go +++ b/pkg/schedule/config/config.go @@ -67,6 +67,9 @@ const ( defaultRegionScoreFormulaVersion = "v2" defaultLeaderSchedulePolicy = "count" defaultStoreLimitVersion = "v1" + defaultPatrolRegionWorkerCount = 1 + maxPatrolRegionWorkerCount = 8 + // DefaultSplitMergeInterval is the default value of config split merge interval. DefaultSplitMergeInterval = time.Hour defaultSwitchWitnessInterval = time.Hour @@ -306,6 +309,9 @@ type ScheduleConfig struct { // HaltScheduling is the option to halt the scheduling. Once it's on, PD will halt the scheduling, // and any other scheduling configs will be ignored. HaltScheduling bool `toml:"halt-scheduling" json:"halt-scheduling,string,omitempty"` + + // PatrolRegionWorkerCount is the number of workers to patrol region. + PatrolRegionWorkerCount int `toml:"patrol-region-worker-count" json:"patrol-region-worker-count"` } // Clone returns a cloned scheduling configuration. @@ -374,6 +380,9 @@ func (c *ScheduleConfig) Adjust(meta *configutil.ConfigMetaData, reloading bool) if !meta.IsDefined("store-limit-version") { configutil.AdjustString(&c.StoreLimitVersion, defaultStoreLimitVersion) } + if !meta.IsDefined("patrol-region-worker-count") { + configutil.AdjustInt(&c.PatrolRegionWorkerCount, defaultPatrolRegionWorkerCount) + } if !meta.IsDefined("enable-joint-consensus") { c.EnableJointConsensus = defaultEnableJointConsensus @@ -518,6 +527,9 @@ func (c *ScheduleConfig) Validate() error { if c.SlowStoreEvictingAffectedStoreRatioThreshold == 0 { return errors.Errorf("slow-store-evicting-affected-store-ratio-threshold is not set") } + if c.PatrolRegionWorkerCount > maxPatrolRegionWorkerCount || c.PatrolRegionWorkerCount < 1 { + return errors.Errorf("patrol-region-worker-count should be between 1 and %d", maxPatrolRegionWorkerCount) + } return nil } diff --git a/pkg/schedule/config/config_provider.go b/pkg/schedule/config/config_provider.go index 95bcad5add0..5c1be1089e9 100644 --- a/pkg/schedule/config/config_provider.go +++ b/pkg/schedule/config/config_provider.go @@ -89,6 +89,7 @@ type CheckerConfigProvider interface { GetIsolationLevel() string GetSplitMergeInterval() time.Duration GetPatrolRegionInterval() time.Duration + GetPatrolRegionWorkerCount() int GetMaxMergeRegionSize() uint64 GetMaxMergeRegionKeys() uint64 GetReplicaScheduleLimit() uint64 diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index dcf91f71b59..46a525a3e09 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -114,6 +114,9 @@ const ( heartbeatTaskRunner = "heartbeat-async" miscTaskRunner = "misc-async" logTaskRunner = "log-async" + + // TODO: make it configurable + IsTSODynamicSwitchingEnabled = false ) // Server is the interface for cluster. @@ -409,11 +412,30 @@ func (c *RaftCluster) checkSchedulingService() { // checkTSOService checks the TSO service. func (c *RaftCluster) checkTSOService() { if c.isAPIServiceMode { + if IsTSODynamicSwitchingEnabled { + servers, err := discovery.Discover(c.etcdClient, constant.TSOServiceName) + if err != nil || len(servers) == 0 { + if err := c.startTSOJobsIfNeeded(); err != nil { + log.Error("failed to start TSO jobs", errs.ZapError(err)) + return + } + log.Info("TSO is provided by PD") + c.UnsetServiceIndependent(constant.TSOServiceName) + } else { + if err := c.startTSOJobsIfNeeded(); err != nil { + log.Error("failed to stop TSO jobs", errs.ZapError(err)) + return + } + log.Info("TSO is provided by TSO server") + if !c.IsServiceIndependent(constant.TSOServiceName) { + c.SetServiceIndependent(constant.TSOServiceName) + } + } + } return } - if err := c.startTSOJobs(); err != nil { - // If there is an error, need to wait for the next check. + if err := c.startTSOJobsIfNeeded(); err != nil { log.Error("failed to start TSO jobs", errs.ZapError(err)) return } @@ -428,6 +450,8 @@ func (c *RaftCluster) runServiceCheckJob() { schedulingTicker.Reset(time.Millisecond) }) defer schedulingTicker.Stop() + tsoTicker := time.NewTicker(tsoServiceCheckInterval) + defer tsoTicker.Stop() for { select { @@ -436,11 +460,13 @@ func (c *RaftCluster) runServiceCheckJob() { return case <-schedulingTicker.C: c.checkSchedulingService() + case <-tsoTicker.C: + c.checkTSOService() } } } -func (c *RaftCluster) startTSOJobs() error { +func (c *RaftCluster) startTSOJobsIfNeeded() error { allocator, err := c.tsoAllocator.GetAllocator(tso.GlobalDCLocation) if err != nil { log.Error("failed to get global TSO allocator", errs.ZapError(err)) @@ -456,7 +482,7 @@ func (c *RaftCluster) startTSOJobs() error { return nil } -func (c *RaftCluster) stopTSOJobs() error { +func (c *RaftCluster) stopTSOJobsIfNeeded() error { allocator, err := c.tsoAllocator.GetAllocator(tso.GlobalDCLocation) if err != nil { log.Error("failed to get global TSO allocator", errs.ZapError(err)) @@ -824,7 +850,7 @@ func (c *RaftCluster) Stop() { if !c.IsServiceIndependent(constant.SchedulingServiceName) { c.stopSchedulingJobs() } - if err := c.stopTSOJobs(); err != nil { + if err := c.stopTSOJobsIfNeeded(); err != nil { log.Error("failed to stop tso jobs", errs.ZapError(err)) } c.heartbeatRunner.Stop() diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index ac7bf5f1443..9d3a3d44590 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -2873,6 +2873,8 @@ func TestCheckCache(t *testing.T) { cfg.ReplicaScheduleLimit = 0 }, nil, nil, re) defer cleanup() + oc := co.GetOperatorController() + checker := co.GetCheckerController() re.NoError(tc.addRegionStore(1, 0)) re.NoError(tc.addRegionStore(2, 0)) @@ -2883,40 +2885,96 @@ func TestCheckCache(t *testing.T) { re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/checker/breakPatrol", `return`)) // case 1: operator cannot be created due to replica-schedule-limit restriction - co.GetWaitGroup().Add(1) - co.PatrolRegions() - re.Len(co.GetCheckerController().GetPendingProcessedRegions(), 1) + checker.PatrolRegions() + re.Empty(oc.GetOperators()) + re.Len(checker.GetPendingProcessedRegions(), 1) // cancel the replica-schedule-limit restriction cfg := tc.GetScheduleConfig() cfg.ReplicaScheduleLimit = 10 tc.SetScheduleConfig(cfg) - co.GetWaitGroup().Add(1) - co.PatrolRegions() - oc := co.GetOperatorController() + checker.PatrolRegions() re.Len(oc.GetOperators(), 1) - re.Empty(co.GetCheckerController().GetPendingProcessedRegions()) + re.Empty(checker.GetPendingProcessedRegions()) // case 2: operator cannot be created due to store limit restriction oc.RemoveOperator(oc.GetOperator(1)) tc.SetStoreLimit(1, storelimit.AddPeer, 0) - co.GetWaitGroup().Add(1) - co.PatrolRegions() - re.Len(co.GetCheckerController().GetPendingProcessedRegions(), 1) + checker.PatrolRegions() + re.Len(checker.GetPendingProcessedRegions(), 1) // cancel the store limit restriction tc.SetStoreLimit(1, storelimit.AddPeer, 10) time.Sleep(time.Second) - co.GetWaitGroup().Add(1) - co.PatrolRegions() + checker.PatrolRegions() re.Len(oc.GetOperators(), 1) - re.Empty(co.GetCheckerController().GetPendingProcessedRegions()) + re.Empty(checker.GetPendingProcessedRegions()) + + co.GetSchedulersController().Wait() + co.GetWaitGroup().Wait() + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/checker/breakPatrol")) +} + +func TestPatrolRegionConcurrency(t *testing.T) { + re := require.New(t) + + regionNum := 10000 + mergeScheduleLimit := 15 + + tc, co, cleanup := prepare(func(cfg *sc.ScheduleConfig) { + cfg.PatrolRegionWorkerCount = 8 + cfg.MergeScheduleLimit = uint64(mergeScheduleLimit) + }, nil, nil, re) + defer cleanup() + oc := co.GetOperatorController() + checker := co.GetCheckerController() + + tc.opt.SetSplitMergeInterval(time.Duration(0)) + for i := range 3 { + if err := tc.addRegionStore(uint64(i+1), regionNum); err != nil { + return + } + } + for i := range regionNum { + if err := tc.addLeaderRegion(uint64(i), 1, 2, 3); err != nil { + return + } + } + // test patrol region concurrency + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/checker/breakPatrol", `return`)) + checker.PatrolRegions() + testutil.Eventually(re, func() bool { + return len(oc.GetOperators()) >= mergeScheduleLimit + }) + checkOperatorDuplicate(re, oc.GetOperators()) + + // test patrol region concurrency with suspect regions + suspectRegions := make([]uint64, 0) + for i := range 10 { + suspectRegions = append(suspectRegions, uint64(i)) + } + checker.AddPendingProcessedRegions(false, suspectRegions...) + checker.PatrolRegions() + testutil.Eventually(re, func() bool { + return len(oc.GetOperators()) >= mergeScheduleLimit + }) + checkOperatorDuplicate(re, oc.GetOperators()) co.GetSchedulersController().Wait() co.GetWaitGroup().Wait() re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/checker/breakPatrol")) } +func checkOperatorDuplicate(re *require.Assertions, ops []*operator.Operator) { + regionMap := make(map[uint64]struct{}) + for _, op := range ops { + if _, ok := regionMap[op.RegionID()]; ok { + re.Fail("duplicate operator") + } + regionMap[op.RegionID()] = struct{}{} + } +} + func TestScanLimit(t *testing.T) { re := require.New(t) @@ -2951,8 +3009,7 @@ func checkScanLimit(re *require.Assertions, regionCount int, expectScanLimit ... re.NoError(tc.putRegion(region)) } - co.GetWaitGroup().Add(1) - co.PatrolRegions() + co.GetCheckerController().PatrolRegions() defer func() { co.GetSchedulersController().Wait() co.GetWaitGroup().Wait() @@ -3443,9 +3500,9 @@ func BenchmarkPatrolRegion(b *testing.B) { }() <-listen - co.GetWaitGroup().Add(1) b.ResetTimer() - co.PatrolRegions() + checker := co.GetCheckerController() + checker.PatrolRegions() } func waitOperator(re *require.Assertions, co *schedule.Coordinator, regionID uint64) { diff --git a/server/config/persist_options.go b/server/config/persist_options.go index 807e9699a25..c426e9d2420 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -659,6 +659,11 @@ func (o *PersistOptions) GetHotRegionCacheHitsThreshold() int { return int(o.GetScheduleConfig().HotRegionCacheHitsThreshold) } +// GetPatrolRegionWorkerCount returns the worker count of the patrol. +func (o *PersistOptions) GetPatrolRegionWorkerCount() int { + return o.GetScheduleConfig().PatrolRegionWorkerCount +} + // GetStoresLimit gets the stores' limit. func (o *PersistOptions) GetStoresLimit() map[uint64]sc.StoreLimitConfig { return o.GetScheduleConfig().StoreLimit diff --git a/server/grpc_service.go b/server/grpc_service.go index 25d5d3ed8e7..d5fd8ae3e32 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -529,10 +529,29 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { return s.forwardTSO(stream) } + tsDeadlineCh := make(chan *tsoutil.TSDeadline, 1) + go tsoutil.WatchTSDeadline(stream.Context(), tsDeadlineCh) + var ( doneCh chan struct{} errCh chan error + // The following are tso forward stream related variables. + forwardStream tsopb.TSO_TsoClient + cancelForward context.CancelFunc + forwardCtx context.Context + tsoStreamErr error + lastForwardedHost string ) + + defer func() { + if cancelForward != nil { + cancelForward() + } + if grpcutil.NeedRebuildConnection(tsoStreamErr) { + s.closeDelegateClient(lastForwardedHost) + } + }() + ctx, cancel := context.WithCancel(stream.Context()) defer cancel() for { @@ -570,6 +589,21 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error { continue } + if s.IsServiceIndependent(constant.TSOServiceName) { + if request.GetCount() == 0 { + err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive") + return status.Error(codes.Unknown, err.Error()) + } + forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, nil, request, tsDeadlineCh, lastForwardedHost, cancelForward) + if tsoStreamErr != nil { + return tsoStreamErr + } + if err != nil { + return err + } + continue + } + start := time.Now() // TSO uses leader lease to determine validity. No need to check leader here. if s.IsClosed() { diff --git a/server/server.go b/server/server.go index 760b185a6ff..c88871658dc 100644 --- a/server/server.go +++ b/server/server.go @@ -1411,8 +1411,7 @@ func (s *Server) GetRaftCluster() *cluster.RaftCluster { // IsServiceIndependent returns whether the service is independent. func (s *Server) IsServiceIndependent(name string) bool { if s.mode == APIServiceMode && !s.IsClosed() { - // TODO: remove it after we support tso discovery - if name == constant.TSOServiceName { + if name == constant.TSOServiceName && !cluster.IsTSODynamicSwitchingEnabled { return true } return s.cluster.IsServiceIndependent(name) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index c0143760bdd..0620b624c39 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1864,6 +1864,12 @@ func TestPatrolRegionConfigChange(t *testing.T) { leaderServer.GetServer().SetScheduleConfig(schedule) checkLog(re, fname, "starts patrol regions with new interval") + // test change patrol region worker count + schedule = leaderServer.GetConfig().Schedule + schedule.PatrolRegionWorkerCount = 8 + leaderServer.GetServer().SetScheduleConfig(schedule) + checkLog(re, fname, "starts patrol regions with new workers count") + // test change schedule halt schedule = leaderServer.GetConfig().Schedule schedule.HaltScheduling = true diff --git a/tools/pd-ctl/tests/config/config_test.go b/tools/pd-ctl/tests/config/config_test.go index c3697c065e7..7391f57366e 100644 --- a/tools/pd-ctl/tests/config/config_test.go +++ b/tools/pd-ctl/tests/config/config_test.go @@ -345,6 +345,23 @@ func (suite *configTestSuite) checkConfig(cluster *pdTests.TestCluster) { output, err = tests.ExecuteCommand(cmd, argsInvalid...) re.NoError(err) re.Contains(string(output), "is invalid") + + // config set patrol-region-worker-count + args = []string{"-u", pdAddr, "config", "set", "patrol-region-worker-count", "8"} + _, err = tests.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Equal(8, svr.GetScheduleConfig().PatrolRegionWorkerCount) + // the max value of patrol-region-worker-count is 8 and the min value is 1 + args = []string{"-u", pdAddr, "config", "set", "patrol-region-worker-count", "9"} + output, err = tests.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Contains(string(output), "patrol-region-worker-count should be between 1 and 8") + re.Equal(8, svr.GetScheduleConfig().PatrolRegionWorkerCount) + args = []string{"-u", pdAddr, "config", "set", "patrol-region-worker-count", "0"} + output, err = tests.ExecuteCommand(cmd, args...) + re.NoError(err) + re.Contains(string(output), "patrol-region-worker-count should be between 1 and 8") + re.Equal(8, svr.GetScheduleConfig().PatrolRegionWorkerCount) } func (suite *configTestSuite) TestConfigForwardControl() {