From 0c349409cf5037e488a380003596094753f1ab08 Mon Sep 17 00:00:00 2001 From: EasonBall <592838129@qq.com> Date: Thu, 12 Oct 2023 13:24:25 +0800 Subject: [PATCH] disttask: refine scheduler error handling (#47313) ref pingcap/tidb#46258 --- disttask/framework/framework_test.go | 6 +- disttask/framework/scheduler/BUILD.bazel | 4 + disttask/framework/scheduler/manager.go | 31 +-- disttask/framework/scheduler/manager_test.go | 8 +- disttask/framework/scheduler/scheduler.go | 183 ++++++++++++++---- .../framework/scheduler/scheduler_test.go | 77 ++++---- .../addindextest/add_index_test.go | 5 + 7 files changed, 210 insertions(+), 104 deletions(-) diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index 39349b3d950c4..bea0170ec4b19 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -633,8 +633,7 @@ func TestFrameworkSubtaskFinishedCancel(t *testing.T) { defer ctrl.Finish() RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 3) - err := failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockSubtaskFinishedCancel", "1*return(true)") - require.NoError(t, err) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockSubtaskFinishedCancel", "1*return(true)")) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockSubtaskFinishedCancel")) }() @@ -649,8 +648,7 @@ func TestFrameworkRunSubtaskCancel(t *testing.T) { defer ctrl.Finish() RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 3) - err := failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskCancel", "1*return(true)") - require.NoError(t, err) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskCancel", "1*return(true)")) DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskCancel")) distContext.Close() diff --git a/disttask/framework/scheduler/BUILD.bazel b/disttask/framework/scheduler/BUILD.bazel index b068ea7bb840b..c92d7043ae39e 100644 --- a/disttask/framework/scheduler/BUILD.bazel +++ b/disttask/framework/scheduler/BUILD.bazel @@ -11,7 +11,10 @@ go_library( importpath = "github.com/pingcap/tidb/disttask/framework/scheduler", visibility = ["//visibility:public"], deps = [ + "//br/pkg/lightning/common", "//config", + "//disttask/framework/dispatcher", + "//disttask/framework/handle", "//disttask/framework/proto", "//disttask/framework/scheduler/execute", "//disttask/framework/storage", @@ -19,6 +22,7 @@ go_library( "//metrics", "//resourcemanager/pool/spool", "//resourcemanager/util", + "//util/backoff", "//util/logutil", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", diff --git a/disttask/framework/scheduler/manager.go b/disttask/framework/scheduler/manager.go index 2892ddd77d391..ff7682c99790c 100644 --- a/disttask/framework/scheduler/manager.go +++ b/disttask/framework/scheduler/manager.go @@ -32,8 +32,7 @@ import ( ) var ( - schedulerPoolSize int32 = 4 - subtaskExecutorPoolSize int32 = 10 + schedulerPoolSize int32 = 4 // same as dispatcher checkTime = 300 * time.Millisecond retrySQLTimes = 3 @@ -65,9 +64,9 @@ type Manager struct { schedulerPool Pool mu struct { sync.RWMutex - // taskID -> cancelFunc. - // cancelFunc is used to fast cancel the scheduler.Run. - handlingTasks map[int64]context.CancelFunc + // taskID -> CancelCauseFunc. + // CancelCauseFunc is used to fast cancel the scheduler.Run. + handlingTasks map[int64]context.CancelCauseFunc } // id, it's the same as server id now, i.e. host:port. id string @@ -87,7 +86,7 @@ func (b *ManagerBuilder) BuildManager(ctx context.Context, id string, taskTable newPool: b.newPool, } m.ctx, m.cancel = context.WithCancel(ctx) - m.mu.handlingTasks = make(map[int64]context.CancelFunc) + m.mu.handlingTasks = make(map[int64]context.CancelCauseFunc) schedulerPool, err := m.newPool("scheduler_pool", schedulerPoolSize, util.DistTask) if err != nil { @@ -227,7 +226,8 @@ func (m *Manager) onCanceledTasks(_ context.Context, tasks []*proto.Task) { for _, task := range tasks { logutil.Logger(m.logCtx).Info("onCanceledTasks", zap.Int64("task-id", task.ID)) if cancel, ok := m.mu.handlingTasks[task.ID]; ok && cancel != nil { - cancel() + // subtask needs to change its state to canceled. + cancel(ErrCancelSubtask) } } } @@ -239,7 +239,9 @@ func (m *Manager) onPausingTasks(tasks []*proto.Task) error { for _, task := range tasks { logutil.Logger(m.logCtx).Info("onPausingTasks", zap.Any("task_id", task.ID)) if cancel, ok := m.mu.handlingTasks[task.ID]; ok && cancel != nil { - cancel() + // Pause all running subtasks, don't mark subtasks as canceled. + // Should not change the subtask's state. + cancel(nil) } if err := m.taskTable.PauseSubtasks(m.id, task.ID); err != nil { return err @@ -255,7 +257,9 @@ func (m *Manager) cancelAllRunningTasks() { for id, cancel := range m.mu.handlingTasks { logutil.Logger(m.logCtx).Info("cancelAllRunningTasks", zap.Int64("task-id", id)) if cancel != nil { - cancel() + // tidb shutdown, don't mark subtask as canceled. + // Should not change the subtask's state. + cancel(nil) } } } @@ -322,6 +326,9 @@ func (m *Manager) onRunnableTask(ctx context.Context, task *proto.Task) { m.logErr(err) return } + if task == nil { + return + } if task.State != proto.TaskStateRunning && task.State != proto.TaskStateReverting { logutil.Logger(m.logCtx).Info("onRunnableTask exit", zap.Int64("task-id", task.ID), zap.Int64("step", task.Step), zap.String("state", task.State)) @@ -338,10 +345,10 @@ func (m *Manager) onRunnableTask(ctx context.Context, task *proto.Task) { } switch task.State { case proto.TaskStateRunning: - runCtx, runCancel := context.WithCancel(ctx) + runCtx, runCancel := context.WithCancelCause(ctx) m.registerCancelFunc(task.ID, runCancel) err = scheduler.Run(runCtx, task) - runCancel() + runCancel(nil) case proto.TaskStatePausing: err = scheduler.Pause(ctx, task) case proto.TaskStateReverting: @@ -361,7 +368,7 @@ func (m *Manager) addHandlingTask(id int64) { } // registerCancelFunc registers a cancel function for a task. -func (m *Manager) registerCancelFunc(id int64, cancel context.CancelFunc) { +func (m *Manager) registerCancelFunc(id int64, cancel context.CancelCauseFunc) { m.mu.Lock() defer m.mu.Unlock() m.mu.handlingTasks[id] = cancel diff --git a/disttask/framework/scheduler/manager_test.go b/disttask/framework/scheduler/manager_test.go index 7873aa59beaf2..323b1a01aded4 100644 --- a/disttask/framework/scheduler/manager_test.go +++ b/disttask/framework/scheduler/manager_test.go @@ -72,16 +72,16 @@ func TestManageTask(t *testing.T) { newTasks = m.filterAlreadyHandlingTasks(tasks) require.Equal(t, []*proto.Task{{ID: 1}}, newTasks) - ctx1, cancel1 := context.WithCancel(context.Background()) + ctx1, cancel1 := context.WithCancelCause(context.Background()) m.registerCancelFunc(2, cancel1) m.cancelAllRunningTasks() require.Equal(t, context.Canceled, ctx1.Err()) // test cancel. m.addHandlingTask(1) - ctx2, cancel2 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancelCause(context.Background()) m.registerCancelFunc(1, cancel2) - ctx3, cancel3 := context.WithCancel(context.Background()) + ctx3, cancel3 := context.WithCancelCause(context.Background()) m.registerCancelFunc(2, cancel3) m.onCanceledTasks(context.Background(), []*proto.Task{{ID: 1}}) require.Equal(t, context.Canceled, ctx2.Err()) @@ -89,7 +89,7 @@ func TestManageTask(t *testing.T) { // test pause. m.addHandlingTask(3) - ctx4, cancel4 := context.WithCancel(context.Background()) + ctx4, cancel4 := context.WithCancelCause(context.Background()) m.registerCancelFunc(1, cancel4) mockTaskTable.EXPECT().PauseSubtasks("test", int64(1)).Return(nil) m.onPausingTasks([]*proto.Task{{ID: 1}}) diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 56f8aa6e619ba..d6f2074119a53 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -21,11 +21,15 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/lightning/common" + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/handle" "github.com/pingcap/tidb/disttask/framework/proto" "github.com/pingcap/tidb/disttask/framework/scheduler/execute" "github.com/pingcap/tidb/disttask/framework/storage" "github.com/pingcap/tidb/domain/infosync" "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/util/backoff" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -35,8 +39,17 @@ const ( DefaultCheckSubtaskCanceledInterval = 2 * time.Second ) -// TestSyncChan is used to sync the test. -var TestSyncChan = make(chan struct{}) +var ( + // ErrCancelSubtask is the cancel cause when cancelling subtasks. + ErrCancelSubtask = errors.New("cancel subtasks") + // ErrFinishSubtask is the cancel cause when scheduler successfully processed subtasks. + ErrFinishSubtask = errors.New("finish subtasks") + // ErrFinishRollback is the cancel cause when scheduler rollback successfully. + ErrFinishRollback = errors.New("finish rollback") + + // TestSyncChan is used to sync the test. + TestSyncChan = make(chan struct{}) +) // BaseScheduler is the base implementation of Scheduler. type BaseScheduler struct { @@ -53,7 +66,7 @@ type BaseScheduler struct { // handled indicates whether the error has been updated to one of the subtask. handled bool // runtimeCancel is used to cancel the Run/Rollback when error occurs. - runtimeCancel context.CancelFunc + runtimeCancel context.CancelCauseFunc } } @@ -68,7 +81,7 @@ func NewBaseScheduler(_ context.Context, id string, taskID int64, taskTable Task return schedulerImpl } -func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup, cancelFn context.CancelFunc) { +func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup, cancelFn context.CancelCauseFunc) { wg.Add(1) go func() { defer wg.Done() @@ -87,7 +100,9 @@ func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup if canceled { logutil.Logger(s.logCtx).Info("scheduler canceled") if cancelFn != nil { - cancelFn() + // subtask transferred to other tidb, don't mark subtask as canceled. + // Should not change the subtask's state. + cancelFn(nil) } } } @@ -106,7 +121,7 @@ func (s *BaseScheduler) Run(ctx context.Context, task *proto.Task) (err error) { if r := recover(); r != nil { logutil.Logger(ctx).Error("BaseScheduler panicked", zap.Any("recover", r), zap.Stack("stack")) err4Panic := errors.Errorf("%v", r) - err1 := s.taskTable.UpdateErrorToSubtask(s.id, task.ID, err4Panic) + err1 := s.updateErrorToSubtask(ctx, task.ID, err4Panic) if err == nil { err = err1 } @@ -116,7 +131,10 @@ func (s *BaseScheduler) Run(ctx context.Context, task *proto.Task) (err error) { if s.mu.handled { return err } - return s.taskTable.UpdateErrorToSubtask(s.id, task.ID, err) + if err == nil { + return nil + } + return s.updateErrorToSubtask(ctx, task.ID, err) } func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { @@ -124,8 +142,8 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { s.onError(ctx.Err()) return s.getError() } - runCtx, runCancel := context.WithCancel(ctx) - defer runCancel() + runCtx, runCancel := context.WithCancelCause(ctx) + defer runCancel(ErrFinishSubtask) s.registerCancelFunc(runCancel) s.resetError() logutil.Logger(s.logCtx).Info("scheduler run a step", zap.Any("step", task.Step), zap.Any("concurrency", task.Concurrency)) @@ -168,6 +186,10 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { proto.TaskStatePending, proto.TaskStateRunning) if err != nil { s.onError(err) + if common.IsRetryableError(err) { + logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) + return nil + } return s.getError() } for _, subtask := range subtasks { @@ -212,7 +234,7 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { } } else { // subtask.State == proto.TaskStatePending - s.startSubtaskAndUpdateState(subtask) + s.startSubtaskAndUpdateState(ctx, subtask) if err := s.getError(); err != nil { logutil.Logger(s.logCtx).Warn("startSubtaskAndUpdateState meets error", zap.Error(err)) continue @@ -236,20 +258,26 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { func (s *BaseScheduler) runSubtask(ctx context.Context, executor execute.SubtaskExecutor, subtask *proto.Subtask) { err := executor.RunSubtask(ctx, subtask) failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { + if val.(bool) { + err = ErrCancelSubtask + } + }) + + failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { if val.(bool) { err = context.Canceled } }) + if err != nil { s.onError(err) - if errors.Cause(err) == context.Canceled { - s.updateSubtaskStateAndError(subtask, proto.TaskStateCanceled, s.getError()) - } else { - s.updateSubtaskStateAndError(subtask, proto.TaskStateFailed, s.getError()) - } - s.markErrorHandled() + } + + finished := s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { return } + failpoint.Inject("mockTiDBDown", func(val failpoint.Value) { logutil.Logger(s.logCtx).Info("trigger mockTiDBDown") if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" { @@ -311,19 +339,22 @@ func (s *BaseScheduler) onSubtaskFinished(ctx context.Context, executor execute. } failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { if val.(bool) { - s.onError(context.Canceled) + s.onError(ErrCancelSubtask) } }) - if err := s.getError(); err != nil { - if errors.Cause(err) == context.Canceled { - s.updateSubtaskStateAndError(subtask, proto.TaskStateCanceled, nil) - } else { - s.updateSubtaskStateAndError(subtask, proto.TaskStateFailed, s.getError()) - } - s.markErrorHandled() + + finished := s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { return } - s.finishSubtaskAndUpdateState(subtask) + + s.finishSubtaskAndUpdateState(ctx, subtask) + + finished = s.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + failpoint.Inject("syncAfterSubtaskFinish", func() { TestSyncChan <- struct{}{} <-TestSyncChan @@ -332,8 +363,8 @@ func (s *BaseScheduler) onSubtaskFinished(ctx context.Context, executor execute. // Rollback rollbacks the scheduler task. func (s *BaseScheduler) Rollback(ctx context.Context, task *proto.Task) error { - rollbackCtx, rollbackCancel := context.WithCancel(ctx) - defer rollbackCancel() + rollbackCtx, rollbackCancel := context.WithCancelCause(ctx) + defer rollbackCancel(ErrFinishRollback) s.registerCancelFunc(rollbackCancel) s.resetError() @@ -429,7 +460,7 @@ func runSummaryCollectLoop( return nil, func() {}, nil } -func (s *BaseScheduler) registerCancelFunc(cancel context.CancelFunc) { +func (s *BaseScheduler) registerCancelFunc(cancel context.CancelCauseFunc) { s.mu.Lock() defer s.mu.Unlock() s.mu.runtimeCancel = cancel @@ -450,7 +481,7 @@ func (s *BaseScheduler) onError(err error) { } if s.mu.runtimeCancel != nil { - s.mu.runtimeCancel() + s.mu.runtimeCancel(err) } } @@ -473,25 +504,61 @@ func (s *BaseScheduler) resetError() { s.mu.handled = false } -func (s *BaseScheduler) startSubtaskAndUpdateState(subtask *proto.Subtask) { +func (s *BaseScheduler) startSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { metrics.DecDistTaskSubTaskCnt(subtask) metrics.EndDistTaskSubTask(subtask) - err := s.taskTable.StartSubtask(subtask.ID) - if err != nil { - s.onError(err) - } + s.startSubtask(ctx, subtask.ID) subtask.State = proto.TaskStateRunning metrics.IncDistTaskSubTaskCnt(subtask) metrics.StartDistTaskSubTask(subtask) } -func (s *BaseScheduler) updateSubtaskStateAndError(subtask *proto.Subtask, state string, subTaskErr error) { - metrics.DecDistTaskSubTaskCnt(subtask) - metrics.EndDistTaskSubTask(subtask) - err := s.taskTable.UpdateSubtaskStateAndError(subtask.ID, state, subTaskErr) +func (s *BaseScheduler) updateSubtaskStateAndErrorImpl(subtaskID int64, state string, subTaskErr error) { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) + ctx := context.Background() + err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.UpdateSubtaskStateAndError(subtaskID, state, subTaskErr) + }, + ) + if err != nil { + s.onError(err) + } +} + +func (s *BaseScheduler) startSubtask(ctx context.Context, subtaskID int64) { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.StartSubtask(subtaskID) + }, + ) if err != nil { s.onError(err) } +} + +func (s *BaseScheduler) finishSubtask(ctx context.Context, subtask *proto.Subtask) { + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.FinishSubtask(subtask.ID, subtask.Meta) + }, + ) + if err != nil { + s.onError(err) + } +} + +func (s *BaseScheduler) updateSubtaskStateAndError(subtask *proto.Subtask, state string, subTaskErr error) { + metrics.DecDistTaskSubTaskCnt(subtask) + metrics.EndDistTaskSubTask(subtask) + s.updateSubtaskStateAndErrorImpl(subtask.ID, state, subTaskErr) subtask.State = state metrics.IncDistTaskSubTaskCnt(subtask) if !subtask.IsFinished() { @@ -499,12 +566,44 @@ func (s *BaseScheduler) updateSubtaskStateAndError(subtask *proto.Subtask, state } } -func (s *BaseScheduler) finishSubtaskAndUpdateState(subtask *proto.Subtask) { +func (s *BaseScheduler) finishSubtaskAndUpdateState(ctx context.Context, subtask *proto.Subtask) { metrics.DecDistTaskSubTaskCnt(subtask) metrics.EndDistTaskSubTask(subtask) - if err := s.taskTable.FinishSubtask(subtask.ID, subtask.Meta); err != nil { - s.onError(err) - } + s.finishSubtask(ctx, subtask) subtask.State = proto.TaskStateSucceed metrics.IncDistTaskSubTaskCnt(subtask) } + +// markSubTaskCanceledOrFailed check the error type and decide the subtasks' state. +// 1. Only cancel subtasks when meet ErrCancelSubtask. +// 2. Only fail subtasks when meet non retryable error. +// 3. When meet other errors, don't change subtasks' state. +func (s *BaseScheduler) markSubTaskCanceledOrFailed(ctx context.Context, subtask *proto.Subtask) bool { + if err := s.getError(); err != nil { + if ctx.Err() != nil && context.Cause(ctx) == ErrCancelSubtask { + logutil.Logger(s.logCtx).Warn("subtask canceled", zap.Error(err)) + s.updateSubtaskStateAndError(subtask, proto.TaskStateCanceled, nil) + } else if common.IsRetryableError(err) { + logutil.Logger(s.logCtx).Warn("met retryable error", zap.Error(err)) + } else if errors.Cause(err) != context.Canceled { + logutil.Logger(s.logCtx).Warn("subtask failed", zap.Error(err)) + s.updateSubtaskStateAndError(subtask, proto.TaskStateFailed, err) + } else { + logutil.Logger(s.logCtx).Info("met context canceled for gracefully shutdown", zap.Error(err)) + } + s.markErrorHandled() + return true + } + return false +} + +func (s *BaseScheduler) updateErrorToSubtask(ctx context.Context, taskID int64, err error) error { + logger := logutil.Logger(s.logCtx) + backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) + err1 := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, s.taskTable.UpdateErrorToSubtask(s.id, taskID, err) + }, + ) + return err1 +} diff --git a/disttask/framework/scheduler/scheduler_test.go b/disttask/framework/scheduler/scheduler_test.go index b78389225dc20..0f3707fee0618 100644 --- a/disttask/framework/scheduler/scheduler_test.go +++ b/disttask/framework/scheduler/scheduler_test.go @@ -16,9 +16,7 @@ package scheduler import ( "context" - "sync" "testing" - "time" "github.com/pingcap/tidb/disttask/framework/mock" mockexecute "github.com/pingcap/tidb/disttask/framework/mock/execute" @@ -54,32 +52,28 @@ func TestSchedulerRun(t *testing.T) { // 1. no scheduler constructor schedulerRegisterErr := errors.Errorf("constructor of scheduler for key not found") - mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, schedulerRegisterErr) + mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, schedulerRegisterErr).Times(2) scheduler := NewBaseScheduler(ctx, "id", 1, mockSubtaskTable) scheduler.Extension = mockExtension - // UpdateErrorToSubtask won't return such errors, but since the error is not handled, - // it's saved by UpdateErrorToSubtask. - // here we use this to check the returned error of s.run. - forwardErrFn := func(_ string, _ int64, err error) error { - return err - } - mockSubtaskTable.EXPECT().UpdateErrorToSubtask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(forwardErrFn).AnyTimes() - err := scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) + err := scheduler.run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) require.EqualError(t, err, schedulerRegisterErr.Error()) + mockSubtaskTable.EXPECT().UpdateErrorToSubtask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) + require.NoError(t, err) + // 2. init subtask exec env failed mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil).AnyTimes() - // 2. init subtask exec env failed initErr := errors.New("init error") mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(initErr) - err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) + err = scheduler.run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) require.EqualError(t, err, initErr.Error()) var taskID int64 = 1 var concurrency uint64 = 10 task := &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency} - // 5. run subtask failed + // 3. run subtask failed runSubtaskErr := errors.New("run subtask error") mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, @@ -92,10 +86,11 @@ func TestSchedulerRun(t *testing.T) { mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(runSubtaskErr) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateFailed, gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) + err = scheduler.Run(runCtx, task) require.EqualError(t, err, runSubtaskErr.Error()) - // 6. run subtask success + // 4. run subtask success mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ @@ -114,7 +109,7 @@ func TestSchedulerRun(t *testing.T) { err = scheduler.Run(runCtx, task) require.NoError(t, err) - // 7. run subtask one by one + // 5. run subtask one by one mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return( @@ -182,7 +177,7 @@ func TestSchedulerRun(t *testing.T) { err = scheduler.Run(runCtx, task) require.NoError(t, err) - // 8. cancel + // 6. cancel mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending}}, nil) @@ -191,20 +186,27 @@ func TestSchedulerRun(t *testing.T) { unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending}, nil) mockSubtaskTable.EXPECT().StartSubtask(taskID).Return(nil) - mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(context.Canceled) + mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(ErrCancelSubtask) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateCanceled, gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) + err = scheduler.Run(runCtx, task) + require.EqualError(t, err, ErrCancelSubtask.Error()) + + // 7. RunSubtask return context.Canceled + mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, + unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ + ID: 2, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending}}, nil) + mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) + mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, + unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ + ID: 1, Type: tp, Step: proto.StepOne, State: proto.TaskStatePending}, nil) + mockSubtaskTable.EXPECT().StartSubtask(taskID).Return(nil) + mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(context.Canceled) + mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) + err = scheduler.Run(runCtx, task) + require.EqualError(t, err, context.Canceled.Error()) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - err = scheduler.Run(runCtx, task) - require.EqualError(t, err, context.Canceled.Error()) - }() - time.Sleep(time.Second) runCancel() - wg.Wait() } func TestSchedulerRollback(t *testing.T) { @@ -249,17 +251,7 @@ func TestSchedulerRollback(t *testing.T) { err = scheduler.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.NoError(t, err) - // 4. update subtask error - updateSubtaskErr := errors.New("update subtask error") - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, nil) - mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, - unfinishedRevertSubtaskStates...).Return(&proto.Subtask{ID: 1, State: proto.TaskStateRevertPending}, nil) - mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateReverting, nil).Return(updateSubtaskErr) - err = scheduler.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) - require.EqualError(t, err, updateSubtaskErr.Error()) - - // 5. rollback failed + // 4. rollback failed rollbackErr := errors.New("rollback error") mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) @@ -271,7 +263,7 @@ func TestSchedulerRollback(t *testing.T) { err = scheduler.Rollback(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) require.EqualError(t, err, rollbackErr.Error()) - // 6. rollback success + // 5. rollback success mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(&proto.Subtask{ID: 1}, nil) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(int64(1), proto.TaskStateCanceled, nil).Return(nil) @@ -336,13 +328,14 @@ func TestScheduler(t *testing.T) { mockSubtaskTable := mock.NewMockTaskTable(ctrl) mockSubtaskExecutor := mockexecute.NewMockSubtaskExecutor(ctrl) mockExtension := mock.NewMockExtension(ctrl) + mockSubtaskTable.EXPECT().IsSchedulerCanceled(gomock.Any(), gomock.Any()).Return(false, nil).AnyTimes() mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil).AnyTimes() scheduler := NewBaseScheduler(ctx, "id", 1, mockSubtaskTable) scheduler.Extension = mockExtension - // run failed + // 1. run failed. runSubtaskErr := errors.New("run subtask error") mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetSubtasksInStates("id", taskID, proto.StepOne, @@ -355,10 +348,10 @@ func TestScheduler(t *testing.T) { mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).Return(runSubtaskErr) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateFailed, gomock.Any()).Return(nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) - err := scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}) + err := scheduler.run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}) require.EqualError(t, err, runSubtaskErr.Error()) - // rollback success + // 2. rollback success. mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(nil, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates("id", taskID, proto.StepOne, diff --git a/tests/realtikvtest/addindextest/add_index_test.go b/tests/realtikvtest/addindextest/add_index_test.go index 6f13913b97ed4..f0ca7905cb496 100644 --- a/tests/realtikvtest/addindextest/add_index_test.go +++ b/tests/realtikvtest/addindextest/add_index_test.go @@ -164,6 +164,11 @@ func TestAddIndexDistBasic(t *testing.T) { tk.MustExec("split table t1 between (3) and (8646911284551352360) regions 50;") tk.MustExec("alter table t1 add index idx(a);") tk.MustExec("admin check index t1 idx;") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskContextCanceled", "1*return(true)")) + tk.MustExec("alter table t1 add index idx1(a);") + tk.MustExec("admin check index t1 idx1;") + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskContextCanceled")) tk.MustExec(`set global tidb_enable_dist_task=0;`) }