Skip to content

Commit

Permalink
Support CUDA Graph on ParallelExecutor (#36250)
Browse files Browse the repository at this point in the history
* support CUDA Graph on PE

* add ut, fix CI compile

* reduce memory consumption

* fix CUDA 10 CI

* improve coverage

* improve python coverage
  • Loading branch information
sneaxiy authored Oct 8, 2021
1 parent ca16e8f commit f9591bb
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 42 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ struct BuildStrategy {
// Turn off inplace addto by default.
bool enable_addto_{false};

bool allow_cuda_graph_capture_{false};

// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
Expand Down
19 changes: 14 additions & 5 deletions paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,28 @@ struct ScaleLossGradFunctor {
}
};

std::string ScaleLossGradOpHandle::LossGradName() const {
return static_cast<VarHandle *>(this->outputs_[0])->name();
}

void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
RunOnVar(local_exec_scopes_[0]->FindVar(LossGradName()), true);
}

auto *tensor =
local_exec_scopes_[0]->FindVar(var_name)->GetMutable<LoDTensor>();
void ScaleLossGradOpHandle::RunOnVar(Variable *var, bool record_event) {
auto *tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_,
this->dev_ctxes_.at(place_));
this->RunAndRecordEvent([&] { framework::VisitDataType(out_dtype_, func); });
if (record_event) {
this->RunAndRecordEvent(
[&] { framework::VisitDataType(out_dtype_, func); });
} else {
framework::VisitDataType(out_dtype_, func);
}
#else
ScaleLossGradFunctor func(coeff_, tensor, place_, out_dtype_, nullptr);
framework::VisitDataType(out_dtype_, func);
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/details/scale_loss_grad_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ struct ScaleLossGradOpHandle : public OpHandleBase {

std::string Name() const override;

platform::Place GetPlace() const { return place_; }

void RunOnVar(Variable *var, bool record_event = false);

std::string LossGradName() const;

protected:
void RunImpl() override;

Expand Down
53 changes: 31 additions & 22 deletions paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace framework {
namespace details {
Expand All @@ -49,8 +51,29 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
PrepareLocalExeScopes();
}

static void RunProgramDescs(const ProgramDescs &programs,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places) {
for (auto &program : programs) {
for (auto &op_desc : program.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes[i], places[i]);
}
}
}
}

FetchResultType ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) {
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
strategy_.num_iteration_per_drop_scope_ =
std::numeric_limits<size_t>::max();
DropLocalExeScopes(/*need_wait=*/false);
}
#endif

if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalVars");
InitVariables();
Expand Down Expand Up @@ -84,7 +107,7 @@ FetchResultType ScopeBufferedSSAGraphExecutor::Run(
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_ ||
DropScopeOrNot()) {
DropLocalExeScopes();
DropLocalExeScopes(!platform::IsCUDAGraphCapturing());
}

if (VLOG_IS_ON(5)) {
Expand Down Expand Up @@ -128,39 +151,25 @@ void ScopeBufferedSSAGraphExecutor::InitVariables() {
if (graph.Has(details::kStartupProgramDescs)) {
auto &program_descs =
graph.Get<details::ProgramDescs>(details::kStartupProgramDescs);

for (auto &program_desc : program_descs) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes_[i], places_[i]);
}
}
}
RunProgramDescs(program_descs, local_exec_scopes_, places_);
}
is_initialized_ = true;
}

if (graph.Has(details::kProgramDescs)) {
auto &program_descs =
graph.Get<details::ProgramDescs>(details::kProgramDescs);

for (auto &program_desc : program_descs) {
for (auto &op_desc : program_desc.Block(0).AllOps()) {
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto op = OpRegistry::CreateOp(*op_desc);
op->Run(*local_exec_scopes_[i], places_[i]);
}
}
}
RunProgramDescs(program_descs, local_exec_scopes_, places_);
}
}

void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes(bool need_wait) {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0;
for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
if (need_wait) {
for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
}
scope_monitor_.ClearHistoryLocalExecScopes();
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
FetchResultType Run(const std::vector<std::string>& fetch_tensors,
bool return_merged) override;

void DropLocalExeScopes();
void DropLocalExeScopes(bool need_wait = true);

bool NeedCreateLocalExeScope();

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ message BuildStrategy {
optional bool enable_auto_fusion = 11 [ default = false ];
optional bool enable_addto = 12 [ default = false ];
optional bool fix_op_run_order = 13 [ default = false ];
optional bool allow_cuda_graph_capture = 14 [ default = false ];
}

message ExecutionStrategy {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle op_graph_view multi_devices_helper)

cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,31 @@

#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"

namespace paddle {
namespace framework {
namespace ir {

template <typename T>
static bool IsMatchedPlaceSingleDeviceOp(details::OpHandleBase *op_base,
const platform::Place &place) {
auto *op = dynamic_cast<T *>(op_base);
return op && op->GetPlace() == place;
}

static bool IsLockAndRecordEventFreeComputationOpHandle(
details::ComputationOpHandle *op, const OpGraphView &graph_view) {
if (!platform::is_gpu_place(op->GetPlace()) &&
!platform::is_xpu_place(op->GetPlace()))
return false;
for (auto &pending_op : graph_view.PendingOps(op)) {
auto *tmp = dynamic_cast<details::ComputationOpHandle *>(pending_op);
if (tmp == nullptr || !(tmp->GetPlace() == op->GetPlace())) {
if (!IsMatchedPlaceSingleDeviceOp<details::ComputationOpHandle>(
pending_op, op->GetPlace()) &&
!IsMatchedPlaceSingleDeviceOp<details::ScaleLossGradOpHandle>(
pending_op, op->GetPlace())) {
return false;
}
}
Expand Down
Loading

0 comments on commit f9591bb

Please sign in to comment.