diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto index 766463eceae56..1b12f1239dcbd 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor_desc.proto @@ -27,4 +27,6 @@ message FleetExecutorDesc { optional int32 dp_degree = 4 [ default = 1 ]; optional int32 mp_degree = 5 [ default = 1 ]; optional int32 pp_degree = 6 [ default = 1 ]; + optional int64 num_micro_batches = 7 [ default = 1 ]; + optional int64 num_slots = 8 [ default = 1 ]; } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.cc b/paddle/fluid/distributed/fleet_executor/interceptor.cc index 696f7dd752eec..e4ae04be53a9a 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/interceptor.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" +#include "paddle/fluid/distributed/fleet_executor/task_node.h" namespace paddle { namespace distributed { @@ -21,7 +22,8 @@ namespace distributed { Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node) : interceptor_id_(interceptor_id), node_(node) { interceptor_thread_ = std::thread([this]() { - VLOG(3) << "Start pooling local mailbox's thread."; + VLOG(3) << "Interceptor " << interceptor_id_ + << " starts the thread pooling it's local mailbox."; PoolTheMailbox(); }); } diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 2e86dc2fe525d..adc10022beb43 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -96,6 +96,9 @@ class Interceptor { // local mailbox, written by FetchRemoteMailbox() // read by PoolTheMailbox() std::queue local_mailbox_; + + int64_t already_run_times_{0}; + int64_t used_slot_nums_{0}; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc index e0fbecf2ca995..5a98704fe9c01 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc @@ -136,16 +136,31 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { role_to_ops.at(new_op_role_id).emplace_back(op.get()); } int64_t cur_rank = exe_desc_.cur_rank(); + DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(), + exe_desc_.mp_degree()); + const auto& coord = coord_sys.RankToCoord(cur_rank); + int pipeline_stage = coord.pp_idx; + int64_t num_pipeline_stages = exe_desc_.pp_degree(); + // TODO(fleet_executor dev): start up steps should be a config `num_slots` + int64_t start_up_steps = num_pipeline_stages - pipeline_stage - 1; + int64_t num_micro_batches = exe_desc_.num_micro_batches(); int64_t task_id = cur_rank * functionality_order.size(); for (std::size_t i = 0; i < functionality_order.size(); ++i) { OpRole role = functionality_order[i]; int64_t role_id = static_cast(role); + int64_t max_run_times = num_micro_batches; + int64_t max_slot_nums = start_up_steps; + if (IsLRSched(role_id) || IsOptimize(role_id)) { + max_run_times = 1; + max_slot_nums = 1; + } if (role_to_ops.find(role_id) == role_to_ops.end()) { - task_nodes_.emplace_back( - TaskNode::CreateEmptyTaskNode(role_id, cur_rank, task_id)); + task_nodes_.emplace_back(TaskNode::CreateEmptyTaskNode( + role_id, cur_rank, task_id, max_run_times, max_slot_nums)); } else { - task_nodes_.emplace_back(TaskNode::CreateTaskNode( - role_id, role_to_ops.at(role_id), cur_rank, task_id)); + task_nodes_.emplace_back( + TaskNode::CreateTaskNode(role_id, role_to_ops.at(role_id), cur_rank, + task_id, max_run_times, max_slot_nums)); } ++task_id; } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index de85871af5181..1a20b4c32b505 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -22,22 +22,37 @@ using OperatorBase = TaskNode::OperatorBase; } TaskNode::TaskNode(int64_t role, const std::vector& ops, - int64_t rank, int64_t task_id) - : ops_(ops), role_(role), rank_(rank), task_id_(task_id) {} + int64_t rank, int64_t task_id, int64_t max_run_times, + int64_t max_slot_nums) + : ops_(ops), + role_(role), + rank_(rank), + task_id_(task_id), + max_run_times_(max_run_times), + max_slot_nums_(max_slot_nums) {} -TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id) - : role_(role), rank_(rank), task_id_(task_id) {} +TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id, + int64_t max_run_times, int64_t max_slot_nums) + : role_(role), + rank_(rank), + task_id_(task_id), + max_run_times_(max_run_times), + max_slot_nums_(max_slot_nums) {} std::unique_ptr TaskNode::CreateEmptyTaskNode(int64_t role, int64_t rank, - int64_t task_id) { - return std::make_unique(role, rank, task_id); + int64_t task_id, + int64_t max_run_times, + int64_t max_slot_nums) { + return std::make_unique(role, rank, task_id, max_run_times, + max_slot_nums); } std::unique_ptr TaskNode::CreateTaskNode( int64_t role, const std::vector& ops, int64_t rank, - int64_t task_id) { - return std::make_unique(role, ops, rank, task_id); + int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) { + return std::make_unique(role, ops, rank, task_id, max_run_times, + max_slot_nums); } void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index e341f52507144..a90106d01d26d 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -28,23 +28,28 @@ namespace distributed { class TaskNode final { public: using OperatorBase = paddle::framework::OperatorBase; - TaskNode(int64_t role, int64_t rank, int64_t task_id); + TaskNode(int64_t role, int64_t rank, int64_t task_id, int64_t max_run_times, + int64_t max_slot_nums); TaskNode(int64_t role, const std::vector& ops, int64_t rank, - int64_t task_id); + int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); ~TaskNode() = default; int64_t rank() const { return rank_; } int64_t task_id() const { return task_id_; } int64_t role() const { return role_; } + int64_t max_run_times() const { return max_run_times_; } + int64_t max_slot_nums() const { return max_slot_nums_; } const std::unordered_set& upstream() const { return upstream_; } const std::unordered_set& downstream() const { return downstream_; } void AddUpstreamTask(int64_t task_id); void AddDownstreamTask(int64_t task_id); static std::unique_ptr CreateEmptyTaskNode(int64_t role, int64_t rank, - int64_t task_id); + int64_t task_id, + int64_t max_run_times, + int64_t max_slot_nums); static std::unique_ptr CreateTaskNode( int64_t role, const std::vector& ops, int64_t rank, - int64_t task_id); + int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); private: DISABLE_COPY_AND_ASSIGN(TaskNode); @@ -55,6 +60,8 @@ class TaskNode final { int64_t role_; int64_t rank_; int64_t task_id_; + int64_t max_run_times_; + int64_t max_slot_nums_; }; } // namespace distributed diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c493a420b946b..aba64a51e3b95 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1969,6 +1969,8 @@ def _run_using_fleet_executor(self, fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"] fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"] fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"] + if "num_micro_batches" in fleet_opt: + fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu." fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString()) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py index 473d49fea4878..adffd228591b9 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_multi_devices.py @@ -43,7 +43,11 @@ def test_dist_executor_on_multi_devices(self): "mp_degree": 2, "pp_degree": 2 } - fleet_opt = {"dist_strategy": strategy.sharding_configs} + strategy.pipeline_configs = {"accumulate_steps": 8} + fleet_opt = { + "dist_strategy": strategy.sharding_configs, + "num_micro_batches": strategy.pipeline_configs["accumulate_steps"] + } if fluid.is_compiled_with_cuda(): self.run_fleet_executor(fluid.CUDAPlace(0), fleet_opt)