diff --git a/pkg/disttask/framework/dispatcher/BUILD.bazel b/pkg/disttask/framework/dispatcher/BUILD.bazel index 4e08190d52257..e4992f060b0a4 100644 --- a/pkg/disttask/framework/dispatcher/BUILD.bazel +++ b/pkg/disttask/framework/dispatcher/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "dispatcher.go", "dispatcher_manager.go", "interface.go", + "slots.go", "state_transform.go", ], importpath = "github.com/pingcap/tidb/pkg/disttask/framework/dispatcher", @@ -19,6 +20,7 @@ go_library( "//pkg/resourcemanager/util", "//pkg/sessionctx", "//pkg/util", + "//pkg/util/cpu", "//pkg/util/disttask", "//pkg/util/intest", "//pkg/util/logutil", @@ -37,20 +39,24 @@ go_test( "dispatcher_test.go", "main_test.go", "rebalance_test.go", + "slots_test.go", ], embed = [":dispatcher"], flaky = True, race = "off", - shard_count = 19, + shard_count = 22, deps = [ "//pkg/disttask/framework/dispatcher/mock", "//pkg/disttask/framework/mock", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/domain/infosync", "//pkg/kv", + "//pkg/sessionctx", "//pkg/testkit", "//pkg/testkit/testsetup", + "//pkg/util/disttask", "//pkg/util/logutil", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", diff --git a/pkg/disttask/framework/dispatcher/dispatcher.go b/pkg/disttask/framework/dispatcher/dispatcher.go index e8c1671531406..c0c0a77b7d703 100644 --- a/pkg/disttask/framework/dispatcher/dispatcher.go +++ b/pkg/disttask/framework/dispatcher/dispatcher.go @@ -133,7 +133,7 @@ func (*BaseDispatcher) Init() error { // ExecuteTask implements the Dispatcher interface. func (d *BaseDispatcher) ExecuteTask() { logutil.Logger(d.logCtx).Info("execute one task", - zap.Stringer("state", d.Task.State), zap.Uint64("concurrency", d.Task.Concurrency)) + zap.Stringer("state", d.Task.State), zap.Int("concurrency", d.Task.Concurrency)) d.scheduleTask() } @@ -570,7 +570,7 @@ func (d *BaseDispatcher) onErrHandlingStage(receiveErrs []error) error { subTasks = make([]*proto.Subtask, 0, len(instanceIDs)) for _, id := range instanceIDs { // reverting subtasks belong to the same step as current active step. - subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, int(d.Task.Concurrency), []byte("{}"))) + subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, d.Task.Concurrency, []byte("{}"))) } } return d.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) @@ -686,7 +686,11 @@ func (d *BaseDispatcher) dispatchSubTask( subtaskStep proto.Step, metas [][]byte, serverNodes []*infosync.ServerInfo) error { - logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.Stringer("state", d.Task.State), zap.Int64("step", int64(d.Task.Step)), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas))) + logutil.Logger(d.logCtx).Info("dispatch subtasks", + zap.Stringer("state", d.Task.State), + zap.Int64("step", int64(d.Task.Step)), + zap.Int("concurrency", d.Task.Concurrency), + zap.Int("subtasks", len(metas))) d.TaskNodes = make([]string, len(serverNodes)) for i := range serverNodes { d.TaskNodes[i] = disttaskutil.GenerateExecID(serverNodes[i].IP, serverNodes[i].Port) @@ -698,7 +702,7 @@ func (d *BaseDispatcher) dispatchSubTask( pos := i % len(serverNodes) instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) - subTasks = append(subTasks, proto.NewSubtask(subtaskStep, d.Task.ID, d.Task.Type, instanceID, int(d.Task.Concurrency), meta)) + subTasks = append(subTasks, proto.NewSubtask(subtaskStep, d.Task.ID, d.Task.Type, instanceID, d.Task.Concurrency, meta)) } failpoint.Inject("cancelBeforeUpdateTask", func() { _ = d.updateTask(proto.TaskStateCancelling, subTasks, RetrySQLTimes) @@ -750,22 +754,19 @@ func GenerateTaskExecutorNodes(ctx context.Context) (serverNodes []*infosync.Ser } func (d *BaseDispatcher) filterByRole(infos []*infosync.ServerInfo) ([]*infosync.ServerInfo, error) { - nodes, err := d.taskMgr.GetNodesByRole(d.ctx, "background") + nodes, err := d.taskMgr.GetManagedNodes(d.ctx) if err != nil { return nil, err } - if len(nodes) == 0 { - nodes, err = d.taskMgr.GetNodesByRole(d.ctx, "") - } - - if err != nil { - return nil, err + nodeMap := make(map[string]struct{}, len(nodes)) + for _, node := range nodes { + nodeMap[node] = struct{}{} } res := make([]*infosync.ServerInfo, 0, len(nodes)) for _, info := range infos { - _, ok := nodes[disttaskutil.GenerateExecID(info.IP, info.Port)] + _, ok := nodeMap[disttaskutil.GenerateExecID(info.IP, info.Port)] if ok { res = append(res, info) } diff --git a/pkg/disttask/framework/dispatcher/dispatcher_manager.go b/pkg/disttask/framework/dispatcher/dispatcher_manager.go index 92d86958b456a..d0c7e6b6a3214 100644 --- a/pkg/disttask/framework/dispatcher/dispatcher_manager.go +++ b/pkg/disttask/framework/dispatcher/dispatcher_manager.go @@ -32,8 +32,6 @@ import ( ) var ( - // DefaultDispatchConcurrency is the default concurrency for dispatching task. - DefaultDispatchConcurrency = 4 // checkTaskRunningInterval is the interval for loading tasks. checkTaskRunningInterval = 3 * time.Second // defaultHistorySubtaskTableGcInterval is the interval of gc history subtask table. @@ -45,63 +43,55 @@ var ( // WaitTaskFinished is used to sync the test. var WaitTaskFinished = make(chan struct{}) -func (dm *Manager) getRunningTaskCnt() int { - dm.runningTasks.RLock() - defer dm.runningTasks.RUnlock() - return len(dm.runningTasks.taskIDs) +func (dm *Manager) getDispatcherCount() int { + dm.mu.RLock() + defer dm.mu.RUnlock() + return len(dm.mu.dispatchers) } -func (dm *Manager) setRunningTask(task *proto.Task, dispatcher Dispatcher) { - dm.runningTasks.Lock() - defer dm.runningTasks.Unlock() - dm.runningTasks.taskIDs[task.ID] = struct{}{} - dm.runningTasks.dispatchers[task.ID] = dispatcher - metrics.UpdateMetricsForRunTask(task) +func (dm *Manager) addDispatcher(taskID int64, dispatcher Dispatcher) { + dm.mu.Lock() + defer dm.mu.Unlock() + dm.mu.dispatchers[taskID] = dispatcher } -func (dm *Manager) isRunningTask(taskID int64) bool { - dm.runningTasks.Lock() - defer dm.runningTasks.Unlock() - _, ok := dm.runningTasks.taskIDs[taskID] +func (dm *Manager) hasDispatcher(taskID int64) bool { + dm.mu.Lock() + defer dm.mu.Unlock() + _, ok := dm.mu.dispatchers[taskID] return ok } -func (dm *Manager) delRunningTask(taskID int64) { - dm.runningTasks.Lock() - defer dm.runningTasks.Unlock() - delete(dm.runningTasks.taskIDs, taskID) - delete(dm.runningTasks.dispatchers, taskID) +func (dm *Manager) delDispatcher(taskID int64) { + dm.mu.Lock() + defer dm.mu.Unlock() + delete(dm.mu.dispatchers, taskID) } -func (dm *Manager) clearRunningTasks() { - dm.runningTasks.Lock() - defer dm.runningTasks.Unlock() - for id := range dm.runningTasks.dispatchers { - delete(dm.runningTasks.dispatchers, id) - } - for id := range dm.runningTasks.taskIDs { - delete(dm.runningTasks.taskIDs, id) - } +func (dm *Manager) clearDispatchers() { + dm.mu.Lock() + defer dm.mu.Unlock() + dm.mu.dispatchers = make(map[int64]Dispatcher) } // Manager manage a bunch of dispatchers. // Dispatcher schedule and monitor tasks. // The scheduling task number is limited by size of gPool. type Manager struct { - ctx context.Context - cancel context.CancelFunc - taskMgr TaskManager - wg tidbutil.WaitGroupWrapper - gPool *spool.Pool - inited bool + ctx context.Context + cancel context.CancelFunc + taskMgr TaskManager + wg tidbutil.WaitGroupWrapper + gPool *spool.Pool + slotMgr *slotManager + initialized bool // serverID, it's value is ip:port now. serverID string finishCh chan struct{} - runningTasks struct { + mu struct { syncutil.RWMutex - taskIDs map[int64]struct{} dispatchers map[int64]Dispatcher } } @@ -111,16 +101,16 @@ func NewManager(ctx context.Context, taskMgr TaskManager, serverID string) (*Man dispatcherManager := &Manager{ taskMgr: taskMgr, serverID: serverID, + slotMgr: newSlotManager(), } - gPool, err := spool.NewPool("dispatch_pool", int32(DefaultDispatchConcurrency), util.DistTask, spool.WithBlocking(true)) + gPool, err := spool.NewPool("dispatch_pool", int32(proto.MaxConcurrentTask), util.DistTask, spool.WithBlocking(true)) if err != nil { return nil, err } dispatcherManager.gPool = gPool dispatcherManager.ctx, dispatcherManager.cancel = context.WithCancel(ctx) - dispatcherManager.runningTasks.taskIDs = make(map[int64]struct{}) - dispatcherManager.runningTasks.dispatchers = make(map[int64]Dispatcher) - dispatcherManager.finishCh = make(chan struct{}, DefaultDispatchConcurrency) + dispatcherManager.mu.dispatchers = make(map[int64]Dispatcher) + dispatcherManager.finishCh = make(chan struct{}, proto.MaxConcurrentTask) return dispatcherManager, nil } @@ -133,7 +123,7 @@ func (dm *Manager) Start() { dm.wg.Run(dm.dispatchTaskLoop) dm.wg.Run(dm.gcSubtaskHistoryTableLoop) dm.wg.Run(dm.cleanUpLoop) - dm.inited = true + dm.initialized = true } // Stop the dispatcherManager. @@ -141,14 +131,14 @@ func (dm *Manager) Stop() { dm.cancel() dm.gPool.ReleaseAndWait() dm.wg.Wait() - dm.clearRunningTasks() - dm.inited = false + dm.clearDispatchers() + dm.initialized = false close(dm.finishCh) } -// Inited check the manager inited. -func (dm *Manager) Inited() bool { - return dm.inited +// Initialized check the manager initialized. +func (dm *Manager) Initialized() bool { + return dm.initialized } // dispatchTaskLoop dispatches the tasks. @@ -162,71 +152,67 @@ func (dm *Manager) dispatchTaskLoop() { logutil.BgLogger().Info("dispatch task loop exits", zap.Error(dm.ctx.Err()), zap.Int64("interval", int64(checkTaskRunningInterval)/1000000)) return case <-ticker.C: - cnt := dm.getRunningTaskCnt() - if dm.checkConcurrencyOverflow(cnt) { - break - } + } - // TODO: Consider getting these tasks, in addition to the task being worked on.. - tasks, err := dm.taskMgr.GetTasksInStates( - dm.ctx, - proto.TaskStatePending, - proto.TaskStateRunning, - proto.TaskStateReverting, - proto.TaskStateCancelling, - proto.TaskStateResuming, - ) - if err != nil { - logutil.BgLogger().Warn("get unfinished(pending, running, reverting, cancelling, resuming) tasks failed", zap.Error(err)) - break + taskCnt := dm.getDispatcherCount() + if taskCnt >= proto.MaxConcurrentTask { + logutil.BgLogger().Info("dispatched tasks reached limit", + zap.Int("current", taskCnt), zap.Int("max", proto.MaxConcurrentTask)) + continue + } + + tasks, err := dm.taskMgr.GetTopUnfinishedTasks(dm.ctx) + if err != nil { + logutil.BgLogger().Warn("get unfinished tasks failed", zap.Error(err)) + continue + } + + dispatchableTasks := make([]*proto.Task, 0, len(tasks)) + for _, task := range tasks { + if dm.hasDispatcher(task.ID) { + continue } + // we check it before start dispatcher, so no need to check it again. + // see startDispatcher. + // this should not happen normally, unless user modify system table + // directly. + if getDispatcherFactory(task.Type) == nil { + logutil.BgLogger().Warn("unknown task type", zap.Int64("task-id", task.ID), + zap.Stringer("task-type", task.Type)) + dm.failTask(task.ID, task.State, errors.New("unknown task type")) + continue + } + dispatchableTasks = append(dispatchableTasks, task) + } + if len(dispatchableTasks) == 0 { + continue + } - // There are currently no tasks to work on. - if len(tasks) == 0 { + if err = dm.slotMgr.update(dm.ctx, dm.taskMgr); err != nil { + logutil.BgLogger().Warn("update used slot failed", zap.Error(err)) + continue + } + for _, task := range dispatchableTasks { + taskCnt = dm.getDispatcherCount() + if taskCnt >= proto.MaxConcurrentTask { break } - for _, task := range tasks { - // This task is running, so no need to reprocess it. - if dm.isRunningTask(task.ID) { - continue - } - metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.DispatchingStatus).Inc() - // we check it before start dispatcher, so no need to check it again. - // see startDispatcher. - // this should not happen normally, unless user modify system table - // directly. - if getDispatcherFactory(task.Type) == nil { - logutil.BgLogger().Warn("unknown task type", zap.Int64("task-id", task.ID), - zap.Stringer("task-type", task.Type)) - dm.failTask(task, errors.New("unknown task type")) - continue - } - // the task is not in runningTasks set when: - // owner changed or task is cancelled when status is pending. - if task.State == proto.TaskStateRunning || task.State == proto.TaskStateReverting || task.State == proto.TaskStateCancelling { - metrics.UpdateMetricsForDispatchTask(task) - dm.startDispatcher(task) - cnt++ - continue - } - if dm.checkConcurrencyOverflow(cnt) { - break - } - metrics.UpdateMetricsForDispatchTask(task) - dm.startDispatcher(task) - cnt++ + reservedExecID, ok := dm.slotMgr.canReserve(task) + if !ok { + // task of lower priority might be able to be dispatched. + continue } + metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.DispatchingStatus).Inc() + metrics.UpdateMetricsForDispatchTask(task.ID, task.Type) + dm.startDispatcher(task, reservedExecID) } } } -func (dm *Manager) failTask(task *proto.Task, err error) { - prevState := task.State - task.State = proto.TaskStateFailed - task.Error = err - if _, err2 := dm.taskMgr.UpdateTaskAndAddSubTasks(dm.ctx, task, nil, prevState); err2 != nil { +func (dm *Manager) failTask(id int64, currState proto.TaskState, err error) { + if err2 := dm.taskMgr.FailTask(dm.ctx, id, currState, err); err2 != nil { logutil.BgLogger().Warn("failed to update task state to failed", - zap.Int64("task-id", task.ID), zap.Error(err2)) + zap.Int64("task-id", id), zap.Error(err2)) } } @@ -259,30 +245,30 @@ func (dm *Manager) gcSubtaskHistoryTableLoop() { } } -func (*Manager) checkConcurrencyOverflow(cnt int) bool { - if cnt >= DefaultDispatchConcurrency { - logutil.BgLogger().Info("dispatch task loop, running task cnt is more than concurrency limitation", - zap.Int("running cnt", cnt), zap.Int("concurrency", DefaultDispatchConcurrency)) - return true +func (dm *Manager) startDispatcher(basicTask *proto.Task, reservedExecID string) { + task, err := dm.taskMgr.GetTaskByID(dm.ctx, basicTask.ID) + if err != nil { + logutil.BgLogger().Error("get task failed", zap.Error(err)) + return } - return false -} -func (dm *Manager) startDispatcher(task *proto.Task) { + dispatcherFactory := getDispatcherFactory(task.Type) + dispatcher := dispatcherFactory(dm.ctx, dm.taskMgr, dm.serverID, task) + if err = dispatcher.Init(); err != nil { + logutil.BgLogger().Error("init dispatcher failed", zap.Error(err)) + dm.failTask(task.ID, task.State, err) + return + } + dm.addDispatcher(task.ID, dispatcher) + dm.slotMgr.reserve(basicTask, reservedExecID) // Using the pool with block, so it wouldn't return an error. _ = dm.gPool.Run(func() { - dispatcherFactory := getDispatcherFactory(task.Type) - dispatcher := dispatcherFactory(dm.ctx, dm.taskMgr, dm.serverID, task) - if err := dispatcher.Init(); err != nil { - logutil.BgLogger().Error("init dispatcher failed", zap.Error(err)) - dm.failTask(task, err) - return - } defer func() { dispatcher.Close() - dm.delRunningTask(task.ID) + dm.delDispatcher(task.ID) + dm.slotMgr.unReserve(basicTask, reservedExecID) }() - dm.setRunningTask(task, dispatcher) + metrics.UpdateMetricsForRunTask(task) dispatcher.ExecuteTask() logutil.BgLogger().Info("task finished", zap.Int64("task-id", task.ID)) dm.finishCh <- struct{}{} diff --git a/pkg/disttask/framework/dispatcher/dispatcher_test.go b/pkg/disttask/framework/dispatcher/dispatcher_test.go index e7679217b0b0c..42115628194ac 100644 --- a/pkg/disttask/framework/dispatcher/dispatcher_test.go +++ b/pkg/disttask/framework/dispatcher/dispatcher_test.go @@ -17,7 +17,9 @@ package dispatcher_test import ( "context" "fmt" + "slices" "strings" + "sync/atomic" "testing" "time" @@ -29,9 +31,12 @@ import ( "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/domain/infosync" "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" "github.com/pingcap/tidb/pkg/util/logutil" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -198,8 +203,7 @@ func TestGetInstance(t *testing.T) { TaskID: task.ID, ExecID: serverIDs[1], } - err = mgr.CreateSubTask(ctx, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, true) - require.NoError(t, err) + testutil.CreateSubTask(t, mgr, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, 11, true) instanceIDs, err = dsp.GetAllTaskExecutorIDs(ctx, task) require.NoError(t, err) require.Equal(t, []string{serverIDs[1]}, instanceIDs) @@ -210,8 +214,7 @@ func TestGetInstance(t *testing.T) { TaskID: task.ID, ExecID: serverIDs[0], } - err = mgr.CreateSubTask(ctx, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, true) - require.NoError(t, err) + testutil.CreateSubTask(t, mgr, task.ID, proto.StepInit, subtask.ExecID, nil, subtask.Type, 11, true) instanceIDs, err = dsp.GetAllTaskExecutorIDs(ctx, task) require.NoError(t, err) require.Len(t, instanceIDs, len(serverIDs)) @@ -271,8 +274,8 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, // test parallelism control var originalConcurrency int if taskCnt == 1 { - originalConcurrency = dispatcher.DefaultDispatchConcurrency - dispatcher.DefaultDispatchConcurrency = 1 + originalConcurrency = proto.MaxConcurrentTask + proto.MaxConcurrentTask = 1 } store := testkit.CreateMockStore(t) @@ -293,7 +296,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, dsp.Stop() // make data race happy if taskCnt == 1 { - dispatcher.DefaultDispatchConcurrency = originalConcurrency + proto.MaxConcurrentTask = originalConcurrency } }() @@ -513,3 +516,108 @@ func TestIsCancelledErr(t *testing.T) { require.False(t, dispatcher.IsCancelledErr(context.Canceled)) require.True(t, dispatcher.IsCancelledErr(errors.New("cancelled by user"))) } + +func TestManagerDispatchLoop(t *testing.T) { + // Mock 16 cpu node. + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(16)")) + t.Cleanup(func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu")) + }) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockDispatcher := mock.NewMockDispatcher(ctrl) + + _ = testkit.CreateMockStore(t) + require.Eventually(t, func() bool { + taskMgr, err := storage.GetTaskManager() + return err == nil && taskMgr != nil + }, 10*time.Second, 100*time.Millisecond) + + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "dispatcher") + taskMgr, err := storage.GetTaskManager() + require.NoError(t, err) + require.NotNil(t, taskMgr) + + // in this test, we only test dispatcher manager, so we add a subtask takes 16 + // slots to avoid reserve by slots, and make sure below test cases works. + serverInfos, err := infosync.GetAllServerInfo(ctx) + require.NoError(t, err) + for _, s := range serverInfos { + execID := disttaskutil.GenerateExecID(s.IP, s.Port) + testutil.InsertSubtask(t, taskMgr, 1000000, proto.StepOne, execID, []byte(""), proto.TaskStatePending, proto.TaskTypeExample, 16) + } + concurrencies := []int{4, 6, 16, 2, 4, 4} + waitChannels := make([]chan struct{}, len(concurrencies)) + for i := range waitChannels { + waitChannels[i] = make(chan struct{}) + } + var counter atomic.Int32 + dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample, + func(ctx context.Context, taskMgr dispatcher.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { + idx := counter.Load() + mockDispatcher = mock.NewMockDispatcher(ctrl) + mockDispatcher.EXPECT().Init().Return(nil) + mockDispatcher.EXPECT().ExecuteTask().Do(func() { + require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { + _, err := storage.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", + proto.TaskStateRunning, proto.StepOne, task.ID) + return err + })) + <-waitChannels[idx] + require.NoError(t, taskMgr.WithNewSession(func(se sessionctx.Context) error { + _, err := storage.ExecSQL(ctx, se, "update mysql.tidb_global_task set state=%?, step=%? where id=%?", + proto.TaskStateSucceed, proto.StepDone, task.ID) + return err + })) + }) + mockDispatcher.EXPECT().Close() + counter.Add(1) + return mockDispatcher + }, + ) + for i := 0; i < len(concurrencies); i++ { + _, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], []byte("{}")) + require.NoError(t, err) + } + getRunningTaskKeys := func() []string { + tasks, err := taskMgr.GetTasksInStates(ctx, proto.TaskStateRunning) + require.NoError(t, err) + taskKeys := make([]string, len(tasks)) + for i, task := range tasks { + taskKeys[i] = task.Key + } + slices.Sort(taskKeys) + return taskKeys + } + require.Eventually(t, func() bool { + taskKeys := getRunningTaskKeys() + return err == nil && len(taskKeys) == 4 && + taskKeys[0] == "key/0" && taskKeys[1] == "key/1" && + taskKeys[2] == "key/3" && taskKeys[3] == "key/4" + }, time.Second*10, time.Millisecond*100) + // finish the first task + close(waitChannels[0]) + require.Eventually(t, func() bool { + taskKeys := getRunningTaskKeys() + return err == nil && len(taskKeys) == 4 && + taskKeys[0] == "key/1" && taskKeys[1] == "key/3" && + taskKeys[2] == "key/4" && taskKeys[3] == "key/5" + }, time.Second*10, time.Millisecond*100) + // finish the second task + close(waitChannels[1]) + require.Eventually(t, func() bool { + taskKeys := getRunningTaskKeys() + return err == nil && len(taskKeys) == 4 && + taskKeys[0] == "key/2" && taskKeys[1] == "key/3" && + taskKeys[2] == "key/4" && taskKeys[3] == "key/5" + }, time.Second*10, time.Millisecond*100) + // close others + for i := 2; i < len(concurrencies); i++ { + close(waitChannels[i]) + } + require.Eventually(t, func() bool { + taskKeys := getRunningTaskKeys() + return err == nil && len(taskKeys) == 0 + }, time.Second*10, time.Millisecond*100) +} diff --git a/pkg/disttask/framework/dispatcher/interface.go b/pkg/disttask/framework/dispatcher/interface.go index 05a387f175c86..3e42dd6faac4a 100644 --- a/pkg/disttask/framework/dispatcher/interface.go +++ b/pkg/disttask/framework/dispatcher/interface.go @@ -25,6 +25,11 @@ import ( // TaskManager defines the interface to access task table. type TaskManager interface { + // GetTopUnfinishedTasks returns unfinished tasks, limited by MaxConcurrentTask*2, + // to make sure lower priority tasks can be scheduled if resource is enough. + // The returned tasks are sorted by task order, see proto.Task, and only contains + // some fields, see row2TaskBasic. + GetTopUnfinishedTasks(ctx context.Context) ([]*proto.Task, error) GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) UpdateTaskAndAddSubTasks(ctx context.Context, task *proto.Task, subtasks []*proto.Subtask, prevState proto.TaskState) (bool, error) @@ -33,13 +38,25 @@ type TaskManager interface { CleanUpMeta(ctx context.Context, nodes []string) error TransferTasks2History(ctx context.Context, tasks []*proto.Task) error CancelTask(ctx context.Context, taskID int64) error + // FailTask updates task state to Failed and updates task error. + FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error PauseTask(ctx context.Context, taskKey string) (bool, error) + // GetUsedSlotsOnNodes returns the used slots on nodes that have subtask scheduled. + // subtasks of each task on one node is only accounted once as we don't support + // running them concurrently. + // we only consider pending/running subtasks, subtasks related to revert are + // not considered. + GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) GetSubtaskInStatesCnt(ctx context.Context, taskID int64, states ...interface{}) (int64, error) ResumeSubtasks(ctx context.Context, taskID int64) error CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) TransferSubTasks2History(ctx context.Context, taskID int64) error UpdateSubtasksExecIDs(ctx context.Context, taskID int64, subtasks []*proto.Subtask) error - GetNodesByRole(ctx context.Context, role string) (map[string]bool, error) + // GetManagedNodes returns the nodes managed by dist framework and can be used + // to execute tasks. If there are any nodes with background role, we use them, + // else we use nodes without role. + // returned nodes are sorted by node id(host:port). + GetManagedNodes(ctx context.Context) ([]string, error) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) diff --git a/pkg/disttask/framework/dispatcher/main_test.go b/pkg/disttask/framework/dispatcher/main_test.go index ca8b90c315026..7f452b8b8a319 100644 --- a/pkg/disttask/framework/dispatcher/main_test.go +++ b/pkg/disttask/framework/dispatcher/main_test.go @@ -30,12 +30,12 @@ type DispatcherManagerForTest interface { // GetRunningGTaskCnt implements Dispatcher.GetRunningGTaskCnt interface. func (dm *Manager) GetRunningTaskCnt() int { - return dm.getRunningTaskCnt() + return dm.getDispatcherCount() } // DelRunningGTask implements Dispatcher.DelRunningGTask interface. func (dm *Manager) DelRunningTask(id int64) { - dm.delRunningTask(id) + dm.delDispatcher(id) } // DoCleanUpRoutine implements Dispatcher.DoCleanUpRoutine interface. diff --git a/pkg/disttask/framework/dispatcher/slots.go b/pkg/disttask/framework/dispatcher/slots.go new file mode 100644 index 0000000000000..72fe68ef72edc --- /dev/null +++ b/pkg/disttask/framework/dispatcher/slots.go @@ -0,0 +1,180 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dispatcher + +import ( + "context" + "slices" + "sync" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/util/cpu" +) + +type taskStripes struct { + task *proto.Task + stripes int +} + +// slotManager is used to manage the resource slots and stripes. +// +// Slot is the resource unit of dist framework on each node, each slot represents +// 1 cpu core, 1/total-core of memory, 1/total-core of disk, etc. +// +// Stripe is the resource unit of dist framework, regardless of the node, each +// stripe means 1 slot on all nodes managed by dist framework. +// Number of stripes is equal to number of slots on each node, as we assume that +// all nodes managed by dist framework are isomorphic. +// Stripes reserved for a task defines the maximum resource that a task can use +// but the task might not use all the resources. To maximize the resource utilization, +// we will try to dispatch as many tasks as possible depends on the used slots +// on each node and the minimum resource required by the tasks, and in this case, +// we don't consider task order. +// +// Dist framework will try to allocate resource by slots and stripes, and give +// quota to subtask, but subtask can determine what to conform. +type slotManager struct { + // Capacity is the total number of slots and stripes. + // TODO: we assume that all nodes managed by dist framework are isomorphic, + // but dist owner might run on normal node where the capacity might not be + // able to run any task. + capacity int + + mu sync.RWMutex + // represents the number of stripes reserved by task, when we reserve by the + // minimum resource required by the task, we still append into it, so it summed + // value might be larger than capacity + // this slice is in task order. + reservedStripes []taskStripes + // map of reservedStripes for fast delete + task2Index map[int64]int + // represents the number of slots reserved by task on each node, the execID + // is only used for reserve minimum resource when starting dispatcher, the + // subtasks may or may not be scheduled on this node. + reservedSlots map[string]int + // represents the number of slots taken by task on each node + // on some cases it might be larger than capacity: + // current step of higher priority task A has little subtasks, so we start + // to dispatch lower priority task, but next step of A has many subtasks. + // once initialized, the length of usedSlots should be equal to number of nodes + // managed by dist framework. + usedSlots map[string]int +} + +// newSlotManager creates a new slotManager. +func newSlotManager() *slotManager { + return &slotManager{ + capacity: cpu.GetCPUCount(), + task2Index: make(map[int64]int), + reservedSlots: make(map[string]int), + usedSlots: make(map[string]int), + } +} + +// Update updates the used slots on each node. +// TODO: on concurrent call, update once. +func (sm *slotManager) update(ctx context.Context, taskMgr TaskManager) error { + nodes, err := taskMgr.GetManagedNodes(ctx) + if err != nil { + return err + } + slotsOnNodes, err := taskMgr.GetUsedSlotsOnNodes(ctx) + if err != nil { + return err + } + newUsedSlots := make(map[string]int, len(nodes)) + for _, node := range nodes { + newUsedSlots[node] = slotsOnNodes[node] + } + sm.mu.Lock() + defer sm.mu.Unlock() + sm.usedSlots = newUsedSlots + return nil +} + +// CanReserve checks whether there are enough resources for a task. +// If the resource is reserved by slots, it returns the execID of the task. +// else if the resource is reserved by stripes, it returns "". +// as usedSlots is updated asynchronously, it might return false even if there +// are enough resources, or return true on resource shortage when some task +// dispatched subtasks. +func (sm *slotManager) canReserve(task *proto.Task) (execID string, ok bool) { + sm.mu.RLock() + defer sm.mu.RUnlock() + if len(sm.usedSlots) == 0 { + // no node managed by dist framework + return "", false + } + + reservedForHigherPriority := 0 + for _, s := range sm.reservedStripes { + if s.task.Compare(task) >= 0 { + break + } + reservedForHigherPriority += s.stripes + } + if task.Concurrency+reservedForHigherPriority <= sm.capacity { + return "", true + } + + for id, count := range sm.usedSlots { + if count+sm.reservedSlots[id]+task.Concurrency <= sm.capacity { + return id, true + } + } + return "", false +} + +// Reserve reserves resources for a task. +// Reserve and UnReserve should be called in pair with same parameters. +func (sm *slotManager) reserve(task *proto.Task, execID string) { + taskClone := *task + + sm.mu.Lock() + defer sm.mu.Unlock() + sm.reservedStripes = append(sm.reservedStripes, taskStripes{&taskClone, taskClone.Concurrency}) + slices.SortFunc(sm.reservedStripes, func(a, b taskStripes) int { + return a.task.Compare(b.task) + }) + for i, s := range sm.reservedStripes { + sm.task2Index[s.task.ID] = i + } + + if execID != "" { + sm.reservedSlots[execID] += taskClone.Concurrency + } +} + +// UnReserve un-reserve resources for a task. +func (sm *slotManager) unReserve(task *proto.Task, execID string) { + sm.mu.Lock() + defer sm.mu.Unlock() + idx, ok := sm.task2Index[task.ID] + if !ok { + return + } + sm.reservedStripes = append(sm.reservedStripes[:idx], sm.reservedStripes[idx+1:]...) + delete(sm.task2Index, task.ID) + for i, s := range sm.reservedStripes { + sm.task2Index[s.task.ID] = i + } + + if execID != "" { + sm.reservedSlots[execID] -= task.Concurrency + if sm.reservedSlots[execID] == 0 { + delete(sm.reservedSlots, execID) + } + } +} diff --git a/pkg/disttask/framework/dispatcher/slots_test.go b/pkg/disttask/framework/dispatcher/slots_test.go new file mode 100644 index 0000000000000..d3eb0162075cc --- /dev/null +++ b/pkg/disttask/framework/dispatcher/slots_test.go @@ -0,0 +1,225 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dispatcher + +import ( + "context" + "testing" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/disttask/framework/mock" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestSlotManagerReserve(t *testing.T) { + sm := newSlotManager() + sm.capacity = 16 + // no node + _, ok := sm.canReserve(&proto.Task{Concurrency: 1}) + require.False(t, ok) + + // reserve by stripes + sm.usedSlots = map[string]int{ + "tidb-1": 16, + } + task := proto.Task{ + Priority: proto.NormalPriority, + Concurrency: 16, + CreateTime: time.Now(), + } + task10 := task + task10.ID = 10 + task10.Concurrency = 4 + execID10, ok := sm.canReserve(&task10) + require.Equal(t, "", execID10) + require.True(t, ok) + sm.reserve(&task10, execID10) + + task20 := task + task20.ID = 20 + task20.Concurrency = 8 + execID20, ok := sm.canReserve(&task20) + require.Equal(t, "", execID20) + require.True(t, ok) + sm.reserve(&task20, execID20) + + task30 := task + task30.ID = 30 + task30.Concurrency = 8 + execID30, ok := sm.canReserve(&task30) + require.Equal(t, "", execID30) + require.False(t, ok) + require.Len(t, sm.reservedStripes, 2) + require.Equal(t, 4, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, map[int64]int{10: 0, 20: 1}, sm.task2Index) + require.Empty(t, sm.reservedSlots) + // higher priority task can preempt lower priority task + task9 := task + task9.ID = 9 + task9.Concurrency = 16 + _, ok = sm.canReserve(&task9) + require.True(t, ok) + // 4 slots are reserved for high priority tasks, so cannot reserve. + task11 := task + task11.ID = 11 + _, ok = sm.canReserve(&task11) + require.False(t, ok) + // lower concurrency + task11.Concurrency = 12 + _, ok = sm.canReserve(&task11) + require.True(t, ok) + + // reserve by slots + sm.usedSlots = map[string]int{ + "tidb-1": 12, + "tidb-2": 8, + } + task40 := task + task40.ID = 40 + task40.Concurrency = 16 + execID40, ok := sm.canReserve(&task40) + require.Equal(t, "", execID40) + require.False(t, ok) + task40.Concurrency = 8 + execID40, ok = sm.canReserve(&task40) + require.Equal(t, "tidb-2", execID40) + require.True(t, ok) + sm.reserve(&task40, execID40) + require.Len(t, sm.reservedStripes, 3) + require.Equal(t, 4, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, 8, sm.reservedStripes[2].stripes) + require.Equal(t, map[int64]int{10: 0, 20: 1, 40: 2}, sm.task2Index) + require.Equal(t, map[string]int{"tidb-2": 8}, sm.reservedSlots) + // higher priority task stop task 15 to run + task15 := task + task15.ID = 15 + task15.Concurrency = 16 + execID15, ok := sm.canReserve(&task15) + require.Equal(t, "", execID15) + require.False(t, ok) + // finish task of id 10 + sm.unReserve(&task10, execID10) + require.Len(t, sm.reservedStripes, 2) + require.Equal(t, 8, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, map[int64]int{20: 0, 40: 1}, sm.task2Index) + require.Equal(t, map[string]int{"tidb-2": 8}, sm.reservedSlots) + // now task 15 can run + execID15, ok = sm.canReserve(&task15) + require.Equal(t, "", execID15) + require.True(t, ok) + sm.reserve(&task15, execID15) + require.Len(t, sm.reservedStripes, 3) + require.Equal(t, 16, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, 8, sm.reservedStripes[2].stripes) + require.Equal(t, map[int64]int{15: 0, 20: 1, 40: 2}, sm.task2Index) + require.Equal(t, map[string]int{"tidb-2": 8}, sm.reservedSlots) + // task 50 cannot run + task50 := task + task50.ID = 50 + task50.Concurrency = 8 + _, ok = sm.canReserve(&task50) + require.False(t, ok) + // finish task 40 + sm.unReserve(&task40, execID40) + require.Len(t, sm.reservedStripes, 2) + require.Equal(t, 16, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, map[int64]int{15: 0, 20: 1}, sm.task2Index) + require.Empty(t, sm.reservedSlots) + // now task 50 can run + execID50, ok := sm.canReserve(&task50) + require.Equal(t, "tidb-2", execID50) + require.True(t, ok) + sm.reserve(&task50, execID50) + // task 60 can run too + task60 := task + task60.ID = 60 + task60.Concurrency = 4 + execID60, ok := sm.canReserve(&task60) + require.Equal(t, "tidb-1", execID60) + require.True(t, ok) + sm.reserve(&task60, execID60) + require.Len(t, sm.reservedStripes, 4) + require.Equal(t, 16, sm.reservedStripes[0].stripes) + require.Equal(t, 8, sm.reservedStripes[1].stripes) + require.Equal(t, 8, sm.reservedStripes[2].stripes) + require.Equal(t, 4, sm.reservedStripes[3].stripes) + require.Equal(t, map[int64]int{15: 0, 20: 1, 50: 2, 60: 3}, sm.task2Index) + require.Equal(t, map[string]int{"tidb-1": 4, "tidb-2": 8}, sm.reservedSlots) + + // un-reserve all tasks + sm.unReserve(&task15, execID15) + sm.unReserve(&task20, execID20) + sm.unReserve(&task50, execID50) + sm.unReserve(&task60, execID60) + require.Empty(t, sm.reservedStripes) + require.Empty(t, sm.task2Index) + require.Empty(t, sm.reservedSlots) +} + +func TestSlotManagerUpdate(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + taskMgr := mock.NewMockTaskManager(ctrl) + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1", "tidb-2", "tidb-3"}, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ + "tidb-1": 12, + "tidb-2": 8, + }, nil) + sm := newSlotManager() + sm.capacity = 16 + require.Empty(t, sm.usedSlots) + require.Empty(t, sm.reservedSlots) + require.NoError(t, sm.update(context.Background(), taskMgr)) + require.Empty(t, sm.reservedSlots) + require.Equal(t, map[string]int{ + "tidb-1": 12, + "tidb-2": 8, + "tidb-3": 0, + }, sm.usedSlots) + // some node scaled in, should be reflected + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1"}, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(map[string]int{ + "tidb-1": 12, + "tidb-2": 8, + }, nil) + require.NoError(t, sm.update(context.Background(), taskMgr)) + require.Empty(t, sm.reservedSlots) + require.Equal(t, map[string]int{ + "tidb-1": 12, + }, sm.usedSlots) + // on error, the usedSlots should not be changed + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return(nil, errors.New("mock err")) + require.ErrorContains(t, sm.update(context.Background(), taskMgr), "mock err") + require.Empty(t, sm.reservedSlots) + require.Equal(t, map[string]int{ + "tidb-1": 12, + }, sm.usedSlots) + taskMgr.EXPECT().GetManagedNodes(gomock.Any()).Return([]string{"tidb-1"}, nil) + taskMgr.EXPECT().GetUsedSlotsOnNodes(gomock.Any()).Return(nil, errors.New("mock err")) + require.ErrorContains(t, sm.update(context.Background(), taskMgr), "mock err") + require.Empty(t, sm.reservedSlots) + require.Equal(t, map[string]int{ + "tidb-1": 12, + }, sm.usedSlots) +} diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index 32d299569c27e..5ff6c6c0ddb30 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -58,7 +58,7 @@ func TestHandle(t *testing.T) { // no dispatcher registered require.Equal(t, proto.TaskStateFailed, task.State) require.Equal(t, proto.StepInit, task.Step) - require.Equal(t, uint64(2), task.Concurrency) + require.Equal(t, 2, task.Concurrency) require.Equal(t, []byte("byte"), task.Meta) require.NoError(t, handle.CancelTask(ctx, "1")) diff --git a/pkg/disttask/framework/mock/dispatcher_mock.go b/pkg/disttask/framework/mock/dispatcher_mock.go index 5f55bdfc35b2a..0e2fb12a6be74 100644 --- a/pkg/disttask/framework/mock/dispatcher_mock.go +++ b/pkg/disttask/framework/mock/dispatcher_mock.go @@ -181,6 +181,20 @@ func (mr *MockTaskManagerMockRecorder) CollectSubTaskError(arg0, arg1 any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CollectSubTaskError", reflect.TypeOf((*MockTaskManager)(nil).CollectSubTaskError), arg0, arg1) } +// FailTask mocks base method. +func (m *MockTaskManager) FailTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FailTask", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// FailTask indicates an expected call of FailTask. +func (mr *MockTaskManagerMockRecorder) FailTask(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FailTask", reflect.TypeOf((*MockTaskManager)(nil).FailTask), arg0, arg1, arg2, arg3) +} + // GCSubtasks mocks base method. func (m *MockTaskManager) GCSubtasks(arg0 context.Context) error { m.ctrl.T.Helper() @@ -210,19 +224,19 @@ func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllNodes", reflect.TypeOf((*MockTaskManager)(nil).GetAllNodes), arg0) } -// GetNodesByRole mocks base method. -func (m *MockTaskManager) GetNodesByRole(arg0 context.Context, arg1 string) (map[string]bool, error) { +// GetManagedNodes mocks base method. +func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetNodesByRole", arg0, arg1) - ret0, _ := ret[0].(map[string]bool) + ret := m.ctrl.Call(m, "GetManagedNodes", arg0) + ret0, _ := ret[0].([]string) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetNodesByRole indicates an expected call of GetNodesByRole. -func (mr *MockTaskManagerMockRecorder) GetNodesByRole(arg0, arg1 any) *gomock.Call { +// GetManagedNodes indicates an expected call of GetManagedNodes. +func (mr *MockTaskManagerMockRecorder) GetManagedNodes(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNodesByRole", reflect.TypeOf((*MockTaskManager)(nil).GetNodesByRole), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetManagedNodes", reflect.TypeOf((*MockTaskManager)(nil).GetManagedNodes), arg0) } // GetSubtaskInStatesCnt mocks base method. @@ -340,6 +354,36 @@ func (mr *MockTaskManagerMockRecorder) GetTasksInStates(arg0 any, arg1 ...any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTasksInStates", reflect.TypeOf((*MockTaskManager)(nil).GetTasksInStates), varargs...) } +// GetTopUnfinishedTasks mocks base method. +func (m *MockTaskManager) GetTopUnfinishedTasks(arg0 context.Context) ([]*proto.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTopUnfinishedTasks", arg0) + ret0, _ := ret[0].([]*proto.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTopUnfinishedTasks indicates an expected call of GetTopUnfinishedTasks. +func (mr *MockTaskManagerMockRecorder) GetTopUnfinishedTasks(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTopUnfinishedTasks", reflect.TypeOf((*MockTaskManager)(nil).GetTopUnfinishedTasks), arg0) +} + +// GetUsedSlotsOnNodes mocks base method. +func (m *MockTaskManager) GetUsedSlotsOnNodes(arg0 context.Context) (map[string]int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUsedSlotsOnNodes", arg0) + ret0, _ := ret[0].(map[string]int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUsedSlotsOnNodes indicates an expected call of GetUsedSlotsOnNodes. +func (mr *MockTaskManagerMockRecorder) GetUsedSlotsOnNodes(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsedSlotsOnNodes", reflect.TypeOf((*MockTaskManager)(nil).GetUsedSlotsOnNodes), arg0) +} + // PauseTask mocks base method. func (m *MockTaskManager) PauseTask(arg0 context.Context, arg1 string) (bool, error) { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/proto/BUILD.bazel b/pkg/disttask/framework/proto/BUILD.bazel index a03b24ec142c0..6b73cc11152d4 100644 --- a/pkg/disttask/framework/proto/BUILD.bazel +++ b/pkg/disttask/framework/proto/BUILD.bazel @@ -13,5 +13,6 @@ go_test( srcs = ["task_test.go"], embed = [":proto"], flaky = True, + shard_count = 3, deps = ["@com_github_stretchr_testify//require"], ) diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 886340781c961..22ee3991bb403 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -21,7 +21,9 @@ import ( // task state machine // -// ┌──────────────────────────────┐ +// ┌────────┐ +// ┌───────────│resuming│◄────────┐ +// │ └────────┘ │ // │ ┌───────┐ ┌──┴───┐ // │ ┌────────►│pausing├──────►│paused│ // │ │ └───────┘ └──────┘ @@ -122,24 +124,31 @@ const ( // TaskIDLabelName is the label name of task id. TaskIDLabelName = "task_id" // NormalPriority represents the normal priority of task. - NormalPriority = 100 + NormalPriority = 512 ) +// MaxConcurrentTask is the max concurrency of task. +// TODO: remove this limit later. +var MaxConcurrentTask = 4 + // Task represents the task of distributed framework. -// tasks are run in the order of: priority desc, create_time asc, id asc. +// tasks are run in the order of: priority asc, create_time asc, id asc. type Task struct { ID int64 Key string Type TaskType State TaskState Step Step - // Priority is the priority of task, the larger value means the higher priority. + // Priority is the priority of task, the smaller value means the higher priority. // valid range is [1, 1024], default is NormalPriority. - Priority int + Priority int + Concurrency int + CreateTime time.Time + + // depends on query, below fields might not be filled. + // DispatcherID is not used now. DispatcherID string - Concurrency uint64 - CreateTime time.Time StartTime time.Time StateUpdateTime time.Time Meta []byte @@ -152,6 +161,20 @@ func (t *Task) IsDone() bool { t.State == TaskStateFailed } +// Compare compares two tasks by task order. +func (t *Task) Compare(other *Task) int { + if t.Priority != other.Priority { + return t.Priority - other.Priority + } + if t.CreateTime != other.CreateTime { + if t.CreateTime.Before(other.CreateTime) { + return -1 + } + return 1 + } + return int(t.ID - other.ID) +} + // Subtask represents the subtask of distribute framework. // Each task is divided into multiple subtasks by dispatcher. type Subtask struct { diff --git a/pkg/disttask/framework/proto/task_test.go b/pkg/disttask/framework/proto/task_test.go index 627fef6a0b635..f8b6861afa07a 100644 --- a/pkg/disttask/framework/proto/task_test.go +++ b/pkg/disttask/framework/proto/task_test.go @@ -16,6 +16,7 @@ package proto import ( "testing" + "time" "github.com/stretchr/testify/require" ) @@ -47,3 +48,29 @@ func TestTaskIsDone(t *testing.T) { require.Equal(t, c.done, (&Task{State: c.state}).IsDone()) } } + +func TestTaskCompare(t *testing.T) { + taskA := Task{ + ID: 100, + Priority: NormalPriority, + CreateTime: time.Date(2023, time.December, 5, 15, 53, 30, 0, time.UTC), + } + taskB := taskA + require.Equal(t, 0, taskA.Compare(&taskB)) + taskB.Priority = 100 + require.Greater(t, taskA.Compare(&taskB), 0) + taskB.Priority = taskA.Priority + 100 + require.Less(t, taskA.Compare(&taskB), 0) + + taskB.Priority = taskA.Priority + taskB.CreateTime = time.Date(2023, time.December, 5, 15, 53, 10, 0, time.UTC) + require.Greater(t, taskA.Compare(&taskB), 0) + taskB.CreateTime = time.Date(2023, time.December, 5, 15, 53, 40, 0, time.UTC) + require.Less(t, taskA.Compare(&taskB), 0) + + taskB.CreateTime = taskA.CreateTime + taskB.ID = taskA.ID - 10 + require.Greater(t, taskA.Compare(&taskB), 0) + taskB.ID = taskA.ID + 10 + require.Less(t, taskA.Compare(&taskB), 0) +} diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index 5486695a078a4..cda18d01cee1d 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -32,12 +32,15 @@ go_test( srcs = ["table_test.go"], flaky = True, race = "on", - shard_count = 8, + shard_count = 10, deps = [ ":storage", "//pkg/disttask/framework/proto", + "//pkg/disttask/framework/testutil", + "//pkg/sessionctx", "//pkg/testkit", "//pkg/testkit/testsetup", + "//pkg/util/sqlexec", "@com_github_ngaut_pools//:pools", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index a4a83cf345dac..3d5cc6ea4ecd4 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -16,6 +16,7 @@ package storage_test import ( "context" + "fmt" "testing" "time" @@ -24,8 +25,11 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" + "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/testkit/testsetup" + "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" "go.uber.org/goleak" @@ -78,7 +82,7 @@ func TestTaskTable(t *testing.T) { require.Equal(t, proto.TaskType("test"), task.Type) require.Equal(t, proto.TaskStatePending, task.State) require.Equal(t, proto.NormalPriority, task.Priority) - require.Equal(t, uint64(4), task.Concurrency) + require.Equal(t, 4, task.Concurrency) require.Equal(t, proto.StepInit, task.Step) require.Equal(t, []byte("test"), task.Meta) require.GreaterOrEqual(t, task.CreateTime, timeBeforeCreate) @@ -131,6 +135,110 @@ func TestTaskTable(t *testing.T) { cancelling, err = gm.IsTaskCancelling(ctx, id) require.NoError(t, err) require.True(t, cancelling) + + id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, []byte("test2")) + require.NoError(t, err) + // state not right, update nothing + require.NoError(t, gm.FailTask(ctx, id, proto.TaskStateRunning, errors.New("test error"))) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStatePending, task.State) + require.Nil(t, task.Error) + require.NoError(t, gm.FailTask(ctx, id, proto.TaskStatePending, errors.New("test error"))) + task, err = gm.GetTaskByID(ctx, id) + require.NoError(t, err) + require.Equal(t, proto.TaskStateFailed, task.State) + require.ErrorContains(t, task.Error, "test error") +} + +func TestGetTopUnfinishedTasks(t *testing.T) { + pool := GetResourcePool(t) + gm := GetTaskManager(t, pool) + defer pool.Close() + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "table_test") + + taskStates := []proto.TaskState{ + proto.TaskStateSucceed, + proto.TaskStatePending, + proto.TaskStateRunning, + proto.TaskStateReverting, + proto.TaskStateCancelling, + proto.TaskStatePausing, + proto.TaskStateResuming, + proto.TaskStateFailed, + proto.TaskStatePending, + proto.TaskStatePending, + proto.TaskStatePending, + proto.TaskStatePending, + } + for i, state := range taskStates { + taskKey := fmt.Sprintf("key/%d", i) + _, err := gm.CreateTask(ctx, taskKey, "test", 4, []byte("test")) + require.NoError(t, err) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set state = %? where task_key = %?`, + state, taskKey) + return err + })) + } + // adjust task order + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set create_time = current_timestamp`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task + set create_time = timestampadd(minute, -10, current_timestamp) + where task_key = 'key/5'`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, ` + update mysql.tidb_global_task set priority = 100 where task_key = 'key/6'`) + return err + })) + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + rs, err := storage.ExecSQL(ctx, se, ` + select count(1) from mysql.tidb_global_task`) + require.Len(t, rs, 1) + require.Equal(t, int64(12), rs[0].GetInt64(0)) + return err + })) + tasks, err := gm.GetTopUnfinishedTasks(ctx) + require.NoError(t, err) + require.Len(t, tasks, 8) + taskKeys := make([]string, 0, len(tasks)) + for _, task := range tasks { + taskKeys = append(taskKeys, task.Key) + // not filled + require.Empty(t, task.Meta) + } + require.Equal(t, []string{"key/6", "key/5", "key/1", "key/2", "key/3", "key/4", "key/8", "key/9"}, taskKeys) +} + +func TestGetUsedSlotsOnNodes(t *testing.T) { + pool := GetResourcePool(t) + sm := GetTaskManager(t, pool) + defer pool.Close() + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "table_test") + + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-1", []byte(""), proto.TaskStateRunning, "test", 12) + testutil.InsertSubtask(t, sm, 1, proto.StepOne, "tidb-2", []byte(""), proto.TaskStatePending, "test", 12) + testutil.InsertSubtask(t, sm, 2, proto.StepOne, "tidb-2", []byte(""), proto.TaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 3, proto.StepOne, "tidb-3", []byte(""), proto.TaskStatePending, "test", 8) + testutil.InsertSubtask(t, sm, 4, proto.StepOne, "tidb-3", []byte(""), proto.TaskStateFailed, "test", 8) + slotsOnNodes, err := sm.GetUsedSlotsOnNodes(ctx) + require.NoError(t, err) + require.Equal(t, map[string]int{ + "tidb-1": 12, + "tidb-2": 20, + "tidb-3": 8, + }, slotsOnNodes) } func TestSubTaskTable(t *testing.T) { @@ -247,8 +355,7 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.False(t, ok) - err = sm.CreateSubTask(ctx, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, true) - require.NoError(t, err) + testutil.CreateSubTask(t, sm, 2, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, true) cnt, err = sm.GetSubtaskInStatesCnt(ctx, 2, proto.TaskStateRevertPending) require.NoError(t, err) @@ -275,8 +382,7 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, int64(100), rowCount) // test UpdateErrorToSubtask do update start/update time - err = sm.CreateSubTask(ctx, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, false) - require.NoError(t, err) + testutil.CreateSubTask(t, sm, 3, proto.StepInit, "for_test", []byte("test"), proto.TaskTypeExample, 11, false) require.NoError(t, sm.UpdateErrorToSubtask(ctx, "for_test", 3, errors.New("fail"))) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test", 3, proto.StepInit, proto.TaskStateFailed) require.NoError(t, err) @@ -285,8 +391,7 @@ func TestSubTaskTable(t *testing.T) { require.Greater(t, subtask.UpdateTime, ts) // test FinishSubtask do update update time - err = sm.CreateSubTask(ctx, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, false) - require.NoError(t, err) + testutil.CreateSubTask(t, sm, 4, proto.StepInit, "for_test1", []byte("test"), proto.TaskTypeExample, 11, false) subtask, err = sm.GetFirstSubtaskInStates(ctx, "for_test1", 4, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.NoError(t, sm.StartSubtask(ctx, subtask.ID)) @@ -314,7 +419,7 @@ func TestSubTaskTable(t *testing.T) { // test UpdateSubtasksExecIDs // 1. update one subtask - require.NoError(t, sm.CreateSubTask(ctx, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb2" @@ -323,7 +428,7 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) // 2. update 2 subtasks - require.NoError(t, sm.CreateSubTask(ctx, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb3" @@ -345,8 +450,8 @@ func TestSubTaskTable(t *testing.T) { require.Equal(t, "tidb2", subtasks[0].ExecID) // test GetSubtasksByExecIdsAndStepAndState - require.NoError(t, sm.CreateSubTask(ctx, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) - require.NoError(t, sm.CreateSubTask(ctx, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) subtasks, err = sm.GetSubtasksByStepAndState(ctx, 6, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.TaskStateRunning, nil)) @@ -356,7 +461,7 @@ func TestSubTaskTable(t *testing.T) { subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, 1, len(subtasks)) - require.NoError(t, sm.CreateSubTask(ctx, 6, proto.StepInit, "tidb2", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb2", []byte("test"), proto.TaskTypeExample, 11, false) subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1", "tidb2"}, 6, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) require.Equal(t, 2, len(subtasks)) @@ -484,6 +589,11 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { } func TestDistFrameworkMeta(t *testing.T) { + // to avoid inserted nodes be cleaned by dispatcher + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask", "return(true)")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/domain/MockDisableDistTask")) + }() pool := GetResourcePool(t) sm := GetTaskManager(t, pool) defer pool.Close() @@ -492,37 +602,28 @@ func TestDistFrameworkMeta(t *testing.T) { require.NoError(t, sm.StartManager(ctx, ":4000", "background")) require.NoError(t, sm.StartManager(ctx, ":4001", "")) - require.NoError(t, sm.StartManager(ctx, ":4002", "")) + // will be replaced by below one require.NoError(t, sm.StartManager(ctx, ":4002", "background")) + require.NoError(t, sm.StartManager(ctx, ":4002", "")) + require.NoError(t, sm.StartManager(ctx, ":4003", "background")) allNodes, err := sm.GetAllNodes(ctx) require.NoError(t, err) - require.Equal(t, []string{":4000", ":4001", ":4002"}, allNodes) - - nodes, err := sm.GetNodesByRole(ctx, "background") - require.NoError(t, err) - require.Equal(t, map[string]bool{ - ":4000": true, - ":4002": true, - }, nodes) + require.Equal(t, []string{":4000", ":4001", ":4002", ":4003"}, allNodes) - nodes, err = sm.GetNodesByRole(ctx, "") + nodes, err := sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, map[string]bool{ - ":4001": true, - }, nodes) + require.Equal(t, []string{":4000", ":4003"}, nodes) require.NoError(t, sm.CleanUpMeta(ctx, []string{":4000"})) - nodes, err = sm.GetNodesByRole(ctx, "background") + nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, map[string]bool{ - ":4002": true, - }, nodes) + require.Equal(t, []string{":4003"}, nodes) - require.NoError(t, sm.CleanUpMeta(ctx, []string{":4002"})) - nodes, err = sm.GetNodesByRole(ctx, "background") + require.NoError(t, sm.CleanUpMeta(ctx, []string{":4003"})) + nodes, err = sm.GetManagedNodes(ctx) require.NoError(t, err) - require.Equal(t, map[string]bool{}, nodes) + require.Equal(t, []string{":4001", ":4002"}, nodes) } func TestSubtaskHistoryTable(t *testing.T) { @@ -546,11 +647,11 @@ func TestSubtaskHistoryTable(t *testing.T) { finishedMeta = "finished" ) - require.NoError(t, sm.CreateSubTask(ctx, taskID, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11, false) require.NoError(t, sm.FinishSubtask(ctx, tidb1, subTask1, []byte(finishedMeta))) - require.NoError(t, sm.CreateSubTask(ctx, taskID, proto.StepInit, tidb2, []byte(meta), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb2, []byte(meta), proto.TaskTypeExample, 11, false) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb2, subTask2, proto.TaskStateCanceled, nil)) - require.NoError(t, sm.CreateSubTask(ctx, taskID, proto.StepInit, tidb3, []byte(meta), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, taskID, proto.StepInit, tidb3, []byte(meta), proto.TaskTypeExample, 11, false) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb3, subTask3, proto.TaskStateFailed, nil)) subTasks, err := storage.GetSubtasksByTaskIDForTest(ctx, sm, taskID) @@ -583,7 +684,7 @@ func TestSubtaskHistoryTable(t *testing.T) { }() time.Sleep(2 * time.Second) - require.NoError(t, sm.CreateSubTask(ctx, taskID2, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, taskID2, proto.StepInit, tidb1, []byte(meta), proto.TaskTypeExample, 11, false) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, tidb1, subTask4, proto.TaskStateFailed, nil)) require.NoError(t, sm.TransferSubTasks2History(ctx, taskID2)) @@ -647,9 +748,9 @@ func TestPauseAndResume(t *testing.T) { ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "table_test") - require.NoError(t, sm.CreateSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) - require.NoError(t, sm.CreateSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) - require.NoError(t, sm.CreateSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) // 1.1 pause all subtasks. require.NoError(t, sm.PauseSubtasks(ctx, "tidb1", 1)) cnt, err := sm.GetSubtaskInStatesCnt(ctx, 1, proto.TaskStatePaused) @@ -681,7 +782,7 @@ func TestCancelAndExecIdChanged(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) ctx = util.WithInternalSourceType(ctx, "table_test") - require.NoError(t, sm.CreateSubTask(ctx, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, sm, 1, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) subtask, err := sm.GetFirstSubtaskInStates(ctx, "tidb1", 1, proto.StepInit, proto.TaskStatePending) require.NoError(t, err) // 1. cancel the ctx, then update subtask state. diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 0315e266803b1..b200ef259f6e6 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -41,12 +41,15 @@ import ( const ( defaultSubtaskKeepDays = 14 - taskColumns = `id, task_key, type, dispatcher_id, state, start_time, state_update_time, - meta, concurrency, step, error, priority, create_time` + basicTaskColumns = `id, task_key, type, state, step, priority, concurrency, create_time` + taskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error` + // InsertTaskColumns is the columns used in insert task. + InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, start_time, state_update_time` + subtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, start_time, state_update_time, meta, summary` - insertSubtaskBasic = `insert into mysql.tidb_background_subtask( - step, task_key, exec_id, meta, state, type, concurrency, create_time, checkpoint, summary) values ` + // InsertSubtaskColumns is the columns used in insert subtask. + InsertSubtaskColumns = `step, task_key, exec_id, meta, state, type, concurrency, create_time, checkpoint, summary` ) // SessionExecutor defines the interface for executing SQLs in a session. @@ -111,21 +114,29 @@ func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...int return nil, nil } +func row2TaskBasic(r chunk.Row) *proto.Task { + task := &proto.Task{ + ID: r.GetInt64(0), + Key: r.GetString(1), + Type: proto.TaskType(r.GetString(2)), + State: proto.TaskState(r.GetString(3)), + Step: proto.Step(r.GetInt64(4)), + Priority: int(r.GetInt64(5)), + Concurrency: int(r.GetInt64(6)), + } + task.CreateTime, _ = r.GetTime(7).GoTime(time.Local) + return task +} + // row2Task converts a row to a task. func row2Task(r chunk.Row) *proto.Task { - task := &proto.Task{ - ID: r.GetInt64(0), - Key: r.GetString(1), - Type: proto.TaskType(r.GetString(2)), - DispatcherID: r.GetString(3), - State: proto.TaskState(r.GetString(4)), - Meta: r.GetBytes(7), - Concurrency: uint64(r.GetInt64(8)), - Step: proto.Step(r.GetInt64(9)), - Priority: int(r.GetInt64(11)), - } - if !r.IsNull(10) { - errBytes := r.GetBytes(10) + task := row2TaskBasic(r) + task.StartTime, _ = r.GetTime(8).GoTime(time.Local) + task.StateUpdateTime, _ = r.GetTime(9).GoTime(time.Local) + task.Meta = r.GetBytes(10) + task.DispatcherID = r.GetString(11) + if !r.IsNull(12) { + errBytes := r.GetBytes(12) stdErr := errors.Normalize("") err := stdErr.UnmarshalJSON(errBytes) if err != nil { @@ -135,9 +146,6 @@ func row2Task(r chunk.Row) *proto.Task { task.Error = stdErr } } - task.CreateTime, _ = r.GetTime(12).GoTime(time.Local) - task.StartTime, _ = r.GetTime(5).GoTime(time.Local) - task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.Local) return task } @@ -206,8 +214,8 @@ func (stm *TaskManager) CreateTask(ctx context.Context, key string, tp proto.Tas // CreateTaskWithSession adds a new task to task table with session. func (*TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx.Context, key string, tp proto.TaskType, concurrency int, meta []byte) (taskID int64, err error) { - _, err = ExecSQL(ctx, se, `insert into mysql.tidb_global_task( - task_key, type, state, priority, concurrency, step, meta, create_time, start_time, state_update_time) + _, err = ExecSQL(ctx, se, ` + insert into mysql.tidb_global_task(`+InsertTaskColumns+`) values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP(), CURRENT_TIMESTAMP())`, key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta) if err != nil { @@ -239,6 +247,31 @@ func (stm *TaskManager) GetOneTask(ctx context.Context) (task *proto.Task, err e return row2Task(rs[0]), nil } +// GetTopUnfinishedTasks implements the dispatcher.TaskManager interface. +func (stm *TaskManager) GetTopUnfinishedTasks(ctx context.Context) (task []*proto.Task, err error) { + rs, err := stm.executeSQLWithNewSession(ctx, + `select `+basicTaskColumns+` from mysql.tidb_global_task + where state in (%?, %?, %?, %?, %?, %?) + order by priority asc, create_time asc, id asc + limit %?`, + proto.TaskStatePending, + proto.TaskStateRunning, + proto.TaskStateReverting, + proto.TaskStateCancelling, + proto.TaskStatePausing, + proto.TaskStateResuming, + proto.MaxConcurrentTask*2, + ) + if err != nil { + return task, err + } + + for _, r := range rs { + task = append(task, row2TaskBasic(r)) + } + return task, nil +} + // GetTasksInStates gets the tasks in the states. func (stm *TaskManager) GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) { if len(states) == 0 { @@ -327,6 +360,47 @@ func (stm *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) return row2Task(rs[0]), nil } +// FailTask implements the dispatcher.TaskManager interface. +func (stm *TaskManager) FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error { + _, err := stm.executeSQLWithNewSession(ctx, + `update mysql.tidb_global_task + set state=%?, + error = %?, + state_update_time = CURRENT_TIMESTAMP() + where id=%? and state=%?`, + proto.TaskStateFailed, serializeErr(taskErr), taskID, currentState, + ) + return err +} + +// GetUsedSlotsOnNodes implements the dispatcher.TaskManager interface. +func (stm *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) { + // concurrency of subtasks of some step is the same, we use max(concurrency) + // to make group by works. + rs, err := stm.executeSQLWithNewSession(ctx, ` + select + exec_id, sum(concurrency) + from ( + select exec_id, task_key, max(concurrency) concurrency + from mysql.tidb_background_subtask + where state in (%?, %?) + group by exec_id, task_key + ) a + group by exec_id`, + proto.TaskStatePending, proto.TaskStateRunning, + ) + if err != nil { + return nil, err + } + + slots := make(map[string]int, len(rs)) + for _, r := range rs { + val, _ := r.GetMyDecimal(1).ToInt() + slots[r.GetString(0)] = int(val) + } + return slots, nil +} + // row2SubTask converts a row to a subtask. func row2SubTask(r chunk.Row) *proto.Subtask { // subtask defines start/update time as bigint, to ensure backward compatible, @@ -363,23 +437,6 @@ func row2SubTask(r chunk.Row) *proto.Subtask { return subtask } -// CreateSubTask adds a new task to subtask table. -// used for testing. -func (stm *TaskManager) CreateSubTask(ctx context.Context, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, isRevert bool) error { - state := proto.TaskStatePending - if isRevert { - state = proto.TaskStateRevertPending - } - - _, err := stm.executeSQLWithNewSession(ctx, insertSubtaskBasic+`(%?, %?, %?, %?, %?, %?, 11, CURRENT_TIMESTAMP(), '{}', '{}')`, - step, taskID, execID, meta, state, proto.Type2Int(tp)) - if err != nil { - return err - } - - return nil -} - // GetSubtasksInStates gets all subtasks by given states. func (stm *TaskManager) GetSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...interface{}) ([]*proto.Subtask, error) { args := []interface{}{tidbID, taskID, step} @@ -777,7 +834,7 @@ func (stm *TaskManager) UpdateTaskAndAddSubTasks(ctx context.Context, task *prot } sql := new(strings.Builder) - if err := sqlescape.FormatSQL(sql, insertSubtaskBasic); err != nil { + if err := sqlescape.FormatSQL(sql, `insert into mysql.tidb_background_subtask(`+InsertSubtaskColumns+`) values`); err != nil { return err } for i, subtask := range subtasks { @@ -1015,18 +1072,25 @@ func (stm *TaskManager) TransferTasks2History(ctx context.Context, tasks []*prot }) } -// GetNodesByRole gets nodes map from dist_framework_meta by role. -func (stm *TaskManager) GetNodesByRole(ctx context.Context, role string) (map[string]bool, error) { - rs, err := stm.executeSQLWithNewSession(ctx, - "select host from mysql.dist_framework_meta where role = %?", role) +// GetManagedNodes implements dispatcher.TaskManager interface. +func (stm *TaskManager) GetManagedNodes(ctx context.Context) ([]string, error) { + rs, err := stm.executeSQLWithNewSession(ctx, ` + select host, role + from mysql.dist_framework_meta + where role = 'background' or role = '' + order by host`) if err != nil { return nil, err } - nodes := make(map[string]bool, len(rs)) + nodes := make(map[string][]string, 2) for _, r := range rs { - nodes[r.GetString(0)] = true + role := r.GetString(1) + nodes[role] = append(nodes[role], r.GetString(0)) } - return nodes, nil + if len(nodes["background"]) == 0 { + return nodes[""], nil + } + return nodes["background"], nil } // GetAllNodes gets nodes in dist_framework_meta. diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index a4ec25f391c23..8f29fe2044865 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -154,7 +154,7 @@ func (s *BaseTaskExecutor) run(ctx context.Context, task *proto.Task) (resErr er s.resetError() stepLogger := log.BeginTask(logutil.Logger(s.logCtx).With( zap.Any("step", task.Step), - zap.Uint64("concurrency", task.Concurrency), + zap.Int("concurrency", task.Concurrency), zap.Float64("mem-limit-percent", gctuner.GlobalMemoryLimitTuner.GetPercentage()), zap.String("server-mem-limit", memory.ServerMemoryLimitOriginText.Load()), ), "schedule step") diff --git a/pkg/disttask/framework/taskexecutor/task_executor_test.go b/pkg/disttask/framework/taskexecutor/task_executor_test.go index 68713b0aadbe5..1f831e5eb9d67 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_test.go @@ -73,7 +73,7 @@ func TestTaskExecutorRun(t *testing.T) { require.EqualError(t, err, initErr.Error()) var taskID int64 = 1 - var concurrency uint64 = 10 + var concurrency = 10 task := &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency} // 3. run subtask failed @@ -356,7 +356,7 @@ func TestTaskExecutorPause(t *testing.T) { func TestTaskExecutor(t *testing.T) { var tp proto.TaskType = "test_task_executor" var taskID int64 = 1 - var concurrency uint64 = 10 + var concurrency = 10 ctx, cancel := context.WithCancel(context.Background()) defer cancel() runCtx, runCancel := context.WithCancel(ctx) diff --git a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go index 09f5249ba9653..25bbf96cdac11 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go @@ -45,7 +45,7 @@ func runOneTask(ctx context.Context, t *testing.T, mgr *storage.TaskManager, tas _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStatePending) require.NoError(t, err) for i := 0; i < subtaskCnt; i++ { - require.NoError(t, mgr.CreateSubTask(ctx, taskID, proto.StepOne, "test", nil, proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, mgr, taskID, proto.StepOne, "test", nil, proto.TaskTypeExample, 11, false) } task, err = mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) @@ -58,7 +58,7 @@ func runOneTask(ctx context.Context, t *testing.T, mgr *storage.TaskManager, tas _, err = mgr.UpdateTaskAndAddSubTasks(ctx, task, nil, proto.TaskStateRunning) require.NoError(t, err) for i := 0; i < subtaskCnt; i++ { - require.NoError(t, mgr.CreateSubTask(ctx, taskID, proto.StepTwo, "test", nil, proto.TaskTypeExample, false)) + testutil.CreateSubTask(t, mgr, taskID, proto.StepTwo, "test", nil, proto.TaskTypeExample, 11, false) } task, err = mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/testutil/BUILD.bazel b/pkg/disttask/framework/testutil/BUILD.bazel index 645e8aa407c36..a84f828752684 100644 --- a/pkg/disttask/framework/testutil/BUILD.bazel +++ b/pkg/disttask/framework/testutil/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "dispatcher_util.go", "disttest_util.go", "executor_util.go", + "task_util.go", ], importpath = "github.com/pingcap/tidb/pkg/disttask/framework/testutil", visibility = ["//visibility:public"], @@ -19,6 +20,7 @@ go_library( "//pkg/disttask/framework/storage", "//pkg/disttask/framework/taskexecutor", "//pkg/domain/infosync", + "//pkg/sessionctx", "//pkg/testkit", "@com_github_pingcap_failpoint//:failpoint", "@com_github_stretchr_testify//require", diff --git a/pkg/disttask/framework/testutil/task_util.go b/pkg/disttask/framework/testutil/task_util.go new file mode 100644 index 0000000000000..2a4f11b6578b9 --- /dev/null +++ b/pkg/disttask/framework/testutil/task_util.go @@ -0,0 +1,49 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "context" + "testing" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/util" +) + +// CreateSubTask adds a new task to subtask table. +// used for testing. +func CreateSubTask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, tp proto.TaskType, concurrency int, isRevert bool) { + state := proto.TaskStatePending + if isRevert { + state = proto.TaskStateRevertPending + } + InsertSubtask(t, gm, taskID, step, execID, meta, state, tp, concurrency) +} + +// InsertSubtask adds a new subtask of any state to subtask table. +func InsertSubtask(t *testing.T, gm *storage.TaskManager, taskID int64, step proto.Step, execID string, meta []byte, state proto.TaskState, tp proto.TaskType, concurrency int) { + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "table_test") + require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { + _, err := storage.ExecSQL(ctx, se, ` + insert into mysql.tidb_background_subtask(`+storage.InsertSubtaskColumns+`) values`+ + `(%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')`, + step, taskID, execID, meta, state, proto.Type2Int(tp), concurrency) + return err + })) +} diff --git a/pkg/disttask/importinto/BUILD.bazel b/pkg/disttask/importinto/BUILD.bazel index 6fa4a3f079852..ef229b5f347d8 100644 --- a/pkg/disttask/importinto/BUILD.bazel +++ b/pkg/disttask/importinto/BUILD.bazel @@ -109,6 +109,7 @@ go_test( "//pkg/disttask/framework/planner", "//pkg/disttask/framework/proto", "//pkg/disttask/framework/storage", + "//pkg/disttask/framework/testutil", "//pkg/disttask/importinto/mock", "//pkg/disttask/operator", "//pkg/domain/infosync", diff --git a/pkg/disttask/importinto/job_testkit_test.go b/pkg/disttask/importinto/job_testkit_test.go index a4a35e9241d02..3267bc8711a12 100644 --- a/pkg/disttask/importinto/job_testkit_test.go +++ b/pkg/disttask/importinto/job_testkit_test.go @@ -23,6 +23,7 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/disttask/importinto" "github.com/pingcap/tidb/pkg/executor/importer" "github.com/pingcap/tidb/pkg/kv" @@ -69,8 +70,8 @@ func TestGetTaskImportedRows(t *testing.T) { for _, m := range importStepMetas { bytes, err := json.Marshal(m) require.NoError(t, err) - require.NoError(t, manager.CreateSubTask(ctx, taskID, importinto.StepImport, - "", bytes, proto.ImportInto, false)) + testutil.CreateSubTask(t, manager, taskID, importinto.StepImport, + "", bytes, proto.ImportInto, 11, false) } rows, err := importinto.GetTaskImportedRows(ctx, 111) require.NoError(t, err) @@ -101,8 +102,8 @@ func TestGetTaskImportedRows(t *testing.T) { for _, m := range ingestStepMetas { bytes, err := json.Marshal(m) require.NoError(t, err) - require.NoError(t, manager.CreateSubTask(ctx, taskID, importinto.StepWriteAndIngest, - "", bytes, proto.ImportInto, false)) + testutil.CreateSubTask(t, manager, taskID, importinto.StepWriteAndIngest, + "", bytes, proto.ImportInto, 11, false) } rows, err = importinto.GetTaskImportedRows(ctx, 222) require.NoError(t, err) diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index 4337ce9473d47..a2f6f513d4b13 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -1507,7 +1507,7 @@ func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storag var dispatcherManager *dispatcher.Manager startDispatchIfNeeded := func() { - if dispatcherManager != nil && dispatcherManager.Inited() { + if dispatcherManager != nil && dispatcherManager.Initialized() { return } var err error @@ -1519,7 +1519,7 @@ func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storag dispatcherManager.Start() } stopDispatchIfNeeded := func() { - if dispatcherManager != nil && dispatcherManager.Inited() { + if dispatcherManager != nil && dispatcherManager.Initialized() { logutil.BgLogger().Info("stopping dist task dispatcher manager because the current node is not DDL owner anymore", zap.String("id", do.ddl.GetID())) dispatcherManager.Stop() logutil.BgLogger().Info("dist task dispatcher manager stopped", zap.String("id", do.ddl.GetID())) diff --git a/pkg/executor/importer/BUILD.bazel b/pkg/executor/importer/BUILD.bazel index 5e1afc2d8332b..8bf7adde8bfce 100644 --- a/pkg/executor/importer/BUILD.bazel +++ b/pkg/executor/importer/BUILD.bazel @@ -51,6 +51,7 @@ go_library( "//pkg/types", "//pkg/util", "//pkg/util/chunk", + "//pkg/util/cpu", "//pkg/util/dbterror", "//pkg/util/dbterror/exeerrors", "//pkg/util/etcd", diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index 2dd931e828658..0ab47f80d96c8 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -28,7 +28,6 @@ import ( "unicode/utf8" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/lightning/backend/local" "github.com/pingcap/tidb/br/pkg/lightning/common" @@ -54,6 +53,7 @@ import ( "github.com/pingcap/tidb/pkg/table" tidbutil "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/cpu" "github.com/pingcap/tidb/pkg/util/dbterror" "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" "github.com/pingcap/tidb/pkg/util/filter" @@ -502,10 +502,7 @@ func (e *LoadDataController) checkFieldParams() error { } func (p *Plan) initDefaultOptions() { - threadCnt := runtime.GOMAXPROCS(0) - failpoint.Inject("mockNumCpu", func(val failpoint.Value) { - threadCnt = val.(int) - }) + threadCnt := cpu.GetCPUCount() threadCnt = int(math.Max(1, float64(threadCnt)*0.5)) p.Checksum = config.OpLevelRequired diff --git a/pkg/executor/importer/import_test.go b/pkg/executor/importer/import_test.go index 35860ae3f3d91..db25955c080aa 100644 --- a/pkg/executor/importer/import_test.go +++ b/pkg/executor/importer/import_test.go @@ -47,7 +47,7 @@ import ( func TestInitDefaultOptions(t *testing.T) { plan := &Plan{} - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/importer/mockNumCpu", "return(1)")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(1)")) variable.CloudStorageURI.Store("s3://bucket/path") t.Cleanup(func() { variable.CloudStorageURI.Store("") @@ -65,7 +65,7 @@ func TestInitDefaultOptions(t *testing.T) { require.Equal(t, config.ByteSize(defaultMaxEngineSize), plan.MaxEngineSize) require.Equal(t, "s3://bucket/path", plan.CloudStorageURI) - require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/executor/importer/mockNumCpu", "return(10)")) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/util/cpu/mockNumCpu", "return(10)")) plan.initDefaultOptions() require.Equal(t, int64(5), plan.ThreadCnt) } diff --git a/pkg/metrics/disttask.go b/pkg/metrics/disttask.go index cf916b5215c3d..97f45322a9d29 100644 --- a/pkg/metrics/disttask.go +++ b/pkg/metrics/disttask.go @@ -129,10 +129,10 @@ func UpdateMetricsForAddTask(task *proto.Task) { } // UpdateMetricsForDispatchTask update metrics when a task is added -func UpdateMetricsForDispatchTask(task *proto.Task) { - DistTaskGauge.WithLabelValues(task.Type.String(), WaitingStatus).Dec() - DistTaskStarttimeGauge.DeleteLabelValues(task.Type.String(), WaitingStatus, fmt.Sprint(task.ID)) - DistTaskStarttimeGauge.WithLabelValues(task.Type.String(), DispatchingStatus, fmt.Sprint(task.ID)).SetToCurrentTime() +func UpdateMetricsForDispatchTask(id int64, taskType proto.TaskType) { + DistTaskGauge.WithLabelValues(taskType.String(), WaitingStatus).Dec() + DistTaskStarttimeGauge.DeleteLabelValues(taskType.String(), WaitingStatus, fmt.Sprint(id)) + DistTaskStarttimeGauge.WithLabelValues(taskType.String(), DispatchingStatus, fmt.Sprint(id)).SetToCurrentTime() } // UpdateMetricsForRunTask update metrics when a task starts running diff --git a/pkg/util/cpu/BUILD.bazel b/pkg/util/cpu/BUILD.bazel index 50332f6fbaeec..d0284eea2b4b4 100644 --- a/pkg/util/cpu/BUILD.bazel +++ b/pkg/util/cpu/BUILD.bazel @@ -10,6 +10,7 @@ go_library( "//pkg/util/cgroup", "//pkg/util/mathutil", "@com_github_cloudfoundry_gosigar//:gosigar", + "@com_github_pingcap_failpoint//:failpoint", "@com_github_pingcap_log//:log", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", diff --git a/pkg/util/cpu/cpu.go b/pkg/util/cpu/cpu.go index 4be66d05421a7..6d14cbfada58d 100644 --- a/pkg/util/cpu/cpu.go +++ b/pkg/util/cpu/cpu.go @@ -16,10 +16,12 @@ package cpu import ( "os" + "runtime" "sync" "time" sigar "github.com/cloudfoundry/gosigar" + "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/pingcap/tidb/pkg/metrics" "github.com/pingcap/tidb/pkg/util/cgroup" @@ -120,3 +122,11 @@ func getCPUTime() (userTimeMillis, sysTimeMillis int64, err error) { } return int64(cpuTime.User), int64(cpuTime.Sys), nil } + +// GetCPUCount returns the number of logical CPUs usable by the current process. +func GetCPUCount() int { + failpoint.Inject("mockNumCpu", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) + return runtime.GOMAXPROCS(0) +}