Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Add compute interceptor #37376

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ else()
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
namespace paddle {
namespace distributed {

USE_INTERCEPTOR(Compute);

void Carrier::Init(
const std::unordered_map<int64_t, TaskNode*>& interceptor_id_to_node) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
Expand Down
61 changes: 61 additions & 0 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"

#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
namespace distributed {

ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Compute(msg); });
}

void ComputeInterceptor::PrepareDeps() {
auto& upstream = GetTaskNode()->upstream();
upstream_deps_.insert(upstream.begin(), upstream.end());
}

void ComputeInterceptor::SendDataReadyToDownStream() {
auto& downstream = GetTaskNode()->downstream();
for (auto dst_id : downstream) {
InterceptorMessage dst_msg;
dst_msg.set_message_type(DATA_IS_READY);
VLOG(3) << "ComputeInterceptor Send msg to " << dst_id;
Send(dst_id, dst_msg);
}
}

void ComputeInterceptor::Compute(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
auto src_id = msg.src_id();
upstream_deps_.erase(src_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉erase不太好,这个Compute会被调用很多次,它的上下游是固定的话,一次compute清空,下一次算还要加回来,感觉。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我也觉得不太好。特别是有多个上游情况下,如果有个上游产出多次,这个还会出错。不过目前是demo,所以没那么讲究。
可能还是新建一个空的,然后填充对比比较好一些,但这样还是有上游产出多次的问题需要考虑下怎么解决。

Copy link
Contributor

@FeixLiu FeixLiu Nov 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得完全不考虑micro scope可能有问题🤨

比如C同时依赖A,B的情况,且A的运行速度比B快。需要跑两个micro steps。

time event C's upstream_deps()
0 C初始化 (A, B)
1 A的micro step 0完成,A->C: DATA_IS_READY (B)
2 A的micro step 1完成,A->C: DATA_IS_READY (B)
3 B的micro step 0完成,B->C: DATA_IS_READY (),C开始执行micro step 0
4 C重新构建upstream_deps (A, B)
5 B的micro step 1完成,B->C: DATA_IS_READY (A)

这样时间2这一刻,A给C发送的micro step1的DATA_IS_READY怎么处理?且在最后,AB都结束了两个micro steps的运行,但是C永远会在等A的第二个micro step的DATA_IS_READY。

现阶段我们的上下游依赖很简单,应该都是单依赖的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,现在这个写的是个demo compute,后续需要有buffer作为流控,需要一个个buffer写,一个buffer写满了才能计算一个


// all input is ready
if (upstream_deps_.empty()) {
// TODO(wangxi): op run
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";
SendDataReadyToDownStream();
PrepareDeps();
}
}
}

REGISTER_INTERCEPTOR(Compute, ComputeInterceptor);

} // namespace distributed
} // namespace paddle
37 changes: 37 additions & 0 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/distributed/fleet_executor/interceptor.h"

namespace paddle {
namespace distributed {

class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);

void PrepareDeps();

void SendDataReadyToDownStream();

void Compute(const InterceptorMessage& msg);

private:
std::unordered_set<int64_t> upstream_deps_;
};

} // namespace distributed
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/fleet_executor/interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ bool Interceptor::Send(int64_t dst_id, InterceptorMessage& msg) {

void Interceptor::PoolTheMailbox() {
// pool the local mailbox, parse the Message
while (true) {
for (;;) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😶‍🌫️

if (local_mailbox_.empty()) {
// local mailbox is empty, fetch the remote mailbox
VLOG(3) << interceptor_id_ << "'s local mailbox is empty. "
Expand Down
26 changes: 20 additions & 6 deletions paddle/fluid/distributed/fleet_executor/interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class Interceptor {

DISABLE_COPY_AND_ASSIGN(Interceptor);

protected:
TaskNode* GetTaskNode() const { return node_; }

private:
// pool the local mailbox, parse the Message
void PoolTheMailbox();
Expand Down Expand Up @@ -114,19 +117,30 @@ class InterceptorFactory {
int64_t id, TaskNode* node);
};

template <typename InterceptorClass>
std::unique_ptr<Interceptor> CreatorInterceptor(int64_t id, TaskNode* node) {
return std::make_unique<InterceptorClass>(id, node);
}

#define REGISTER_INTERCEPTOR(interceptor_type, interceptor_class) \
std::unique_ptr<Interceptor> CreatorInterceptor_##interceptor_type( \
int64_t id, TaskNode* node) { \
return std::make_unique<interceptor_class>(id, node); \
} \
class __RegisterInterceptor_##interceptor_type { \
public: \
__RegisterInterceptor_##interceptor_type() { \
InterceptorFactory::Register(#interceptor_type, \
CreatorInterceptor_##interceptor_type); \
CreatorInterceptor<interceptor_class>); \
} \
void Touch() {} \
}; \
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type;
__RegisterInterceptor_##interceptor_type g_register_##interceptor_type; \
int TouchRegisterInterceptor_##interceptor_type() { \
g_register_##interceptor_type.Touch(); \
return 0; \
}

#define USE_INTERCEPTOR(interceptor_type) \
extern int TouchRegisterInterceptor_##interceptor_type(); \
UNUSED static int use_interceptor_##interceptor_type = \
TouchRegisterInterceptor_##interceptor_type();

} // namespace distributed
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/distributed/fleet_executor/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/platform/macros.h"

namespace paddle {
Expand All @@ -33,16 +35,20 @@ class TaskNode final {
TaskNode(int32_t role, const std::vector<OperatorBase*>& ops, int64_t rank,
int64_t task_id, int64_t max_run_times, int64_t max_slot_nums);
~TaskNode() = default;

int64_t rank() const { return rank_; }
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; }

void AddUpstreamTask(int64_t task_id);
void AddDownstreamTask(int64_t task_id);
std::string DebugString() const;

static std::unique_ptr<TaskNode> CreateEmptyTaskNode(int32_t role,
int64_t rank,
int64_t task_id,
Expand All @@ -63,6 +69,8 @@ class TaskNode final {
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;

std::string type_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是做什么的?区分interceptor种类用的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,carrier构建用的,这段逻辑还没有加

};

} // namespace distributed
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
set_source_files_properties(interceptor_ping_pong_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet_executor ${BRPC_DEPS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <iostream>
#include <unordered_map>

#include "gtest/gtest.h"

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
namespace distributed {

class StopInterceptor : public Interceptor {
public:
StopInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
RegisterMsgHandle([this](const InterceptorMessage& msg) { Stop(msg); });
}

void Stop(const InterceptorMessage& msg) {
std::cout << GetInterceptorId() << " recv msg from " << msg.src_id()
<< std::endl;
InterceptorMessage stop;
stop.set_message_type(STOP);
Send(0, stop);
Send(1, stop);
Send(2, stop);
}
};

TEST(ComputeInterceptor, Compute) {
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}}, {{0, "127.0.0.0:0"}}, "127.0.0.0:0");

Carrier& carrier = Carrier::Instance();

// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 0, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 0, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 0, 0);

// a->b->c
node_a->AddDownstreamTask(1);
node_b->AddUpstreamTask(0);
node_b->AddDownstreamTask(2);

Interceptor* a = carrier.SetInterceptor(
0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b));
carrier.SetInterceptor(2, std::make_unique<StopInterceptor>(2, node_c));

carrier.SetCreatingFlag(false);

InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
a->Send(1, msg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以给StopInterceptor加一个 finish的flag,这里wait那个flag。然后就可以delete那三个new出来的指针了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个后续讨论一下,我觉得应该是通过析构来

}

} // namespace distributed
} // namespace paddle