Skip to content

Commit

Permalink
disttask: add pending to failed state transform (#46357)
Browse files Browse the repository at this point in the history
ref #46258
  • Loading branch information
ywqzzy authored Aug 25, 2023
1 parent 66ddb7b commit c66d28f
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 63 deletions.
3 changes: 2 additions & 1 deletion disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ go_test(
name = "framework_test",
timeout = "short",
srcs = [
"framework_err_handling_test.go",
"framework_rollback_test.go",
"framework_test.go",
],
flaky = True,
race = "on",
shard_count = 11,
shard_count = 14,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/proto",
Expand Down
92 changes: 50 additions & 42 deletions disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ const (
)

var (
// DefaultDispatchConcurrency is the default concurrency for handling global task.
// DefaultDispatchConcurrency is the default concurrency for handling task.
DefaultDispatchConcurrency = 4
checkTaskFinishedInterval = 500 * time.Millisecond
checkTaskRunningInterval = 300 * time.Millisecond
Expand Down Expand Up @@ -73,11 +73,11 @@ var MockOwnerChange func()
func newDispatcher(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) *dispatcher {
logPrefix := fmt.Sprintf("task_id: %d, task_type: %s, server_id: %s", task.ID, task.Type, serverID)
return &dispatcher{
ctx,
taskMgr,
task,
logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix),
serverID,
ctx: ctx,
taskMgr: taskMgr,
task: task,
logCtx: logutil.WithKeyValue(context.Background(), "dispatcher", logPrefix),
serverID: serverID,
}
}

Expand Down Expand Up @@ -122,14 +122,14 @@ func (d *dispatcher) scheduleTask() {
})
switch d.task.State {
case proto.TaskStateCancelling:
err = d.handleCancelling()
err = d.onCancelling()
case proto.TaskStateReverting:
err = d.handleReverting()
err = d.onReverting()
case proto.TaskStatePending:
err = d.handlePending()
err = d.onPending()
case proto.TaskStateRunning:
err = d.handleRunning()
case proto.TaskStateSucceed, proto.TaskStateReverted:
err = d.onRunning()
case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed:
logutil.Logger(d.logCtx).Info("schedule task, task is finished", zap.String("state", d.task.State))
return
}
Expand All @@ -149,15 +149,15 @@ func (d *dispatcher) scheduleTask() {
}

// handle task in cancelling state, dispatch revert subtasks.
func (d *dispatcher) handleCancelling() error {
logutil.Logger(d.logCtx).Debug("handle cancelling state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
func (d *dispatcher) onCancelling() error {
logutil.Logger(d.logCtx).Debug("on cancelling state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
errs := []error{errors.New("cancel")}
return d.processErrFlow(errs)
return d.onErrHandlingStage(errs)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *dispatcher) handleReverting() error {
logutil.Logger(d.logCtx).Debug("handle reverting state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
func (d *dispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
Expand All @@ -171,28 +171,28 @@ func (d *dispatcher) handleReverting() error {
}
// Wait all subtasks in this stage finished.
GetTaskFlowHandle(d.task.Type).OnTicker(d.ctx, d.task)
logutil.Logger(d.logCtx).Debug("handle reverting state, this task keeps current state", zap.String("state", d.task.State))
logutil.Logger(d.logCtx).Debug("on reverting state, this task keeps current state", zap.String("state", d.task.State))
return nil
}

// handle task in pending state, dispatch subtasks.
func (d *dispatcher) handlePending() error {
logutil.Logger(d.logCtx).Debug("handle pending state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
return d.processNormalFlow()
func (d *dispatcher) onPending() error {
logutil.Logger(d.logCtx).Debug("on pending state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
return d.onNextStage()
}

// handle task in running state, check all running subtasks finished.
// If subtasks finished, run into the next stage.
func (d *dispatcher) handleRunning() error {
logutil.Logger(d.logCtx).Debug("handle running state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
func (d *dispatcher) onRunning() error {
logutil.Logger(d.logCtx).Debug("on running state", zap.String("state", d.task.State), zap.Int64("stage", d.task.Step))
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
}
if len(subTaskErrs) > 0 {
logutil.Logger(d.logCtx).Warn("subtasks encounter errors")
return d.processErrFlow(subTaskErrs)
return d.onErrHandlingStage(subTaskErrs)
}
// check current stage finished.
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.task.ID, proto.TaskStatePending, proto.TaskStateRunning)
Expand All @@ -204,11 +204,11 @@ func (d *dispatcher) handleRunning() error {
prevStageFinished := cnt == 0
if prevStageFinished {
logutil.Logger(d.logCtx).Info("previous stage finished, generate dist plan", zap.Int64("stage", d.task.Step))
return d.processNormalFlow()
return d.onNextStage()
}
// Wait all subtasks in this stage finished.
GetTaskFlowHandle(d.task.Type).OnTicker(d.ctx, d.task)
logutil.Logger(d.logCtx).Debug("handing running state, this task keeps current state", zap.String("state", d.task.State))
logutil.Logger(d.logCtx).Debug("on running state, this task keeps current state", zap.String("state", d.task.State))
return nil
}

Expand All @@ -232,28 +232,30 @@ func (d *dispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask,
break
}
if i%10 == 0 {
logutil.Logger(d.logCtx).Warn("updateTask first failed", zap.String("previous state", prevState), zap.String("curr state", d.task.State),
logutil.Logger(d.logCtx).Warn("updateTask first failed", zap.String("from", prevState), zap.String("to", d.task.State),
zap.Int("retry times", retryTimes), zap.Error(err))
}
time.Sleep(retrySQLInterval)
}
if err != nil && retryTimes != nonRetrySQLTime {
logutil.Logger(d.logCtx).Warn("updateTask failed",
zap.String("previous state", prevState), zap.String("curr state", d.task.State), zap.Int("retry times", retryTimes), zap.Error(err))
zap.String("from", prevState), zap.String("to", d.task.State), zap.Int("retry times", retryTimes), zap.Error(err))
}
return err
}

func (d *dispatcher) processErrFlow(receiveErr []error) error {
func (d *dispatcher) onErrHandlingStage(receiveErr []error) error {
// TODO: Maybe it gets GetTaskFlowHandle fails when rolling upgrades.
// 1. generate the needed global task meta and subTask meta (dist-plan).
// 1. generate the needed task meta and subTask meta (dist-plan).
handle := GetTaskFlowHandle(d.task.Type)
if handle == nil {
logutil.Logger(d.logCtx).Warn("gen task flow handle failed, this type handle doesn't register")
return d.updateTask(proto.TaskStateReverted, nil, retrySQLTimes)
// state transform: pending --> running --> canceling --> failed.
return d.updateTask(proto.TaskStateFailed, nil, retrySQLTimes)
}
meta, err := handle.ProcessErrFlow(d.ctx, d, d.task, receiveErr)
if err != nil {
// processErrFlow must be retryable, if not, there will have resource leak for tasks.
logutil.Logger(d.logCtx).Warn("handle error failed", zap.Error(err))
return err
}
Expand All @@ -276,22 +278,18 @@ func (d *dispatcher) dispatchSubTask4Revert(task *proto.Task, handle TaskFlowHan
return d.updateTask(proto.TaskStateReverting, subTasks, retrySQLTimes)
}

func (d *dispatcher) processNormalFlow() error {
func (d *dispatcher) onNextStage() error {
// 1. generate the needed global task meta and subTask meta (dist-plan).
handle := GetTaskFlowHandle(d.task.Type)
if handle == nil {
logutil.Logger(d.logCtx).Warn("gen task flow handle failed, this type handle doesn't register", zap.String("type", d.task.Type))
d.task.Error = errors.New("unsupported task type")
return d.updateTask(proto.TaskStateReverted, nil, retrySQLTimes)
// state transform: pending -> failed.
return d.updateTask(proto.TaskStateFailed, nil, retrySQLTimes)
}
metas, err := handle.ProcessNormalFlow(d.ctx, d, d.task)
if err != nil {
logutil.Logger(d.logCtx).Warn("generate dist-plan failed", zap.Error(err))
if handle.IsRetryableErr(err) {
return err
}
d.task.Error = err
return d.updateTask(proto.TaskStateReverted, nil, retrySQLTimes)
return d.handlePlanErr(handle, err)
}
// 2. dispatch dist-plan to EligibleInstances.
return d.dispatchSubTask(d.task, handle, metas)
Expand Down Expand Up @@ -322,13 +320,13 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, handle TaskFlowHandle, me
// Write the global task meta into the storage.
err := d.updateTask(proto.TaskStateSucceed, nil, retryTimes)
if err != nil {
logutil.Logger(d.logCtx).Warn("update global task failed", zap.Error(err))
logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err))
return err
}
return nil
}

// 3. select all available TiDB nodes for this global tasks.
// 3. select all available TiDB nodes for task.
serverNodes, err := handle.GetEligibleInstances(d.ctx, task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))

Expand All @@ -349,6 +347,16 @@ func (d *dispatcher) dispatchSubTask(task *proto.Task, handle TaskFlowHandle, me
return d.updateTask(proto.TaskStateRunning, subTasks, retrySQLTimes)
}

func (d *dispatcher) handlePlanErr(handle TaskFlowHandle, err error) error {
logutil.Logger(d.logCtx).Warn("generate plan failed", zap.Error(err), zap.String("state", d.task.State))
if handle.IsRetryableErr(err) {
return err
}
d.task.Error = err
// state transform: pending -> failed.
return d.updateTask(proto.TaskStateFailed, nil, retrySQLTimes)
}

// GenerateSchedulerNodes generate a eligible TiDB nodes.
func GenerateSchedulerNodes(ctx context.Context) ([]*infosync.ServerInfo, error) {
serverInfos, err := infosync.GetAllServerInfo(ctx)
Expand Down Expand Up @@ -421,12 +429,12 @@ func VerifyTaskStateTransform(from, to string) bool {
proto.TaskStateCancelling,
proto.TaskStatePausing,
proto.TaskStateSucceed,
proto.TaskStateReverted,
proto.TaskStateFailed,
},
proto.TaskStateRunning: {
proto.TaskStateSucceed,
proto.TaskStateReverting,
proto.TaskStateReverted,
proto.TaskStateFailed,
proto.TaskStateCancelling,
proto.TaskStatePausing,
},
Expand Down
130 changes: 130 additions & 0 deletions disttask/framework/framework_err_handling_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package framework_test

import (
"context"
"errors"
"sync"
"testing"

"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/scheduler"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/testkit"
)

type planErrFlowHandle struct {
callTime int
}

var _ dispatcher.TaskFlowHandle = (*planErrFlowHandle)(nil)

func (*planErrFlowHandle) OnTicker(_ context.Context, _ *proto.Task) {
}

func (p *planErrFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.State == proto.TaskStatePending {
if p.callTime == 0 {
p.callTime++
return nil, errors.New("retryable err")
}
gTask.Step = proto.StepOne
return [][]byte{
[]byte("task1"),
[]byte("task2"),
[]byte("task3"),
}, nil
}
if gTask.Step == proto.StepOne {
gTask.Step = proto.StepTwo
return [][]byte{
[]byte("task4"),
}, nil
}
return nil, nil
}

func (p *planErrFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) {
if p.callTime == 1 {
p.callTime++
return nil, errors.New("not retryable err")
}
return []byte("planErrTask"), nil
}

func (*planErrFlowHandle) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return generateSchedulerNodes4Test()
}

func (*planErrFlowHandle) IsRetryableErr(error) bool {
return true
}

type planNotRetryableErrFlowHandle struct {
}

func (*planNotRetryableErrFlowHandle) OnTicker(_ context.Context, _ *proto.Task) {
}

func (p *planNotRetryableErrFlowHandle) ProcessNormalFlow(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrFlowHandle) ProcessErrFlow(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) {
return nil, errors.New("not retryable err")
}

func (*planNotRetryableErrFlowHandle) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return generateSchedulerNodes4Test()
}

func (*planNotRetryableErrFlowHandle) IsRetryableErr(error) bool {
return false
}

func TestPlanErr(t *testing.T) {
defer dispatcher.ClearTaskFlowHandle()
defer scheduler.ClearSchedulers()
m := sync.Map{}

RegisterTaskMeta(&m, &planErrFlowHandle{0})
distContext := testkit.NewDistExecutionContext(t, 2)
DispatchTaskAndCheckSuccess("key1", t, &m)
distContext.Close()
}

func TestRevertPlanErr(t *testing.T) {
defer dispatcher.ClearTaskFlowHandle()
defer scheduler.ClearSchedulers()
m := sync.Map{}

RegisterTaskMeta(&m, &planErrFlowHandle{0})
distContext := testkit.NewDistExecutionContext(t, 2)
DispatchTaskAndCheckSuccess("key1", t, &m)
distContext.Close()
}

func TestPlanNotRetryableErr(t *testing.T) {
defer dispatcher.ClearTaskFlowHandle()
defer scheduler.ClearSchedulers()
m := sync.Map{}

RegisterTaskMeta(&m, &planNotRetryableErrFlowHandle{})
distContext := testkit.NewDistExecutionContext(t, 2)
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateFailed)
distContext.Close()
}
2 changes: 1 addition & 1 deletion disttask/framework/framework_rollback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestFrameworkRollback(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskAfterRefreshTask"))
}()

DispatchTaskAndCheckFail("key2", t, &m)
DispatchTaskAndCheckState("key2", t, &m, proto.TaskStateReverted)
require.Equal(t, int32(2), rollbackCnt.Load())
rollbackCnt.Store(0)
distContext.Close()
Expand Down
Loading

0 comments on commit c66d28f

Please sign in to comment.