diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel index 5325601fee7ec..f7fda21cb115e 100644 --- a/ddl/BUILD.bazel +++ b/ddl/BUILD.bazel @@ -13,7 +13,10 @@ go_library( srcs = [ "backfilling.go", "backfilling_dispatcher.go", + "backfilling_import_cloud.go", + "backfilling_import_local.go", "backfilling_operator.go", + "backfilling_read_index.go", "backfilling_scheduler.go", "callback.go", "cluster.go", @@ -47,9 +50,6 @@ go_library( "schema.go", "sequence.go", "split_region.go", - "stage_ingest_index.go", - "stage_merge_sort.go", - "stage_read_index.go", "stage_scheduler.go", "stat.go", "table.go", diff --git a/ddl/stage_merge_sort.go b/ddl/backfilling_import_cloud.go similarity index 84% rename from ddl/stage_merge_sort.go rename to ddl/backfilling_import_cloud.go index 46b17193f780c..693af13232f10 100644 --- a/ddl/stage_merge_sort.go +++ b/ddl/backfilling_import_cloud.go @@ -29,7 +29,7 @@ import ( "go.uber.org/zap" ) -type mergeSortStage struct { +type cloudImportExecutor struct { jobID int64 index *model.IndexInfo ptbl table.PhysicalTable @@ -37,14 +37,14 @@ type mergeSortStage struct { cloudStoreURI string } -func newMergeSortStage( +func newCloudImportExecutor( jobID int64, index *model.IndexInfo, ptbl table.PhysicalTable, bc ingest.BackendCtx, cloudStoreURI string, -) (*mergeSortStage, error) { - return &mergeSortStage{ +) (*cloudImportExecutor, error) { + return &cloudImportExecutor{ jobID: jobID, index: index, ptbl: ptbl, @@ -53,12 +53,12 @@ func newMergeSortStage( }, nil } -func (*mergeSortStage) Init(ctx context.Context) error { +func (*cloudImportExecutor) Init(ctx context.Context) error { logutil.Logger(ctx).Info("merge sort stage init subtask exec env") return nil } -func (m *mergeSortStage) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { +func (m *cloudImportExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { logutil.Logger(ctx).Info("merge sort stage split subtask") sm := &BackfillSubTaskMeta{} @@ -94,18 +94,17 @@ func (m *mergeSortStage) RunSubtask(ctx context.Context, subtask *proto.Subtask) return err } -func (m *mergeSortStage) Cleanup(ctx context.Context) error { +func (*cloudImportExecutor) Cleanup(ctx context.Context) error { logutil.Logger(ctx).Info("merge sort stage clean up subtask env") - ingest.LitBackCtxMgr.Unregister(m.jobID) return nil } -func (*mergeSortStage) OnFinished(ctx context.Context, _ *proto.Subtask) error { +func (*cloudImportExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) error { logutil.Logger(ctx).Info("merge sort stage finish subtask") return nil } -func (*mergeSortStage) Rollback(ctx context.Context) error { +func (*cloudImportExecutor) Rollback(ctx context.Context) error { logutil.Logger(ctx).Info("merge sort stage rollback subtask") return nil } diff --git a/ddl/stage_ingest_index.go b/ddl/backfilling_import_local.go similarity index 77% rename from ddl/stage_ingest_index.go rename to ddl/backfilling_import_local.go index 13fd8eb3e9612..40f5b8a8a6b3f 100644 --- a/ddl/stage_ingest_index.go +++ b/ddl/backfilling_import_local.go @@ -26,20 +26,20 @@ import ( "go.uber.org/zap" ) -type ingestIndexStage struct { +type localImportExecutor struct { jobID int64 index *model.IndexInfo ptbl table.PhysicalTable bc ingest.BackendCtx } -func newIngestIndexStage( +func newImportFromLocalStepExecutor( jobID int64, index *model.IndexInfo, ptbl table.PhysicalTable, bc ingest.BackendCtx, -) *ingestIndexStage { - return &ingestIndexStage{ +) *localImportExecutor { + return &localImportExecutor{ jobID: jobID, index: index, ptbl: ptbl, @@ -47,7 +47,7 @@ func newIngestIndexStage( } } -func (i *ingestIndexStage) Init(ctx context.Context) error { +func (i *localImportExecutor) Init(ctx context.Context) error { logutil.Logger(ctx).Info("ingest index stage init subtask exec env") _, _, err := i.bc.Flush(i.index.ID, ingest.FlushModeForceGlobal) if err != nil { @@ -61,24 +61,22 @@ func (i *ingestIndexStage) Init(ctx context.Context) error { return err } -func (*ingestIndexStage) RunSubtask(ctx context.Context, _ *proto.Subtask) error { +func (*localImportExecutor) RunSubtask(ctx context.Context, _ *proto.Subtask) error { logutil.Logger(ctx).Info("ingest index stage split subtask") return nil } -func (i *ingestIndexStage) Cleanup(ctx context.Context) error { +func (*localImportExecutor) Cleanup(ctx context.Context) error { logutil.Logger(ctx).Info("ingest index stage cleanup subtask exec env") - ingest.LitBackCtxMgr.Unregister(i.jobID) return nil } -func (*ingestIndexStage) OnFinished(ctx context.Context, _ *proto.Subtask) error { +func (*localImportExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) error { logutil.Logger(ctx).Info("ingest index stage finish subtask") return nil } -func (i *ingestIndexStage) Rollback(ctx context.Context) error { +func (*localImportExecutor) Rollback(ctx context.Context) error { logutil.Logger(ctx).Info("ingest index stage rollback backfill add index task") - ingest.LitBackCtxMgr.Unregister(i.jobID) return nil } diff --git a/ddl/stage_read_index.go b/ddl/backfilling_read_index.go similarity index 90% rename from ddl/stage_read_index.go rename to ddl/backfilling_read_index.go index 5032a916ab2ac..1d86f9a83f737 100644 --- a/ddl/stage_read_index.go +++ b/ddl/backfilling_read_index.go @@ -37,7 +37,7 @@ import ( "go.uber.org/zap" ) -type readIndexStage struct { +type readIndexExecutor struct { d *ddl job *model.Job index *model.IndexInfo @@ -61,7 +61,7 @@ type readIndexSummary struct { mu sync.Mutex } -func newReadIndexStage( +func newReadIndexExecutor( d *ddl, job *model.Job, index *model.IndexInfo, @@ -70,8 +70,8 @@ func newReadIndexStage( bc ingest.BackendCtx, summary *execute.Summary, cloudStorageURI string, -) *readIndexStage { - return &readIndexStage{ +) *readIndexExecutor { + return &readIndexExecutor{ d: d, job: job, index: index, @@ -83,13 +83,13 @@ func newReadIndexStage( } } -func (*readIndexStage) Init(_ context.Context) error { +func (*readIndexExecutor) Init(_ context.Context) error { logutil.BgLogger().Info("read index stage init subtask exec env", zap.String("category", "ddl")) return nil } -func (r *readIndexStage) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { +func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { logutil.BgLogger().Info("read index stage run subtask", zap.String("category", "ddl")) @@ -146,19 +146,16 @@ func (r *readIndexStage) RunSubtask(ctx context.Context, subtask *proto.Subtask) return nil } -func (r *readIndexStage) Cleanup(ctx context.Context) error { +func (*readIndexExecutor) Cleanup(ctx context.Context) error { logutil.Logger(ctx).Info("read index stage cleanup subtask exec env", zap.String("category", "ddl")) - if _, ok := r.ptbl.(table.PartitionedTable); ok { - ingest.LitBackCtxMgr.Unregister(r.job.ID) - } return nil } // MockDMLExecutionAddIndexSubTaskFinish is used to mock DML execution during distributed add index. var MockDMLExecutionAddIndexSubTaskFinish func() -func (r *readIndexStage) OnFinished(ctx context.Context, subtask *proto.Subtask) error { +func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { failpoint.Inject("mockDMLExecutionAddIndexSubTaskFinish", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) && MockDMLExecutionAddIndexSubTaskFinish != nil { @@ -194,14 +191,13 @@ func (r *readIndexStage) OnFinished(ctx context.Context, subtask *proto.Subtask) return nil } -func (r *readIndexStage) Rollback(ctx context.Context) error { +func (r *readIndexExecutor) Rollback(ctx context.Context) error { logutil.Logger(ctx).Info("read index stage rollback backfill add index task", zap.String("category", "ddl"), zap.Int64("jobID", r.job.ID)) - ingest.LitBackCtxMgr.Unregister(r.job.ID) return nil } -func (r *readIndexStage) getTableStartEndKey(sm *BackfillSubTaskMeta) ( +func (r *readIndexExecutor) getTableStartEndKey(sm *BackfillSubTaskMeta) ( start, end kv.Key, tbl table.PhysicalTable, err error) { currentVer, err1 := getValidCurrentVersion(r.d.store) if err1 != nil { @@ -224,7 +220,7 @@ func (r *readIndexStage) getTableStartEndKey(sm *BackfillSubTaskMeta) ( return start, end, tbl, nil } -func (r *readIndexStage) buildLocalStorePipeline( +func (r *readIndexExecutor) buildLocalStorePipeline( opCtx *OperatorCtx, d *ddl, sessCtx sessionctx.Context, @@ -244,7 +240,7 @@ func (r *readIndexStage) buildLocalStorePipeline( opCtx, d.store, d.sessPool, r.bc, ei, sessCtx, tbl, r.index, start, end, totalRowCount, counter) } -func (r *readIndexStage) buildExternalStorePipeline( +func (r *readIndexExecutor) buildExternalStorePipeline( opCtx *OperatorCtx, d *ddl, subtaskID int64, diff --git a/ddl/ddl.go b/ddl/ddl.go index 8abbe3c8a8d53..0b0e9619be899 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -673,8 +673,8 @@ func newDDL(ctx context.Context, options ...Option) *ddl { } scheduler.RegisterTaskType(BackfillTaskType, - func(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable) scheduler.Scheduler { - return newBackfillDistScheduler(ctx, id, taskID, taskTable, d) + func(ctx context.Context, id string, task *proto.Task, taskTable scheduler.TaskTable) scheduler.Scheduler { + return newBackfillDistScheduler(ctx, id, task, taskTable, d) }, scheduler.WithSummary, ) diff --git a/ddl/ingest/backend_mgr.go b/ddl/ingest/backend_mgr.go index d1e9f907e460c..f63e7e91dcee9 100644 --- a/ddl/ingest/backend_mgr.go +++ b/ddl/ingest/backend_mgr.go @@ -160,7 +160,7 @@ func newBackendContext(ctx context.Context, jobID int64, be *local.Backend, cfg // Unregister removes a backend context from the backend context manager. func (m *litBackendCtxMgr) Unregister(jobID int64) { - bc, exist := m.SyncMap.Load(jobID) + bc, exist := m.SyncMap.Delete(jobID) if !exist { return } @@ -170,7 +170,6 @@ func (m *litBackendCtxMgr) Unregister(jobID int64) { bc.checkpointMgr.Close() } m.memRoot.Release(StructSizeBackendCtx) - m.Delete(jobID) m.memRoot.ReleaseWithTag(EncodeBackendTag(jobID)) logutil.Logger(bc.ctx).Info(LitInfoCloseBackend, zap.Int64("job ID", jobID), zap.Int64("current memory usage", m.memRoot.CurrentUsage()), diff --git a/ddl/stage_scheduler.go b/ddl/stage_scheduler.go index f720a1a06cdb3..833750e6ef50a 100644 --- a/ddl/stage_scheduler.go +++ b/ddl/stage_scheduler.go @@ -52,9 +52,9 @@ type BackfillSubTaskMeta struct { TotalKVSize uint64 `json:"total_kv_size"` } -// NewBackfillSchedulerHandle creates a new backfill scheduler. -func NewBackfillSchedulerHandle(ctx context.Context, taskMeta []byte, d *ddl, - stage int64, summary *execute.Summary) (execute.SubtaskExecutor, error) { +// NewBackfillSubtaskExecutor creates a new backfill subtask executor. +func NewBackfillSubtaskExecutor(_ context.Context, taskMeta []byte, d *ddl, + bc ingest.BackendCtx, stage int64, summary *execute.Summary) (execute.SubtaskExecutor, error) { bgm := &BackfillGlobalMeta{} err := json.Unmarshal(taskMeta, bgm) if err != nil { @@ -73,23 +73,18 @@ func NewBackfillSchedulerHandle(ctx context.Context, taskMeta []byte, d *ddl, return nil, errors.New("index info not found") } - bc, err := ingest.LitBackCtxMgr.Register(ctx, indexInfo.Unique, jobMeta.ID, d.etcdCli, jobMeta.ReorgMeta.ResourceGroupName) - if err != nil { - return nil, errors.Trace(err) - } - switch stage { case proto.StepInit: jc := d.jobContext(jobMeta.ID, jobMeta.ReorgMeta) d.setDDLLabelForTopSQL(jobMeta.ID, jobMeta.Query) d.setDDLSourceForDiagnosis(jobMeta.ID, jobMeta.Type) - return newReadIndexStage( + return newReadIndexExecutor( d, &bgm.Job, indexInfo, tbl.(table.PhysicalTable), jc, bc, summary, bgm.CloudStorageURI), nil case proto.StepOne: if len(bgm.CloudStorageURI) > 0 { - return newMergeSortStage(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI) + return newCloudImportExecutor(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI) } - return newIngestIndexStage(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc), nil + return newImportFromLocalStepExecutor(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc), nil default: return nil, errors.Errorf("unknown step %d for job %d", stage, jobMeta.ID) } @@ -100,23 +95,66 @@ const BackfillTaskType = "backfill" type backfillDistScheduler struct { *scheduler.BaseScheduler - d *ddl + d *ddl + task *proto.Task + taskTable scheduler.TaskTable + backendCtx ingest.BackendCtx + jobID int64 } -func newBackfillDistScheduler(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable, d *ddl) scheduler.Scheduler { +func newBackfillDistScheduler(ctx context.Context, id string, task *proto.Task, taskTable scheduler.TaskTable, d *ddl) scheduler.Scheduler { s := &backfillDistScheduler{ - BaseScheduler: scheduler.NewBaseScheduler(ctx, id, taskID, taskTable), + BaseScheduler: scheduler.NewBaseScheduler(ctx, id, task.ID, taskTable), d: d, + task: task, + taskTable: taskTable, } s.BaseScheduler.Extension = s return s } +func (s *backfillDistScheduler) Init(ctx context.Context) error { + err := s.BaseScheduler.Init(ctx) + if err != nil { + return err + } + d := s.d + + bgm := &BackfillGlobalMeta{} + err = json.Unmarshal(s.task.Meta, bgm) + if err != nil { + return errors.Trace(err) + } + job := &bgm.Job + _, tbl, err := d.getTableByTxn(d.store, job.SchemaID, job.TableID) + if err != nil { + return errors.Trace(err) + } + idx := model.FindIndexInfoByID(tbl.Meta().Indices, bgm.EleID) + if idx == nil { + return errors.Trace(errors.New("index info not found")) + } + bc, err := ingest.LitBackCtxMgr.Register(ctx, idx.Unique, job.ID, d.etcdCli, job.ReorgMeta.ResourceGroupName) + if err != nil { + return errors.Trace(err) + } + s.backendCtx = bc + s.jobID = job.ID + return nil +} + func (s *backfillDistScheduler) GetSubtaskExecutor(ctx context.Context, task *proto.Task, summary *execute.Summary) (execute.SubtaskExecutor, error) { switch task.Step { case proto.StepInit, proto.StepOne: - return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, task.Step, summary) + return NewBackfillSubtaskExecutor(ctx, task.Meta, s.d, s.backendCtx, task.Step, summary) default: return nil, errors.Errorf("unknown backfill step %d for task %d", task.Step, task.ID) } } + +func (s *backfillDistScheduler) Close() { + if s.backendCtx != nil { + ingest.LitBackCtxMgr.Unregister(s.jobID) + } + s.BaseScheduler.Close() +} diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index 78fc7bfe7f70d..5b8cbc862b945 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -141,8 +141,8 @@ func registerTaskMetaInner(t *testing.T, taskType string, mockExtension schedule return baseDispatcher }) scheduler.RegisterTaskType(taskType, - func(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable) scheduler.Scheduler { - s := scheduler.NewBaseScheduler(ctx, id, taskID, taskTable) + func(ctx context.Context, id string, task *proto.Task, taskTable scheduler.TaskTable) scheduler.Scheduler { + s := scheduler.NewBaseScheduler(ctx, id, task.ID, taskTable) s.Extension = mockExtension return s }, diff --git a/disttask/framework/mock/scheduler_mock.go b/disttask/framework/mock/scheduler_mock.go index 20287b77ce705..e4b99663e46d4 100644 --- a/disttask/framework/mock/scheduler_mock.go +++ b/disttask/framework/mock/scheduler_mock.go @@ -281,6 +281,32 @@ func (m *MockScheduler) EXPECT() *MockSchedulerMockRecorder { return m.recorder } +// Close mocks base method. +func (m *MockScheduler) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockSchedulerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockScheduler)(nil).Close)) +} + +// Init mocks base method. +func (m *MockScheduler) Init(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init. +func (mr *MockSchedulerMockRecorder) Init(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockScheduler)(nil).Init), arg0) +} + // Rollback mocks base method. func (m *MockScheduler) Rollback(arg0 context.Context, arg1 *proto.Task) error { m.ctrl.T.Helper() diff --git a/disttask/framework/scheduler/interface.go b/disttask/framework/scheduler/interface.go index 5a46c7a23b8a9..fe20abea7a69e 100644 --- a/disttask/framework/scheduler/interface.go +++ b/disttask/framework/scheduler/interface.go @@ -45,10 +45,12 @@ type Pool interface { } // Scheduler is the subtask scheduler for a task. -// each task type should implement this interface. +// Each task type should implement this interface. type Scheduler interface { + Init(context.Context) error Run(context.Context, *proto.Task) error Rollback(context.Context, *proto.Task) error + Close() } // Extension extends the scheduler. diff --git a/disttask/framework/scheduler/manager.go b/disttask/framework/scheduler/manager.go index 96d675f8d06b8..edf9c47b53c26 100644 --- a/disttask/framework/scheduler/manager.go +++ b/disttask/framework/scheduler/manager.go @@ -194,8 +194,8 @@ func (m *Manager) onRunnableTasks(ctx context.Context, tasks []*proto.Task) { m.addHandlingTask(task.ID) t := task err = m.schedulerPool.Run(func() { - m.onRunnableTask(ctx, t.ID, t.Type) - m.removeHandlingTask(task.ID) + m.onRunnableTask(ctx, t) + m.removeHandlingTask(t.ID) }) // pool closed. if err != nil { @@ -254,15 +254,21 @@ type TestContext struct { var testContexts sync.Map // onRunnableTask handles a runnable task. -func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType string) { - logutil.Logger(m.logCtx).Info("onRunnableTask", zap.Any("task_id", taskID), zap.Any("type", taskType)) +func (m *Manager) onRunnableTask(ctx context.Context, task *proto.Task) { + logutil.Logger(m.logCtx).Info("onRunnableTask", zap.Int64("task_id", task.ID), zap.String("type", task.Type)) // runCtx only used in scheduler.Run, cancel in m.fetchAndFastCancelTasks. - factory := getSchedulerFactory(taskType) + factory := getSchedulerFactory(task.Type) if factory == nil { - m.onError(errors.Errorf("task type %s not found", taskType)) + m.onError(errors.Errorf("task type %s not found", task.Type)) + return + } + scheduler := factory(ctx, m.id, task, m.taskTable) + err := scheduler.Init(ctx) + if err != nil { + m.onError(err) return } - scheduler := factory(ctx, m.id, taskID, m.taskTable) + defer scheduler.Close() for { select { case <-ctx.Done(): @@ -280,13 +286,14 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str } }() }) - task, err := m.taskTable.GetGlobalTaskByID(taskID) + task, err := m.taskTable.GetGlobalTaskByID(task.ID) if err != nil { m.onError(err) return } if task.State != proto.TaskStateRunning && task.State != proto.TaskStateReverting { - logutil.Logger(m.logCtx).Info("onRunnableTask exit", zap.Any("task_id", taskID), zap.Int64("step", task.Step), zap.Any("state", task.State)) + logutil.Logger(m.logCtx).Info("onRunnableTask exit", + zap.Int64("task_id", task.ID), zap.Int64("step", task.Step), zap.String("state", task.State)) return } if exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil { diff --git a/disttask/framework/scheduler/manager_test.go b/disttask/framework/scheduler/manager_test.go index dd8b383464e9e..de0dfa341f00a 100644 --- a/disttask/framework/scheduler/manager_test.go +++ b/disttask/framework/scheduler/manager_test.go @@ -101,14 +101,16 @@ func TestOnRunnableTasks(t *testing.T) { m.onRunnableTasks(context.Background(), nil) RegisterTaskType("type", - func(ctx context.Context, id string, taskID int64, taskTable TaskTable) Scheduler { + func(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) Scheduler { return mockInternalScheduler }) // get subtask failed + mockInternalScheduler.EXPECT().Init(gomock.Any()).Return(nil) mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(false, errors.New("get subtask failed")) + mockInternalScheduler.EXPECT().Close() m.onRunnableTasks(context.Background(), []*proto.Task{task}) // no subtask @@ -164,7 +166,7 @@ func TestManager(t *testing.T) { return mockPool, nil }) RegisterTaskType("type", - func(ctx context.Context, id string, taskID int64, taskTable TaskTable) Scheduler { + func(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) Scheduler { return mockInternalScheduler }) id := "test" @@ -177,6 +179,7 @@ func TestManager(t *testing.T) { Return([]*proto.Task{task1, task2}, nil).AnyTimes() mockTaskTable.EXPECT().GetGlobalTasksInStates(proto.TaskStateReverting). Return([]*proto.Task{task2}, nil).AnyTimes() + mockInternalScheduler.EXPECT().Init(gomock.Any()).Return(nil) // task1 mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). @@ -188,9 +191,11 @@ func TestManager(t *testing.T) { []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(true, nil) mockInternalScheduler.EXPECT().Run(gomock.Any(), task1).Return(nil) + mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(false, nil).AnyTimes() + mockInternalScheduler.EXPECT().Close() // task2 mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). @@ -200,10 +205,12 @@ func TestManager(t *testing.T) { mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(true, nil) + mockInternalScheduler.EXPECT().Init(gomock.Any()).Return(nil) mockInternalScheduler.EXPECT().Rollback(gomock.Any(), task2).Return(nil) mockTaskTable.EXPECT().HasSubtasksInStates(id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}). Return(false, nil).AnyTimes() + mockInternalScheduler.EXPECT().Close() // for scheduler pool mockPool.EXPECT().ReleaseAndWait().Do(func() { wg.Wait() diff --git a/disttask/framework/scheduler/register.go b/disttask/framework/scheduler/register.go index 250d33fe0d9cb..9d55455f94a2b 100644 --- a/disttask/framework/scheduler/register.go +++ b/disttask/framework/scheduler/register.go @@ -17,6 +17,7 @@ package scheduler import ( "context" + "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/scheduler/execute" ) @@ -35,7 +36,7 @@ var ( taskSchedulerFactories = make(map[string]schedulerFactoryFn) ) -type schedulerFactoryFn func(ctx context.Context, id string, taskID int64, taskTable TaskTable) Scheduler +type schedulerFactoryFn func(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) Scheduler // RegisterTaskType registers the task type. func RegisterTaskType(taskType string, factory schedulerFactoryFn, opts ...TaskTypeOption) { diff --git a/disttask/framework/scheduler/register_test.go b/disttask/framework/scheduler/register_test.go index b612129757eee..5111a936a7a76 100644 --- a/disttask/framework/scheduler/register_test.go +++ b/disttask/framework/scheduler/register_test.go @@ -18,13 +18,14 @@ import ( "context" "testing" + "github.com/pingcap/tidb/disttask/framework/proto" "github.com/stretchr/testify/require" ) func TestRegisterTaskType(t *testing.T) { // other case might add task types, so we need to clear it first ClearSchedulers() - factoryFn := func(ctx context.Context, id string, taskID int64, taskTable TaskTable) Scheduler { + factoryFn := func(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) Scheduler { return nil } RegisterTaskType("test1", factoryFn) diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 0ee1f421abaef..96a41e6d350b7 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -96,6 +96,11 @@ func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup }() } +// Init implements the Scheduler interface. +func (*BaseScheduler) Init(_ context.Context) error { + return nil +} + // Run runs the scheduler task. func (s *BaseScheduler) Run(ctx context.Context, task *proto.Task) (err error) { defer func() { @@ -350,6 +355,10 @@ func (s *BaseScheduler) Rollback(ctx context.Context, task *proto.Task) error { return s.getError() } +// Close closes the scheduler when all the subtasks are complete. +func (*BaseScheduler) Close() { +} + func runSummaryCollectLoop( ctx context.Context, task *proto.Task, diff --git a/disttask/importinto/scheduler.go b/disttask/importinto/scheduler.go index f64fb638e2281..ef65e1524fe4c 100644 --- a/disttask/importinto/scheduler.go +++ b/disttask/importinto/scheduler.go @@ -249,9 +249,9 @@ type importScheduler struct { *scheduler.BaseScheduler } -func newImportScheduler(ctx context.Context, id string, taskID int64, taskTable scheduler.TaskTable) scheduler.Scheduler { +func newImportScheduler(ctx context.Context, id string, task *proto.Task, taskTable scheduler.TaskTable) scheduler.Scheduler { s := &importScheduler{ - BaseScheduler: scheduler.NewBaseScheduler(ctx, id, taskID, taskTable), + BaseScheduler: scheduler.NewBaseScheduler(ctx, id, task.ID, taskTable), } s.BaseScheduler.Extension = s return s diff --git a/util/generic/sync_map.go b/util/generic/sync_map.go index 0a3d0f734d76f..26ca29de21c50 100644 --- a/util/generic/sync_map.go +++ b/util/generic/sync_map.go @@ -44,11 +44,16 @@ func (m *SyncMap[K, V]) Load(key K) (V, bool) { return val, exist } -// Delete deletes a key value. -func (m *SyncMap[K, V]) Delete(key K) { +// Delete deletes the value for a key, returning the previous value if any. +// The exist result reports whether the key was present. +func (m *SyncMap[K, V]) Delete(key K) (val V, exist bool) { m.mu.Lock() - delete(m.item, key) + val, exist = m.item[key] + if exist { + delete(m.item, key) + } m.mu.Unlock() + return val, exist } // Keys returns all the keys in the map.