Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Parse pipeline config #37319

Merged
merged 5 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ message FleetExecutorDesc {
optional int32 dp_degree = 4 [ default = 1 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 pp_degree = 6 [ default = 1 ];
optional int64 num_micro_batches = 7 [ default = 1 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

大概不是num_micro_steps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在python端就是global batch size / micro batch size,所以叫num_mircro_batches?一共有多少个mirco batch?其实就是num_micro_steps

optional int64 num_slots = 8 [ default = 1 ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是啥

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

就是一次最多能跑多少步

}
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
namespace distributed {

Interceptor::Interceptor(int64_t interceptor_id, TaskNode* node)
: interceptor_id_(interceptor_id), node_(node) {
interceptor_thread_ = std::thread([this]() {
VLOG(3) << "Start pooling local mailbox's thread.";
VLOG(3) << "Interceptor " << interceptor_id_
<< " starts the thread pooling it's local mailbox.";
PoolTheMailbox();
});
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Interceptor {
// local mailbox, written by FetchRemoteMailbox()
// read by PoolTheMailbox()
std::queue<InterceptorMessage> local_mailbox_;

int64_t already_run_times_{0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个建议加到后面的compute_interceptor中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是为了fake run准备的,留着吧,后面的子类可以不用?

int64_t used_slot_nums_{0};
};

} // namespace distributed
Expand Down
23 changes: 19 additions & 4 deletions paddle/fluid/distributed/fleet_executor/runtime_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,31 @@ void RuntimeGraph::SplitProgramBasedFunctionality(const ProgramDesc& program) {
role_to_ops.at(new_op_role_id).emplace_back(op.get());
}
int64_t cur_rank = exe_desc_.cur_rank();
DistCoordSys coord_sys(exe_desc_.dp_degree(), exe_desc_.pp_degree(),
exe_desc_.mp_degree());
const auto& coord = coord_sys.RankToCoord(cur_rank);
int pipeline_stage = coord.pp_idx;
int64_t num_pipeline_stages = exe_desc_.pp_degree();
// TODO(fleet_executor dev): start up steps should be a config `num_slots`
int64_t start_up_steps = num_pipeline_stages - pipeline_stage - 1;
int64_t num_micro_batches = exe_desc_.num_micro_batches();
int64_t task_id = cur_rank * functionality_order.size();
for (std::size_t i = 0; i < functionality_order.size(); ++i) {
OpRole role = functionality_order[i];
int64_t role_id = static_cast<int64_t>(role);
int64_t max_run_times = num_micro_batches;
int64_t max_slot_nums = start_up_steps;
if (IsLRSched(role_id) || IsOptimize(role_id)) {
max_run_times = 1;
max_slot_nums = 1;
}
if (role_to_ops.find(role_id) == role_to_ops.end()) {
task_nodes_.emplace_back(
TaskNode::CreateEmptyTaskNode(role_id, cur_rank, task_id));
task_nodes_.emplace_back(TaskNode::CreateEmptyTaskNode(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可能需要有ComputeTaskNode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没理解,为啥要单搞一个新的tasknode出来?

role_id, cur_rank, task_id, max_run_times, max_slot_nums));
} else {
task_nodes_.emplace_back(TaskNode::CreateTaskNode(
role_id, role_to_ops.at(role_id), cur_rank, task_id));
task_nodes_.emplace_back(
TaskNode::CreateTaskNode(role_id, role_to_ops.at(role_id), cur_rank,
task_id, max_run_times, max_slot_nums));
}
++task_id;
}
Expand Down
31 changes: 23 additions & 8 deletions paddle/fluid/distributed/fleet_executor/task_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,37 @@ using OperatorBase = TaskNode::OperatorBase;
}

TaskNode::TaskNode(int64_t role, const std::vector<OperatorBase*>& ops,
int64_t rank, int64_t task_id)
: ops_(ops), role_(role), rank_(rank), task_id_(task_id) {}
int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}

TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id)
: role_(role), rank_(rank), task_id_(task_id) {}
TaskNode::TaskNode(int64_t role, int64_t rank, int64_t task_id,
int64_t max_run_times, int64_t max_slot_nums)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}

std::unique_ptr<TaskNode> TaskNode::CreateEmptyTaskNode(int64_t role,
int64_t rank,
int64_t task_id) {
return std::make_unique<TaskNode>(role, rank, task_id);
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, rank, task_id, max_run_times,
max_slot_nums);
}

std::unique_ptr<TaskNode> TaskNode::CreateTaskNode(
int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id) {
return std::make_unique<TaskNode>(role, ops, rank, task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums) {
return std::make_unique<TaskNode>(role, ops, rank, task_id, max_run_times,
max_slot_nums);
}

void TaskNode::AddUpstreamTask(int64_t task_id) { upstream_.insert(task_id); }
Expand Down
15 changes: 11 additions & 4 deletions paddle/fluid/distributed/fleet_executor/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,28 @@ namespace distributed {
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t role, int64_t rank, int64_t task_id);
TaskNode(int64_t role, int64_t rank, int64_t task_id, int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default;
int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; }
int64_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
void AddUpstreamTask(int64_t task_id);
void AddDownstreamTask(int64_t task_id);
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int64_t role,
int64_t rank,
int64_t task_id);
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
static std::unique_ptr<TaskNode> CreateTaskNode(
int64_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id);
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);

private:
DISABLE_COPY_AND_ASSIGN(TaskNode);
Expand All @@ -55,6 +60,8 @@ class TaskNode final {
int64_t role_;
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
};

} // namespace distributed
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,8 @@ def _run_using_fleet_executor(self,
fleet_exe_desc.dp_degree = fleet_opt["dist_strategy"]["dp_degree"]
fleet_exe_desc.mp_degree = fleet_opt["dist_strategy"]["mp_degree"]
fleet_exe_desc.pp_degree = fleet_opt["dist_strategy"]["pp_degree"]
if "num_micro_batches" in fleet_opt:
fleet_exe_desc.num_micro_batches = fleet_opt["num_micro_batches"]
FeixLiu marked this conversation as resolved.
Show resolved Hide resolved
num_of_gpu = fleet_exe_desc.dp_degree * fleet_exe_desc.mp_degree * fleet_exe_desc.pp_degree
assert nrank == num_of_gpu, "The number of rank is not equal to the number of gpu."
fleet_exe = core.FleetExecutor(fleet_exe_desc.SerializeToString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ def test_dist_executor_on_multi_devices(self):
"mp_degree": 2,
"pp_degree": 2
}
fleet_opt = {"dist_strategy": strategy.sharding_configs}
strategy.pipeline_configs = {"accumulate_steps": 8}
fleet_opt = {
"dist_strategy": strategy.sharding_configs,
"num_micro_batches": strategy.pipeline_configs["accumulate_steps"]
}
if fluid.is_compiled_with_cuda():
self.run_fleet_executor(fluid.CUDAPlace(0), fleet_opt)

Expand Down