diff --git a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc index 616178888b751..5a98704fe9c01 100644 --- a/paddle/fluid/distributed/fleet_executor/runtime_graph.cc +++ b/paddle/fluid/distributed/fleet_executor/runtime_graph.cc @@ -150,7 +150,7 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) { int64_t role_id = static_cast(role); int64_t max_run_times = num_micro_batches; int64_t max_slot_nums = start_up_steps; - if (role_id == 2 || role_id == 16) { + if (IsLRSched(role_id) || IsOptimize(role_id)) { max_run_times = 1; max_slot_nums = 1; } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index b241f5d8c4da1..aba64a51e3b95 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1969,6 +1969,7 @@ def _run_using_fleet_executor(self, fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"] fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"] fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"] + if "num_micro_batches" in fleet_opt: fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"] num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."