-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fleet_executor] Add compute interceptor #37376
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/task_node.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) | ||
: Interceptor(interceptor_id, node) { | ||
PrepareDeps(); | ||
RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); }); | ||
} | ||
|
||
void ComputeInterceptor::PrepareDeps() { | ||
auto& upstream = GetTaskNode()->upstream(); | ||
upstream_deps_.insert(upstream.begin(), upstream.end()); | ||
} | ||
|
||
void ComputeInterceptor::SendDataReadyToDownStream() { | ||
auto& downstream = GetTaskNode()->downstream(); | ||
for (auto dst_id : downstream) { | ||
InterceptorMessage dst_msg; | ||
dst_msg.set_message_type(DATA_IS_READY); | ||
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id; | ||
Send(dst_id, dst_msg); | ||
} | ||
} | ||
|
||
void ComputeInterceptor::Compute(const InterceptorMessage& msg) { | ||
if (msg.message_type() == DATA_IS_READY) { | ||
auto src_id = msg.src_id(); | ||
upstream_deps_.erase(src_id); | ||
|
||
// all input is ready | ||
if (upstream_deps_.empty()) { | ||
// TODO(wangxi): op run | ||
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; | ||
SendDataReadyToDownStream(); | ||
PrepareDeps(); | ||
} | ||
} | ||
} | ||
|
||
REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
class ComputeInterceptor : public Interceptor { | ||
public: | ||
ComputeInterceptor(int64_t interceptor_id, TaskNode* node); | ||
|
||
void PrepareDeps(); | ||
|
||
void SendDataReadyToDownStream(); | ||
|
||
void Compute(const InterceptorMessage& msg); | ||
|
||
private: | ||
std::unordered_set<int64_t> upstream_deps_; | ||
}; | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) { | |
|
||
void Interceptor::PoolTheMailbox() { | ||
// pool the local mailbox, parse the Message | ||
while (true) { | ||
for (;;) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😶🌫️ |
||
if (local_mailbox_.empty()) { | ||
// local mailbox is empty, fetch the remote mailbox | ||
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,10 @@ | |
#pragma once | ||
#include <cstdint> | ||
#include <memory> | ||
#include <string> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/platform/macros.h" | ||
|
||
namespace paddle { | ||
|
@@ -33,16 +35,20 @@ class TaskNode final { | |
TaskNode(int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank, | ||
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums); | ||
~TaskNode() = default; | ||
|
||
int64_t rank() const { return rank_; } | ||
int64_t task_id() const { return task_id_; } | ||
int32_t role() const { return role_; } | ||
int64_t max_run_times() const { return max_run_times_; } | ||
int64_t max_slot_nums() const { return max_slot_nums_; } | ||
const std::unordered_set<int64_t>& upstream() const { return upstream_; } | ||
const std::unordered_set<int64_t>& downstream() const { return downstream_; } | ||
const std::string& type() const { return type_; } | ||
|
||
void AddUpstreamTask(int64_t task_id); | ||
void AddDownstreamTask(int64_t task_id); | ||
std::string DebugString() const; | ||
|
||
static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int32_t role, | ||
int64_t rank, | ||
int64_t task_id, | ||
|
@@ -63,6 +69,8 @@ class TaskNode final { | |
int64_t task_id_; | ||
int64_t max_run_times_; | ||
int64_t max_slot_nums_; | ||
|
||
std::string type_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是做什么的?区分interceptor种类用的? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,carrier构建用的,这段逻辑还没有加 |
||
}; | ||
|
||
} // namespace distributed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) | ||
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) | ||
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS}) | ||
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <iostream> | ||
#include <unordered_map> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/carrier.h" | ||
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" | ||
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" | ||
#include "paddle/fluid/distributed/fleet_executor/task_node.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
class StopInterceptor : public Interceptor { | ||
public: | ||
StopInterceptor(int64_t interceptor_id, TaskNode* node) | ||
: Interceptor(interceptor_id, node) { | ||
RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); }); | ||
} | ||
|
||
void Stop(const InterceptorMessage& msg) { | ||
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() | ||
<< std::endl; | ||
InterceptorMessage stop; | ||
stop.set_message_type(STOP); | ||
Send(0, stop); | ||
Send(1, stop); | ||
Send(2, stop); | ||
} | ||
}; | ||
|
||
TEST(ComputeInterceptor, Compute) { | ||
MessageBus& msg_bus = MessageBus::Instance(); | ||
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); | ||
|
||
Carrier& carrier = Carrier::Instance(); | ||
|
||
// NOTE: don't delete, otherwise interceptor will use undefined node | ||
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id | ||
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); | ||
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0); | ||
|
||
// a->b->c | ||
node_a->AddDownstreamTask(1); | ||
node_b->AddUpstreamTask(0); | ||
node_b->AddDownstreamTask(2); | ||
|
||
Interceptor* a = carrier.SetInterceptor( | ||
0, InterceptorFactory::Create("Compute", 0, node_a)); | ||
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); | ||
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c)); | ||
|
||
carrier.SetCreatingFlag(false); | ||
|
||
InterceptorMessage msg; | ||
msg.set_message_type(DATA_IS_READY); | ||
a->Send(1, msg); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是不是可以给StopInterceptor加一个 finish的flag,这里wait那个flag。然后就可以delete那三个new出来的指针了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个后续讨论一下,我觉得应该是通过析构来 |
||
} | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉erase不太好,这个Compute会被调用很多次,它的上下游是固定的话,一次compute清空,下一次算还要加回来,感觉。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,我也觉得不太好。特别是有多个上游情况下,如果有个上游产出多次,这个还会出错。不过目前是
demo
,所以没那么讲究。可能还是新建一个空的,然后填充对比比较好一些,但这样还是有上游产出多次的问题需要考虑下怎么解决。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得完全不考虑micro scope可能有问题🤨
比如C同时依赖A,B的情况,且A的运行速度比B快。需要跑两个micro steps。
这样时间2这一刻,A给C发送的micro step1的DATA_IS_READY怎么处理?且在最后,AB都结束了两个micro steps的运行,但是C永远会在等A的第二个micro step的DATA_IS_READY。
现阶段我们的上下游依赖很简单,应该都是单依赖的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,现在这个写的是个demo compute,后续需要有buffer作为流控,需要一个个buffer写,一个buffer写满了才能计算一个