Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] Predictor support pir and new executor #58452

Merged
merged 17 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions paddle/fluid/framework/naive_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ void NaiveExecutor::PrepareInterpreterCore(
place_, program_desc.Block(0), scope, execution_config);
}

void NaiveExecutor::PrepareInterpreterCore(
Scope *scope,
const ::pir::Program &pir_program,
const framework::interpreter::ExecutionConfig &execution_config) {
interpreter_core_ =
std::make_unique<framework::InterpreterCore>(place_,
std::vector<std::string>{},
pir_program.block(),
scope,
execution_config);
}

void NaiveExecutor::RunInterpreterCore(
const std::vector<std::string> &feed_names, bool need_fetch) {
platform::ScopedFlushDenormal flush;
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/framework/naive_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"

#include "paddle/pir/core/program.h"

namespace paddle {
namespace framework {

Expand Down Expand Up @@ -61,6 +63,12 @@ class NaiveExecutor {
const framework::interpreter::ExecutionConfig& execution_config =
framework::interpreter::ExecutionConfig{});

void PrepareInterpreterCore(
Scope* scope,
const ::pir::Program& pir_program,
const framework::interpreter::ExecutionConfig& execution_config =
framework::interpreter::ExecutionConfig{});

// Create variables before head.
// Create parameters if persistable is true, or create the temporary variables
// instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ const paddle::framework::Variable* GetVariableByName(
return nullptr;
}

std::vector<std::string> GetOriginInputNames(std::string op_name) {
std::vector<std::string> GetOriginInputNames(const std::string& op_name) {
std::vector<std::string> ret;
pir::IrContext* ctx = pir::IrContext::Instance();
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
Expand All @@ -1220,7 +1220,7 @@ std::vector<std::string> GetOriginInputNames(std::string op_name) {
return ret;
}

std::vector<std::string> GetOriginOutputNames(std::string op_name) {
std::vector<std::string> GetOriginOutputNames(const std::string& op_name) {
std::vector<std::string> ret;
pir::IrContext* ctx = pir::IrContext::Instance();
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ const paddle::framework::Variable* GetVariableByName(
const std::unordered_map<const paddle::framework::Variable*, std::string>&
variable_2_var_name);

std::vector<std::string> GetOriginInputNames(std::string op_name);
std::vector<std::string> GetOriginInputNames(const std::string& op_name);

std::vector<std::string> GetOriginOutputNames(std::string op_name);
std::vector<std::string> GetOriginOutputNames(const std::string& op_name);

void PrintValuesAndVariables(
const pir::Block& block,
Expand Down
41 changes: 12 additions & 29 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::shared_ptr<ValueExecutionInfo> ValueExecutionInfo::NewChild(Scope* scope) {
return info;
}

void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) {
void ValueExecutionInfo::Add(::pir::Value value, const std::string& var_name) {
auto* var = scope_->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Cannot find %s in scope.", var_name));
Expand All @@ -84,8 +84,8 @@ void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) {
}

void ValueExecutionInfo::Rename(pir::Value value,
std::string new_name,
std::string orig_name) {
const std::string& new_name,
const std::string& orig_name) {
value_2_var_name_[value] = new_name;

for (auto kv : value_2_var_name_) {
Expand Down Expand Up @@ -344,9 +344,7 @@ void HandleForSpecialOp(pir::Operation* op,
auto value = op->result(0);

value_exe_info->Add(value, fetch_var_name);
}

if (op_name == "pd_op.feed" || op_name == "pd_op.data") {
} else if (op_name == "pd_op.feed" || op_name == "pd_op.data") {
VLOG(6) << "Handle for" << op_name;
auto value = op->result(0);
VLOG(6) << "link feed output to feed in variable"
Expand All @@ -360,9 +358,7 @@ void HandleForSpecialOp(pir::Operation* op,
"The variable %s shoud exist", name));

value_exe_info->Add(value, name);
}

if (op_name == "builtin.combine") {
} else if (op_name == "builtin.combine") {
auto out_value = op->result(0);

Variable* var = nullptr;
Expand All @@ -386,9 +382,7 @@ void HandleForSpecialOp(pir::Operation* op,
tensor_array->emplace_back(
value_exe_info->GetScope()->FindVar(value_2_var_name.at(value)));
}
}

if (op_name == "builtin.set_parameter") {
} else if (op_name == "builtin.set_parameter") {
VLOG(6) << "Handle for builtin.set_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
Expand All @@ -413,8 +407,7 @@ void HandleForSpecialOp(pir::Operation* op,
}

value_exe_info->Rename(value, param_name, orig_name);
}
if (op_name.compare(pir::ShadowOutputOp::name()) == 0) {
} else if (op_name == "builtin.shadow_output") {
VLOG(6) << "Handle for builtin.shadow_ouptut";
auto var_name = op->attributes()
.at("output_name")
Expand All @@ -433,9 +426,7 @@ void HandleForSpecialOp(pir::Operation* op,
VLOG(8) << "var " << orig_name << " has been renamed to " << var_name;

value_exe_info->Rename(value, var_name, orig_name);
}

if (op_name == "builtin.get_parameter") {
} else if (op_name == "builtin.get_parameter") {
VLOG(6) << "Handle for builtin.get_parameter:";
auto param_name = op->attributes()
.at("parameter_name")
Expand All @@ -444,9 +435,7 @@ void HandleForSpecialOp(pir::Operation* op,
auto value = op->result(0);

value_exe_info->Add(value, param_name);
}

if (op_name == "builtin.slice") {
} else if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand_source(0);
Expand All @@ -471,9 +460,7 @@ void HandleForSpecialOp(pir::Operation* op,
std::string var_name =
value_exe_info->GetVar2VarName().at(variable_array[index]);
value_exe_info->AddValue2VarName(out_value, var_name);
}

if (op_name == "builtin.split") {
} else if (op_name == "builtin.split") {
VLOG(6) << "Handle for builtin.split";
auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_exe_info->GetValue2VarName().count(in_value),
Expand All @@ -497,17 +484,13 @@ void HandleForSpecialOp(pir::Operation* op,
value_exe_info->GetVar2VarName().at(variable_array[idx]);
value_exe_info->AddValue2VarName(out_value, var_name);
}
}

if (op_name == "pd_op.if") {
} else if (op_name == "pd_op.if") {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
for (size_t i = 0; i < if_op->num_results(); ++i) {
auto if_op_out_value = if_op->result(i);
BuildValue(if_op_out_value, var_name_prefix, value_exe_info);
}
}

if (op_name == "pd_op.while") {
} else if (op_name == "pd_op.while") {
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();

for (size_t i = 0; i < while_op->num_results(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ class ValueExecutionInfo {

Scope* GetScope() const { return scope_; }

void Add(::pir::Value value, std::string var_name);
void Add(::pir::Value value, const std::string& var_name);

void Rename(pir::Value value, std::string new_name, std::string orig_name);
void Rename(pir::Value value,
const std::string& new_name,
const std::string& orig_name);

int GetIdByName(const std::string& name) const;

Expand Down
99 changes: 35 additions & 64 deletions paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,7 @@ paddle::framework::FetchList PirInterpreter::Run(

// Run
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
Expand All @@ -1085,6 +1086,7 @@ paddle::framework::FetchList PirInterpreter::Run(
is_shared_results_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
Expand All @@ -1096,39 +1098,20 @@ paddle::framework::FetchList PirInterpreter::Run(
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}

// return Fetch Tensors
Scope* inner_scope = InnerScope();
if (FLAGS_enable_new_ir_in_executor) {
framework::FetchList fetch_res;

if (need_fetch) {
for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}
}

VLOG(4) << "get fetch list size: " << fetch_res.size();
return fetch_res;
} else {
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var) {
auto fetch_list =
std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
framework::FetchList fetch_res;
if (need_fetch) {
for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}
}

VLOG(4) << "get fetch list size: " << fetch_res.size();
return fetch_res;
}

FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
Expand Down Expand Up @@ -1161,6 +1144,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,

// Run
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
Expand All @@ -1176,6 +1160,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
is_shared_results_build_ = true;
} else {
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
execution_config_.used_for_inference ||
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
(sync_op_num_ == 0))) {
TraceRunImpl();
Expand All @@ -1187,38 +1172,21 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors
Scope* inner_scope = InnerScope();
if (FLAGS_enable_new_ir_in_executor) {
framework::FetchList fetch_res;

if (need_fetch) {
for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}

framework::FetchList fetch_res;
if (need_fetch) {
// return Fetch Tensors
Scope* inner_scope = InnerScope();

for (auto& var_name : fetch_var_names_) {
auto* var = inner_scope->FindVar(var_name);
VLOG(4) << "fetch " << var_name << "[" << var << "]";
fetch_res.push_back(var->Get<phi::DenseTensor>());
}

VLOG(4) << "get fetch list size: " << fetch_res.size();
return fetch_res;
} else {
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list =
std::move(*fetch_var->GetMutable<framework::FetchList>());
#ifdef PADDLE_WITH_CUDA
if (platform::IsCUDAGraphCapturing()) {
PADDLE_ENFORCE_EQ(fetch_list.empty(),
true,
platform::errors::InvalidArgument(
"Cannot fetch data when using CUDA Graph."));
}
#endif
return fetch_list;
} else {
return {};
}
}
return fetch_res;
}

void PirInterpreter::TraceRunImpl() {
Expand Down Expand Up @@ -1437,10 +1405,11 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
platform::RecordEvent instruction_event(
instr_node->Name(), platform::TracerEventType::Operator, 1);

SetDeviceId(instr_node->DeviceContext().GetPlace());
auto cur_place = instr_node->DeviceContext().GetPlace();
SetDeviceId(cur_place);

try {
instr_node->WaitEvent(place_);
instr_node->WaitEvent(cur_place);
VLOG(4) << "begin to run op " << instr_node->Name();
VLOG(4) << "begin: " << __func__ << " OP id:" << instr_node->Id()
<< " name:" << instr_node->Name() << " type:"
Expand All @@ -1450,7 +1419,8 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " Before:"

VLOG(4) << cur_place << " Before:"
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
if (!instr_node->IsArtificial()) {
instr_node->Run();
Expand All @@ -1472,14 +1442,15 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
? "kGpuSync"
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " After:"

VLOG(4) << cur_place << " After:"
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
CheckGC(instr_node);
VLOG(4) << "done CheckGC";
memory::LogDeviceMemoryStats(place_, instr_node->Name());
memory::LogDeviceMemoryStats(cur_place, instr_node->Name());
}
VLOG(5) << "after run kernel";
instr_node->RecordEvent(place_);
instr_node->RecordEvent(cur_place);
} catch (platform::EnforceNotMet& ex) {
auto* op = instr_node->Operation();
const std::vector<std::string> op_callstack_attr =
Expand Down
Loading