diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 94844985a8e2b..bae1e6882d1f5 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -247,7 +247,11 @@ func (d *dispatcher) onRunning() error { func (d *dispatcher) replaceDeadNodesIfAny() error { if len(d.taskNodes) == 0 { - return errors.Errorf("len(d.taskNodes) == 0, onNextStage is not invoked before onRunning") + var err error + d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.task.ID, d.task.Step) + if err != nil { + return err + } } d.liveNodeFetchTick++ if d.liveNodeFetchTick == d.liveNodeFetchInterval { diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index 9238db6d972b4..ca00f9e5eca2b 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -511,6 +511,26 @@ func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) return instanceIDs, nil } +// GetSchedulerIDsByTaskIDAndStep gets the scheduler IDs of the given global task ID and step. +func (stm *TaskManager) GetSchedulerIDsByTaskIDAndStep(taskID int64, step int64) ([]string, error) { + rs, err := stm.executeSQLWithNewSession(stm.ctx, `select distinct(exec_id) from mysql.tidb_background_subtask + where task_key = %? and step = %?`, taskID, step) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + + instanceIDs := make([]string, 0, len(rs)) + for _, r := range rs { + id := r.GetString(0) + instanceIDs = append(instanceIDs, id) + } + + return instanceIDs, nil +} + // IsSchedulerCanceled checks if subtask 'execID' of task 'taskID' has been canceled somehow. func (stm *TaskManager) IsSchedulerCanceled(taskID int64, execID string) (bool, error) { rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where task_key = %? and exec_id = %?", taskID, execID)