Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… eager_dygraph_trace_op_refactor
  • Loading branch information
jim19930609 committed Dec 22, 2021
2 parents 0200dfb + 242ef2b commit ff4b331
Show file tree
Hide file tree
Showing 133 changed files with 2,612 additions and 985 deletions.
17 changes: 8 additions & 9 deletions cmake/infrt_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

set(PADDLE_INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING
set(INFRT_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_infrt_install_dir" CACHE STRING
"A path setting paddle infrt shared and static libraries")

function(copy TARGET)
Expand Down Expand Up @@ -52,18 +52,17 @@ add_custom_target(infrt_lib_dist DEPENDS ${infrt_lib_deps})
# CMakeCache Info
copy(infrt_lib_dist
SRCS ${CMAKE_BINARY_DIR}/CMakeCache.txt
DSTS ${PADDLE_INFRT_INSTALL_DIR})
DSTS ${INFRT_INSTALL_DIR})

set(src_dir "${PADDLE_SOURCE_DIR}/paddle/infrt")
set(paddle_infrt_lib ${PADDLE_BINARY_DIR}/paddle/infrt/libinfrt.*)
set(infrt_lib ${INFRT_BINARY_DIR}/libinfrt.*)
copy(infrt_lib_dist
SRCS ${src_dir}/api/infrt_api.h ${paddle_infrt_lib}
DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include ${PADDLE_INFRT_INSTALL_DIR}/infrt/lib)
SRCS ${INFRT_SOURCE_DIR}/api/infrt_api.h ${infrt_lib}
DSTS ${INFRT_INSTALL_DIR}/infrt/include ${INFRT_INSTALL_DIR}/infrt/lib)


copy(infrt_lib_dist
SRCS ${CMAKE_BINARY_DIR}/paddle/infrt/paddle/framework.pb.h
DSTS ${PADDLE_INFRT_INSTALL_DIR}/infrt/include/internal)
SRCS ${INFRT_BINARY_DIR}/paddle/framework.pb.h
DSTS ${INFRT_INSTALL_DIR}/infrt/include/internal)

# paddle fluid version
function(version version_file)
Expand All @@ -74,4 +73,4 @@ function(version version_file)
file(WRITE ${version_file} "GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n")
file(APPEND ${version_file} "CXX compiler version: ${CMAKE_CXX_COMPILER_VERSION}\n")
endfunction()
version(${PADDLE_INFRT_INSTALL_DIR}/version.txt)
version(${INFRT_INSTALL_DIR}/version.txt)
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ endif()
proto_library(interceptor_message_proto SRCS interceptor_message.proto)

if(WITH_DISTRIBUTE AND WITH_PSCORE AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog)
set(BRPC_DEPS brpc ssl crypto protobuf zlib leveldb snappy gflags glog)
else()
set(BRPC_DEPS "")
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper op_registry
executor_gc_helper ${BRPC_DEPS})
executor_gc_helper gflags glog ${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
Expand Down
71 changes: 48 additions & 23 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ namespace distributed {
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Carrier::Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place) {
PADDLE_ENFORCE_EQ(is_init_, false, platform::errors::AlreadyExists(
"Carrier is already init."));
rank_ = rank;
runtime_graph_ = runtime_graph;
interceptor_id_to_rank_ = runtime_graph_->interceptor_id_to_rank();
minibatch_scope_ = minibatch_scope;
microbatch_scopes_ = microbatch_scopes;
place_ = place;
Expand All @@ -48,12 +50,6 @@ void Carrier::Release() {
// NOTE(wangxi): must join before `Derived Interceptor` destruct,
// otherwise Derived object will be destructed before thread complete.

// Sending STOP msg to the source interceptor
PADDLE_ENFORCE_EQ(msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
for (int64_t id : source_interceptor_ids_) {
VLOG(3) << "Carrier Release is sending stop to source interceptor " << id
<< ".";
Expand All @@ -75,10 +71,10 @@ Carrier::~Carrier() { VLOG(3) << "Carrier's destructor."; }

bool Carrier::EnqueueInterceptorMessage(
const InterceptorMessage& interceptor_message) {
// enqueue message to interceptor
if (interceptor_message.ctrl_message()) {
// handle control message
return true;
VLOG(3) << "Receiving control message from rank "
<< interceptor_message.src_id() << " to rank "
<< interceptor_message.dst_id();
} else {
{
std::unique_lock<std::mutex> lock_creating(creating_flag_mutex_);
Expand All @@ -93,15 +89,9 @@ bool Carrier::EnqueueInterceptorMessage(
}
int64_t dst_id = interceptor_message.dst_id();
Interceptor* dst_interceptor = GetInterceptor(dst_id);
bool rst =
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
if (rst) {
std::condition_variable& interceptor_cond_var =
dst_interceptor->GetCondVar();
interceptor_cond_var.notify_all();
}
return rst;
dst_interceptor->EnqueueRemoteInterceptorMessage(interceptor_message);
}
return true;
}

Interceptor* Carrier::GetInterceptor(int64_t interceptor_id) {
Expand Down Expand Up @@ -144,9 +134,44 @@ std::condition_variable& Carrier::GetCondVar() { return cond_var_; }

bool Carrier::IsInit() const { return is_init_; }

// TODO(liyurui): Move SendIntra into carrier
bool Carrier::Send(const InterceptorMessage& msg) const {
return msg_bus_->Send(msg);
int64_t Carrier::GetRank(int64_t interceptor_id) const {
PADDLE_ENFORCE_NE(
interceptor_id_to_rank_.find(interceptor_id),
interceptor_id_to_rank_.end(),
platform::errors::NotFound("Cannot find rank for interceptor id %lld.",
interceptor_id));
return interceptor_id_to_rank_.at(interceptor_id);
}

bool Carrier::Send(const InterceptorMessage& msg) {
int64_t src_id = (msg.src_id() == -1) ? msg.dst_id() : msg.src_id();
int64_t dst_id = msg.dst_id();
int64_t src_rank = GetRank(src_id);
int64_t dst_rank = GetRank(dst_id);
PADDLE_ENFORCE_EQ(
src_rank, rank_,
platform::errors::Fatal("The source rank id %lld, which is not equal to "
"the carrier rank id %lld.",
src_rank, rank_));
if (src_rank == dst_rank) {
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id << ", which are in the same ranks.";
return EnqueueInterceptorMessage(msg);
} else {
PADDLE_ENFORCE_NOT_NULL(
msg_bus_.get(),
platform::errors::Unavailable("Message bus is released accidently"));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::PreconditionNotMet(
"Using message bus since it has not been initialized. "
"Please invoke MessageBus::Init() before using it or "
"neccessary components are not ready."));
VLOG(3) << "Send a message from interceptor " << src_id
<< " to interceptor " << dst_id
<< ", which are in different ranks.";
return msg_bus_->Send(dst_rank, msg);
}
}

Interceptor* Carrier::SetInterceptor(int64_t interceptor_id,
Expand Down Expand Up @@ -222,13 +247,13 @@ static std::shared_ptr<framework::GarbageCollector> GetGC(
}

void Carrier::CreateInterceptors() {
if (runtime_graph_->intercepter_id_to_node().empty()) return;
if (runtime_graph_->interceptor_id_to_node().empty()) return;

auto gc = GetGC(place_);

// create each Interceptor
// no auto init since there is no config
for (const auto& item : runtime_graph_->intercepter_id_to_node()) {
for (const auto& item : runtime_graph_->interceptor_id_to_node()) {
int64_t interceptor_id = item.first;
TaskNode* task_node = item.second;

Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/distributed/fleet_executor/carrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class MessageBus;
class Carrier final {
public:
Carrier() = default;
Carrier(int64_t rank,
const std::unordered_map<int64_t, int64_t>& interceptor_id_to_rank)
: rank_(rank), interceptor_id_to_rank_(interceptor_id_to_rank) {}
~Carrier();
void Init(std::shared_ptr<RuntimeGraph> runtime_graph,
void Init(int64_t rank, std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope, framework::Scope* minibatch_scope,
const std::vector<framework::Scope*>& microbatch_scopes,
const platform::Place& place);
Expand Down Expand Up @@ -75,7 +78,7 @@ class Carrier final {

bool IsInit() const;

bool Send(const InterceptorMessage& msg) const;
bool Send(const InterceptorMessage& msg);

// NOTE: This mutex will be used in interceptor's RunOps function.
// This mutex is used for avoiding forward ops and backward ops run
Expand All @@ -90,6 +93,8 @@ class Carrier final {

void HandleTmpMessages();

int64_t GetRank(int64_t interceptor_id) const;

// interceptor logic id to actually interceptor
std::unordered_map<int64_t, std::unique_ptr<Interceptor>>
interceptor_idx_to_interceptor_;
Expand All @@ -111,6 +116,7 @@ class Carrier final {
paddle::platform::DeviceContext* dev_ctx_{nullptr};
std::shared_ptr<RuntimeGraph> runtime_graph_;
std::shared_ptr<MessageBus> msg_bus_;
int64_t rank_;
std::unordered_map<int64_t, int64_t> interceptor_id_to_rank_;
};

Expand Down
30 changes: 15 additions & 15 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/runtime_graph.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
Expand All @@ -28,6 +27,8 @@
namespace paddle {
namespace distributed {

std::unique_ptr<Carrier> FleetExecutor::carrier_;

FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
bool parse_flag = exe_desc_.ParseFromString(exe_desc_str);
PADDLE_ENFORCE(parse_flag, platform::errors::PreconditionNotMet(
Expand All @@ -36,12 +37,13 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {

FleetExecutor::~FleetExecutor() {
root_scope_->DropKids();
GetCarrier().Release();
GetCarrier()->Release();
}

Carrier& FleetExecutor::GetCarrier() {
static Carrier carrier;
return carrier;
Carrier* FleetExecutor::GetCarrier() {
PADDLE_ENFORCE_NOT_NULL(carrier_.get(), platform::errors::NotFound(
"Carrier has not been created."));
return carrier_.get();
}

void FleetExecutor::Init(
Expand Down Expand Up @@ -84,16 +86,16 @@ void FleetExecutor::Init(
}
VLOG(5) << runtime_graph_->DebugString();
msg_bus_ = std::make_shared<MessageBus>();
CreateCarrier();
InitCarrier();
InitMessageBus();
}

void FleetExecutor::InitCarrier() {
Carrier& carrier = GetCarrier();
if (!carrier.IsInit()) {
carrier.SetMsgBus(msg_bus_);
carrier.Init(runtime_graph_, root_scope_, minibatch_scope_,
microbatch_scopes_, place_);
if (!GetCarrier()->IsInit()) {
GetCarrier()->SetMsgBus(msg_bus_);
GetCarrier()->Init(exe_desc_.cur_rank(), runtime_graph_, root_scope_,
minibatch_scope_, microbatch_scopes_, place_);
}
}

Expand Down Expand Up @@ -128,21 +130,19 @@ void FleetExecutor::InitMessageBus() {
<< (rank_to_addr.size() == 0 ? 1 : rank_to_addr.size()) << ".";
VLOG(5) << ss.str();
if (!msg_bus_->IsInit()) {
msg_bus_->Init(runtime_graph_->intercepter_id_to_rank(), rank_to_addr,
addr);
msg_bus_->Init(cur_rank, rank_to_addr, addr);
}
}

void FleetExecutor::Run() {
// Run
Carrier& carrier = GetCarrier();
PADDLE_ENFORCE_EQ(
carrier.IsInit(), true,
GetCarrier()->IsInit(), true,
platform::errors::Unavailable("Carrier has not been init yet."));
PADDLE_ENFORCE_EQ(
msg_bus_->IsInit(), true,
platform::errors::Unavailable("MessageBus has not been init yet."));
carrier.Start();
GetCarrier()->Start();
for (auto* micro_scop : microbatch_scopes_) {
// By default, we should delete all kid scopes after run executor because
// some operators may create local scope when running, such as while_op.
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/distributed/fleet_executor/fleet_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <memory>
#include <string>

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
Expand All @@ -30,7 +31,6 @@ namespace distributed {
class RuntimeGraph;
class MessageBus;
class TaskNode;
class Carrier;

class FleetExecutor final {
public:
Expand All @@ -43,7 +43,15 @@ class FleetExecutor final {
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
void Run();
// TODO(liyurui): Change to use registry table for multi-carrier.
static Carrier& GetCarrier();
static Carrier* GetCarrier();
template <typename... Args>
static Carrier* CreateCarrier(Args&&... args) {
PADDLE_ENFORCE_EQ(
carrier_.get(), nullptr,
platform::errors::AlreadyExists("Carrier has been created already."));
carrier_ = std::make_unique<Carrier>(std::forward<Args>(args)...);
return carrier_.get();
}

private:
DISABLE_COPY_AND_ASSIGN(FleetExecutor);
Expand All @@ -59,6 +67,7 @@ class FleetExecutor final {
// The carriers under FleetExecutor will share message bus,
// using shared_ptr to manage lifetime and condition race.
std::shared_ptr<MessageBus> msg_bus_;
static std::unique_ptr<Carrier> carrier_;
};

} // namespace distributed
Expand Down
Loading

0 comments on commit ff4b331

Please sign in to comment.