Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Interceptor run op #37623

Merged
merged 4 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"

#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace distributed {
Expand All @@ -40,9 +41,22 @@ void ComputeInterceptor::PrepareDeps() {
for (auto down_id : downstream) {
out_buffs_.emplace(down_id, std::make_pair(out_buff_size, 0));
}

// source compute node, should we add a new SourceInterceptor?
if (upstream.empty()) {
is_source_ = true;
PADDLE_ENFORCE_GT(node_->max_run_times(), 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥这儿有一个判断,这个东西不是上层指定的,他说跑0次就跑0次呗
而且,如果这个要判断,是不是只要有下游的interceptor都要判断一下,因为都需要至少跑一次才可以触发下游

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

source interceptor,in degree为0,没有上游依赖,需要指定跑多少次,跑0次肯定是错的,这个max_run_times其实也是不准确的,应该就是run_times。 其它的compute interceptor,并不需要知道跑多少次,上游给消息就跑,除非有特殊功能,是不需要max_run_times的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable

platform::errors::InvalidArgument(
"Source ComputeInterceptor must run at least one "
"times, but now max_run_times=%ld",
node_->max_run_times()));
}
}

void ComputeInterceptor::IncreaseReady(int64_t up_id) {
// source node has no upstream, data_is_ready is send by carrier or others
if (is_source_ && up_id == -1) return;

auto it = in_readys_.find(up_id);
PADDLE_ENFORCE_NE(it, in_readys_.end(),
platform::errors::NotFound(
Expand Down Expand Up @@ -93,6 +107,12 @@ bool ComputeInterceptor::CanWriteOutput() {
return true;
}

// only source node need reset
bool ComputeInterceptor::ShouldReset() {
if (is_source_ && step_ == node_->max_run_times()) return true;
return false;
}

void ComputeInterceptor::SendDataReadyToDownStream() {
for (auto& outs : out_buffs_) {
auto down_id = outs.first;
Expand Down Expand Up @@ -134,9 +154,27 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}

void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput()) {
// If there is no limit, source interceptor can be executed
// an unlimited number of times.
// Now source node can only run
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个annotation是不是没写完😂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

额,写了的,不知道咋漏了

if (ShouldReset()) {
for (auto& out_buff : out_buffs_) {
// buffer is using
if (out_buff.second.second != 0) return;
}
step_ = 0; // reset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咱们有RESET message呀,这个置0应该是消息驱动的吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

除了source和若干特殊interceptor,大部分interceptor都是不需要reset的。如果需要reset,也是需要像stop那样,否则还没跑完就给reset了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个reset不应该放在carrier.start()里么,每一次外部调用fleet_executor.run()的时候,这个step就应该reset呀。这个step不是标明当前ministep里,跑了多少个micro step么。而且在运行过程中,reset source interceptor有什么特殊含义?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这个reset也得有一个数据流的流动,和当前的并不冲突。这一步的判断总归是要做的,只不过再加个reset的条件触发。否则刚执行完一个micro_step,来了个reset消息,还没执行完就给reset了。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,应该加一个条件,不过,为啥会刚执行完一个mirco step就会收到reset 🤨

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那在开头加上一个step的reset吗,没有下一步了咋办O_o

return;
}

while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
// TODO(wangxi): add op run

// step_ %= node_->max_run_times();
for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_slot_nums()];
op->Run(*scope, place_);
}
++step_;

// send to downstream and increase buff used
SendDataReadyToDownStream();
Expand All @@ -149,7 +187,7 @@ void ComputeInterceptor::ReceivedStop(int64_t up_id) {
received_stop_ = true;

// source node has no upstream, stop is send by carrier or others
if (up_id == -1) return;
if (is_source_ && up_id == -1) return;

auto it = in_stops_.find(up_id);
PADDLE_ENFORCE_NE(it, in_stops_.end(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ComputeInterceptor : public Interceptor {
void DecreaseBuff(int64_t down_id);
bool IsInputReady();
bool CanWriteOutput();
bool ShouldReset();

void SendDataReadyToDownStream();
void ReplyCompletedToUpStream();
Expand All @@ -43,8 +44,9 @@ class ComputeInterceptor : public Interceptor {
void TryStop();

private:
// FIXME(wangxi): if use step_ and max_steps_, how to restart step_ from 0
bool is_source_{false};
int64_t step_{0};

// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
// downstream_id-->(max_buffer_size, used_size)
Expand Down
34 changes: 27 additions & 7 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace framework {
class Scope;
}
namespace distributed {

class TaskNode;
Expand Down Expand Up @@ -64,12 +68,34 @@ class Interceptor {

bool Send(int64_t dst_id, InterceptorMessage& msg); // NOLINT

void SetPlace(const platform::Place& place) { place_ = place; }

void SetRootScope(framework::Scope* scope) { root_scope_ = scope; }
void SetMiniBatchScope(framework::Scope* scope) { minibatch_scope_ = scope; }
void SetMicroBatchScope(const std::vector<framework::Scope*>& scopes) {
microbatch_scopes_ = scopes;
}

TaskNode* GetTaskNode() const { return node_; }

DISABLE_COPY_AND_ASSIGN(Interceptor);

protected:
TaskNode* GetTaskNode() const { return node_; }
// interceptor id, handed from above layer
int64_t interceptor_id_;

// node need to be handled by this interceptor
TaskNode* node_;

// for stop
bool stop_{false};

// for runtime
platform::Place place_;
framework::Scope* root_scope_{nullptr};
framework::Scope* minibatch_scope_{nullptr};
std::vector<framework::Scope*> microbatch_scopes_{};

private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
Expand All @@ -78,12 +104,6 @@ class Interceptor {
// return true if remote mailbox not empty, otherwise return false
bool FetchRemoteMailbox();

// interceptor id, handed from above layer
int64_t interceptor_id_;

// node need to be handled by this interceptor
TaskNode* node_;

// interceptor handle which process message
MsgHandle handle_{nullptr};

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/distributed/fleet_executor/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TaskNode final {
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }

bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id);
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})

set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})

set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set_source_files_properties(interceptor_ping_pong_with_brpc_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_with_brpc_test SRCS interceptor_ping_pong_with_brpc_test.cc DEPS fleet_executor ${BRPC_DEPS})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <iostream>
#include <unordered_map>

#include "gtest/gtest.h"

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"

USE_OP(elementwise_add);
USE_OP(fill_constant);

namespace paddle {
namespace distributed {

std::vector<framework::OperatorBase*> GetOps() {
framework::AttributeMap attrs;
attrs["dtype"] = framework::proto::VarType::FP32;
attrs["shape"] = framework::vectorize<int>({2, 3});
attrs["value"] = 1.0f;

auto zero_op = framework::OpRegistry::CreateOp("fill_constant", {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥叫zero呀🤔,值不是1么

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从别处copy的,木有改 ╮(╯_╰)╭

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂,all right

{{"Out", {"x"}}}, attrs);

auto op = framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"x"}}, {"Y", {"x"}}}, {{"Out", {"out"}}},
framework::AttributeMap());

// NOTE: don't delete
return {zero_op.release(), op.release()};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dangling ptr++

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有和线程或者carrier绑定起来,析构了就访问非法地址了-。-

}

framework::Scope* GetScope() {
framework::Scope* scope = new framework::Scope();

scope->Var("x")->GetMutable<framework::LoDTensor>();
scope->Var("out")->GetMutable<framework::LoDTensor>();
return scope;
}

TEST(ComputeInterceptor, Compute) {
std::vector<framework::OperatorBase*> ops = GetOps();
framework::Scope* scope = GetScope();
std::vector<framework::Scope*> scopes = {scope, scope};
platform::Place place = platform::CPUPlace();

MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");

Carrier& carrier = Carrier::Instance();

// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);

// a->b
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);

auto* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));

a->SetPlace(place);
a->SetMicroBatchScope(scopes);

carrier.SetCreatingFlag(false);

// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);

// stop
InterceptorMessage stop;
stop.set_message_type(STOP);
stop.set_src_id(-1);
stop.set_dst_id(0);
carrier.EnqueueInterceptorMessage(stop);
}

} // namespace distributed
} // namespace paddle