-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fleet_executor] Interceptor run op #37623
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/task_node.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
@@ -40,9 +41,22 @@ void ComputeInterceptor::PrepareDeps() { | |
for (auto down_id : downstream) { | ||
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0)); | ||
} | ||
|
||
// source compute node, should we add a new SourceInterceptor? | ||
if (upstream.empty()) { | ||
is_source_ = true; | ||
PADDLE_ENFORCE_GT(node_->max_run_times(), 0, | ||
platform::errors::InvalidArgument( | ||
"Source ComputeInterceptor must run at least one " | ||
"times, but now max_run_times=%ld", | ||
node_->max_run_times())); | ||
} | ||
} | ||
|
||
void ComputeInterceptor::IncreaseReady(int64_t up_id) { | ||
// source node has no upstream, data_is_ready is send by carrier or others | ||
if (is_source_ && up_id == -1) return; | ||
|
||
auto it = in_readys_.find(up_id); | ||
PADDLE_ENFORCE_NE(it, in_readys_.end(), | ||
platform::errors::NotFound( | ||
|
@@ -93,6 +107,12 @@ bool ComputeInterceptor::CanWriteOutput() { | |
return true; | ||
} | ||
|
||
// only source node need reset | ||
bool ComputeInterceptor::ShouldReset() { | ||
if (is_source_ && step_ == node_->max_run_times()) return true; | ||
return false; | ||
} | ||
|
||
void ComputeInterceptor::SendDataReadyToDownStream() { | ||
for (auto& outs : out_buffs_) { | ||
auto down_id = outs.first; | ||
|
@@ -134,9 +154,27 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { | |
} | ||
|
||
void ComputeInterceptor::Run() { | ||
while (IsInputReady() && CanWriteOutput()) { | ||
// If there is no limit, source interceptor can be executed | ||
// an unlimited number of times. | ||
// Now source node can only run | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个annotation是不是没写完😂 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 额,写了的,不知道咋漏了 |
||
if (ShouldReset()) { | ||
for (auto& out_buff : out_buffs_) { | ||
// buffer is using | ||
if (out_buff.second.second != 0) return; | ||
} | ||
step_ = 0; // reset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 咱们有 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 除了source和若干特殊interceptor,大部分interceptor都是不需要reset的。如果需要reset,也是需要像stop那样,否则还没跑完就给reset了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个reset不应该放在carrier.start()里么,每一次外部调用fleet_executor.run()的时候,这个step就应该reset呀。这个step不是标明当前ministep里,跑了多少个micro step么。而且在运行过程中,reset source interceptor有什么特殊含义? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那这个reset也得有一个数据流的流动,和当前的并不冲突。这一步的判断总归是要做的,只不过再加个reset的条件触发。否则刚执行完一个micro_step,来了个reset消息,还没执行完就给reset了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,应该加一个条件,不过,为啥会刚执行完一个mirco step就会收到reset 🤨 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那在开头加上一个step的reset吗,没有下一步了咋办O_o |
||
return; | ||
} | ||
|
||
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) { | ||
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; | ||
// TODO(wangxi): add op run | ||
|
||
// step_ %= node_->max_run_times(); | ||
for (auto op : node_->ops()) { | ||
auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()]; | ||
op->Run(*scope, place_); | ||
} | ||
++step_; | ||
|
||
// send to downstream and increase buff used | ||
SendDataReadyToDownStream(); | ||
|
@@ -149,7 +187,7 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) { | |
received_stop_ = true; | ||
|
||
// source node has no upstream, stop is send by carrier or others | ||
if (up_id == -1) return; | ||
if (is_source_ && up_id == -1) return; | ||
|
||
auto it = in_stops_.find(up_id); | ||
PADDLE_ENFORCE_NE(it, in_stops_.end(), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <iostream> | ||
#include <unordered_map> | ||
|
||
#include "gtest/gtest.h" | ||
|
||
#include "paddle/fluid/distributed/fleet_executor/carrier.h" | ||
#include "paddle/fluid/distributed/fleet_executor/interceptor.h" | ||
#include "paddle/fluid/distributed/fleet_executor/message_bus.h" | ||
#include "paddle/fluid/distributed/fleet_executor/task_node.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/program_desc.h" | ||
|
||
USE_OP(elementwise_add); | ||
USE_OP(fill_constant); | ||
|
||
namespace paddle { | ||
namespace distributed { | ||
|
||
std::vector<framework::OperatorBase*> GetOps() { | ||
framework::AttributeMap attrs; | ||
attrs["dtype"] = framework::proto::VarType::FP32; | ||
attrs["shape"] = framework::vectorize<int>({2, 3}); | ||
attrs["value"] = 1.0f; | ||
|
||
auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥叫zero呀🤔,值不是1么 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从别处copy的,木有改 ╮(╯_╰)╭ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😂,all right |
||
{{"Out", {"x"}}}, attrs); | ||
|
||
auto op = framework::OpRegistry::CreateOp( | ||
"elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}}, | ||
framework::AttributeMap()); | ||
|
||
// NOTE: don't delete | ||
return {zero_op.release(), op.release()}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dangling ptr++ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 没有和线程或者carrier绑定起来,析构了就访问非法地址了-。- |
||
} | ||
|
||
framework::Scope* GetScope() { | ||
framework::Scope* scope = new framework::Scope(); | ||
|
||
scope->Var("x")->GetMutable<framework::LoDTensor>(); | ||
scope->Var("out")->GetMutable<framework::LoDTensor>(); | ||
return scope; | ||
} | ||
|
||
TEST(ComputeInterceptor, Compute) { | ||
std::vector<framework::OperatorBase*> ops = GetOps(); | ||
framework::Scope* scope = GetScope(); | ||
std::vector<framework::Scope*> scopes = {scope, scope}; | ||
platform::Place place = platform::CPUPlace(); | ||
|
||
MessageBus& msg_bus = MessageBus::Instance(); | ||
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); | ||
|
||
Carrier& carrier = Carrier::Instance(); | ||
|
||
// FIXME: don't delete, otherwise interceptor will use undefined node | ||
TaskNode* node_a = | ||
new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id | ||
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0); | ||
|
||
// a->b | ||
node_a->AddDownstreamTask(1); | ||
node_b->AddUpstreamTask(0); | ||
|
||
auto* a = carrier.SetInterceptor( | ||
0, InterceptorFactory::Create("Compute", 0, node_a)); | ||
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); | ||
|
||
a->SetPlace(place); | ||
a->SetMicroBatchScope(scopes); | ||
|
||
carrier.SetCreatingFlag(false); | ||
|
||
// start | ||
InterceptorMessage msg; | ||
msg.set_message_type(DATA_IS_READY); | ||
msg.set_src_id(-1); | ||
msg.set_dst_id(0); | ||
carrier.EnqueueInterceptorMessage(msg); | ||
|
||
// stop | ||
InterceptorMessage stop; | ||
stop.set_message_type(STOP); | ||
stop.set_src_id(-1); | ||
stop.set_dst_id(0); | ||
carrier.EnqueueInterceptorMessage(stop); | ||
} | ||
|
||
} // namespace distributed | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥这儿有一个判断,这个东西不是上层指定的,他说跑0次就跑0次呗
而且,如果这个要判断,是不是只要有下游的interceptor都要判断一下,因为都需要至少跑一次才可以触发下游
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
source interceptor,in degree为0,没有上游依赖,需要指定跑多少次,跑0次肯定是错的,这个max_run_times其实也是不准确的,应该就是run_times。 其它的compute interceptor,并不需要知道跑多少次,上游给消息就跑,除非有特殊功能,是不需要max_run_times的。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reasonable