Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distribute framework: restrict the task state transform rule #45932

Merged
merged 12 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion disttask/framework/dispatcher/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ go_test(
embed = [":dispatcher"],
flaky = True,
race = "on",
shard_count = 7,
shard_count = 8,
deps = [
"//disttask/framework/proto",
"//disttask/framework/storage",
Expand Down
73 changes: 70 additions & 3 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
retrySQLInterval = 500 * time.Millisecond
)

// TestSyncChan is used to test.
var TestSyncChan = make(chan struct{})

// TaskHandle provides the interface for operations needed by task flow handles.
type TaskHandle interface {
// GetAllSchedulerIDs gets handles the task's all scheduler instances.
Expand Down Expand Up @@ -213,8 +216,17 @@
func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
prevState := d.task.State
d.task.State = taskState
if !VerifyTaskStateTransform(prevState, taskState) {
return errors.Errorf("invalid task state transform, from %s to %s", prevState, taskState)
}

Check warning on line 221 in disttask/framework/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/dispatcher/dispatcher.go#L220-L221

Added lines #L220 - L221 were not covered by tests
failpoint.Inject("syncInUpdateTask", func(_ failpoint.Value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems not used.

logutil.Logger(d.logCtx).Info("syncInUpdateTask called")
TestSyncChan <- struct{}{}
<-TestSyncChan
})

Check warning on line 226 in disttask/framework/dispatcher/dispatcher.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/dispatcher/dispatcher.go#L223-L226

Added lines #L223 - L226 were not covered by tests

for i := 0; i < retryTimes; i++ {
err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.task, newSubTasks)
err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.task, newSubTasks, prevState)
if err == nil {
break
}
Expand Down Expand Up @@ -304,7 +316,6 @@
// TODO: Consider using TS.
nowTime := time.Now().UTC()
task.StartTime = nowTime
task.State = proto.TaskStateRunning
task.StateUpdateTime = nowTime
retryTimes = nonRetrySQLTime
}
Expand Down Expand Up @@ -339,7 +350,7 @@
subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, instanceID, meta))
}

return d.updateTask(task.State, subTasks, retrySQLTimes)
return d.updateTask(proto.TaskStateRunning, subTasks, retrySQLTimes)
}

// GenerateSchedulerNodes generate a eligible TiDB nodes.
Expand Down Expand Up @@ -405,3 +416,59 @@
func (d *dispatcher) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewTxn(ctx, fn)
}

// VerifyTaskStateTransform verifies whether the task state transform is valid.
func VerifyTaskStateTransform(oldState, newState string) bool {
rules := map[string][]string{
proto.TaskStatePending: {
proto.TaskStateRunning,
proto.TaskStateCancelling,
proto.TaskStatePausing,
proto.TaskStateSucceed,
proto.TaskStateReverted,
},
proto.TaskStateRunning: {
proto.TaskStateSucceed,
proto.TaskStateReverting,
proto.TaskStateReverted,
proto.TaskStateCancelling,
proto.TaskStatePausing,
},
proto.TaskStateSucceed: {},
proto.TaskStateReverting: {
proto.TaskStateReverted,
// no revert_failed now
// proto.TaskStateRevertFailed,
},
proto.TaskStateFailed: {},
proto.TaskStateRevertFailed: {},
proto.TaskStateCancelling: {
proto.TaskStateReverting,
// no canceled now
// proto.TaskStateCanceled,
},
proto.TaskStateCanceled: {},
proto.TaskStatePausing: {
proto.TaskStatePaused,
},
proto.TaskStatePaused: {
proto.TaskStateResuming,
},
proto.TaskStateResuming: {
proto.TaskStateRunning,
},
proto.TaskStateRevertPending: {},
proto.TaskStateReverted: {},
}

if oldState == newState {
return true
}

for _, state := range rules[oldState] {
if state == newState {
return true
}
}
return false
}
22 changes: 22 additions & 0 deletions disttask/framework/dispatcher/dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,25 @@ func (NumberExampleHandle) GetEligibleInstances(ctx context.Context, _ *proto.Ta
func (NumberExampleHandle) IsRetryableErr(error) bool {
return true
}

func TestVerifyTaskStateTransform(t *testing.T) {
testCases := []struct {
oldState string
newState string
expect bool
}{
{proto.TaskStateRunning, proto.TaskStateRunning, true},
{proto.TaskStatePending, proto.TaskStateRunning, true},
{proto.TaskStatePending, proto.TaskStateReverting, false},
{proto.TaskStateRunning, proto.TaskStateReverting, true},
{proto.TaskStateReverting, proto.TaskStateReverted, true},
{proto.TaskStateReverting, proto.TaskStateSucceed, false},
{proto.TaskStateRunning, proto.TaskStatePausing, true},
{proto.TaskStateRunning, proto.TaskStateResuming, false},
{proto.TaskStateCancelling, proto.TaskStateRunning, false},
{proto.TaskStateCanceled, proto.TaskStateRunning, false},
}
for _, tc := range testCases {
require.Equal(t, tc.expect, dispatcher.VerifyTaskStateTransform(tc.oldState, tc.newState))
}
}
1 change: 1 addition & 0 deletions disttask/framework/proto/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ const (
TaskStateCanceled = "canceled"
TaskStatePausing = "pausing"
TaskStatePaused = "paused"
TaskStateResuming = "resuming"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need more discussion to introduce a new state.
For now, LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further more, we can distinguish subtask state and task state later.

TaskStateRevertPending = "revert_pending"
TaskStateReverted = "reverted"
)
Expand Down
12 changes: 8 additions & 4 deletions disttask/framework/storage/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ func TestGlobalTaskTable(t *testing.T) {
require.Len(t, task4, 1)
require.Equal(t, task, task4[0])

prevState := task.State
task.State = proto.TaskStateRunning
err = gm.UpdateGlobalTaskAndAddSubTasks(task, nil)
err = gm.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState)
require.NoError(t, err)

task5, err := gm.GetGlobalTasksInStates(proto.TaskStateRunning)
Expand Down Expand Up @@ -238,6 +239,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
require.Equal(t, proto.TaskStatePending, task.State)

// isSubTaskRevert: false
prevState := task.State
task.State = proto.TaskStateRunning
subTasks := []*proto.Subtask{
{
Expand All @@ -251,7 +253,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
Meta: []byte("m2"),
},
}
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks)
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState)
require.NoError(t, err)

task, err = sm.GetGlobalTaskByID(1)
Expand All @@ -275,6 +277,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
require.Equal(t, int64(2), cnt)

// isSubTaskRevert: true
prevState = task.State
task.State = proto.TaskStateReverting
subTasks = []*proto.Subtask{
{
Expand All @@ -288,7 +291,7 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
Meta: []byte("m4"),
},
}
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks)
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState)
require.NoError(t, err)

task, err = sm.GetGlobalTaskByID(1)
Expand Down Expand Up @@ -317,8 +320,9 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/storage/MockUpdateTaskErr"))
}()
prevState = task.State
task.State = proto.TaskStateFailed
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks)
err = sm.UpdateGlobalTaskAndAddSubTasks(task, subTasks, prevState)
require.EqualError(t, err, "updateTaskErr")

task, err = sm.GetGlobalTaskByID(1)
Expand Down
6 changes: 3 additions & 3 deletions disttask/framework/storage/task_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,10 +468,10 @@
}

// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask) error {
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) error {

Check warning on line 471 in disttask/framework/storage/task_table.go

View check run for this annotation

Codecov / codecov/patch

disttask/framework/storage/task_table.go#L471

Added line #L471 was not covered by tests
return stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error {
_, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?",
gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID)
_, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %? and state = %?",
gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID, prevState)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need check how many lines're affected, if 0, we shouldn't update subtasks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need check how many lines're affected, if 0, we shouldn't update subtasks

I will fix it.

if err != nil {
return err
}
Expand Down