diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel index d9a6a23857790..5675ebaaac738 100644 --- a/disttask/framework/dispatcher/BUILD.bazel +++ b/disttask/framework/dispatcher/BUILD.bazel @@ -40,7 +40,7 @@ go_test( embed = [":dispatcher"], flaky = True, race = "on", - shard_count = 7, + shard_count = 8, deps = [ "//disttask/framework/proto", "//disttask/framework/storage", diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 5eb8daa6eb7cf..653630f348361 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -213,8 +213,12 @@ func (d *dispatcher) handleRunning() error { func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { prevState := d.task.State d.task.State = taskState + if !VerifyTaskStateTransform(prevState, taskState) { + return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState) + } + for i := 0; i < retryTimes; i++ { - err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.task, newSubTasks) + err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.task, newSubTasks, prevState) if err == nil { break } @@ -304,7 +308,6 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, handle TaskFlowHandle, me // TODO: Consider using TS. nowTime := time.Now().UTC() task.StartTime = nowTime - task.State = proto.TaskStateRunning task.StateUpdateTime = nowTime retryTimes = nonRetrySQLTime } @@ -339,7 +342,7 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, handle TaskFlowHandle, me subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, instanceID, meta)) } - return d.updateTask(task.State, subTasks, retrySQLTimes) + return d.updateTask(proto.TaskStateRunning, subTasks, retrySQLTimes) } // GenerateSchedulerNodes generate a eligible TiDB nodes. @@ -405,3 +408,59 @@ func (d *dispatcher) WithNewSession(fn func(se sessionctx.Context) error) error func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { return d.taskMgr.WithNewTxn(ctx, fn) } + +// VerifyTaskStateTransform verifies whether the task state transform is valid. +func VerifyTaskStateTransform(oldState, newState string) bool { + rules := map[string][]string{ + proto.TaskStatePending: { + proto.TaskStateRunning, + proto.TaskStateCancelling, + proto.TaskStatePausing, + proto.TaskStateSucceed, + proto.TaskStateReverted, + }, + proto.TaskStateRunning: { + proto.TaskStateSucceed, + proto.TaskStateReverting, + proto.TaskStateReverted, + proto.TaskStateCancelling, + proto.TaskStatePausing, + }, + proto.TaskStateSucceed: {}, + proto.TaskStateReverting: { + proto.TaskStateReverted, + // no revert_failed now + // proto.TaskStateRevertFailed, + }, + proto.TaskStateFailed: {}, + proto.TaskStateRevertFailed: {}, + proto.TaskStateCancelling: { + proto.TaskStateReverting, + // no canceled now + // proto.TaskStateCanceled, + }, + proto.TaskStateCanceled: {}, + proto.TaskStatePausing: { + proto.TaskStatePaused, + }, + proto.TaskStatePaused: { + proto.TaskStateResuming, + }, + proto.TaskStateResuming: { + proto.TaskStateRunning, + }, + proto.TaskStateRevertPending: {}, + proto.TaskStateReverted: {}, + } + + if oldState == newState { + return true + } + + for _, state := range rules[oldState] { + if state == newState { + return true + } + } + return false +} diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 99923fe1c4cad..3d06bec277d2f 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -355,3 +355,25 @@ func (NumberExampleHandle) GetEligibleInstances(ctx context.Context, _ *proto.Ta func (NumberExampleHandle) IsRetryableErr(error) bool { return true } + +func TestVerifyTaskStateTransform(t *testing.T) { + testCases := []struct { + oldState string + newState string + expect bool + }{ + {proto.TaskStateRunning, proto.TaskStateRunning, true}, + {proto.TaskStatePending, proto.TaskStateRunning, true}, + {proto.TaskStatePending, proto.TaskStateReverting, false}, + {proto.TaskStateRunning, proto.TaskStateReverting, true}, + {proto.TaskStateReverting, proto.TaskStateReverted, true}, + {proto.TaskStateReverting, proto.TaskStateSucceed, false}, + {proto.TaskStateRunning, proto.TaskStatePausing, true}, + {proto.TaskStateRunning, proto.TaskStateResuming, false}, + {proto.TaskStateCancelling, proto.TaskStateRunning, false}, + {proto.TaskStateCanceled, proto.TaskStateRunning, false}, + } + for _, tc := range testCases { + require.Equal(t, tc.expect, dispatcher.VerifyTaskStateTransform(tc.oldState, tc.newState)) + } +} diff --git a/disttask/framework/proto/task.go b/disttask/framework/proto/task.go index c900af8da519e..a05eba5592a88 100644 --- a/disttask/framework/proto/task.go +++ b/disttask/framework/proto/task.go @@ -41,6 +41,7 @@ const ( TaskStateCanceled = "canceled" TaskStatePausing = "pausing" TaskStatePaused = "paused" + TaskStateResuming = "resuming" TaskStateRevertPending = "revert_pending" TaskStateReverted = "reverted" ) diff --git a/disttask/framework/storage/table_test.go b/disttask/framework/storage/table_test.go index d9c4ce718a5ef..1acd01f012d38 100644 --- a/disttask/framework/storage/table_test.go +++ b/disttask/framework/storage/table_test.go @@ -80,8 +80,9 @@ func TestGlobalTaskTable(t *testing.T) { require.Len(t, task4, 1) require.Equal(t, task, task4[0]) + prevState := task.State task.State = proto.TaskStateRunning - err = gm.UpdateGlobalTaskAndAddSubTasks(task, nil) + err = gm.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState) require.NoError(t, err) task5, err := gm.GetGlobalTasksInStates(proto.TaskStateRunning) @@ -238,6 +239,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, proto.TaskStatePending, task.State) // isSubTaskRevert: false + prevState := task.State task.State = proto.TaskStateRunning subTasks := []*proto.Subtask{ { @@ -251,7 +253,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { Meta: []byte("m2"), }, } - err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks) + err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState) require.NoError(t, err) task, err = sm.GetGlobalTaskByID(1) @@ -275,6 +277,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { require.Equal(t, int64(2), cnt) // isSubTaskRevert: true + prevState = task.State task.State = proto.TaskStateReverting subTasks = []*proto.Subtask{ { @@ -288,7 +291,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { Meta: []byte("m4"), }, } - err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks) + err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState) require.NoError(t, err) task, err = sm.GetGlobalTaskByID(1) @@ -317,8 +320,9 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) { defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/storage/MockUpdateTaskErr")) }() + prevState = task.State task.State = proto.TaskStateFailed - err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks) + err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState) require.EqualError(t, err, "updateTaskErr") task, err = sm.GetGlobalTaskByID(1) diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index f585c4af8abde..5264da9561b12 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -468,10 +468,10 @@ func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) } // UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks -func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask) error { +func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) error { return stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error { - _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?", - gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID) + _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %? and state = %?", + gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID, prevState) if err != nil { return err }