Skip to content

Commit

Permalink
optimize performance of dygraph (#42137)
Browse files Browse the repository at this point in the history
  • Loading branch information
zyfncg committed Apr 24, 2022
1 parent 778155a commit 1051469
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 100 deletions.
9 changes: 4 additions & 5 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
auto arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
KernelSignature signature =
arg_map_fn ? (*arg_map_fn)(arg_map_context)
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

// 2. build infermeta context
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2119,8 +2119,16 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
const ExecutionContext& ctx) const {
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
if (arg_map_fn_ == nullptr) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(
phi::OpUtilsMap::Instance().GetArgumentMappingFn(Type())));
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_);
if (arg_map_fn) {
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
} else {
auto func =
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
};
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
}
}
return (*arg_map_fn_)(arg_mapping_ctx);
}
Expand Down
36 changes: 26 additions & 10 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace paddle {
namespace imperative {

static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
Expand Down Expand Up @@ -138,8 +140,6 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);

framework::RuntimeContext ctx({}, {});

#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
Expand All @@ -158,7 +158,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,

// 1. get expected kernel key
auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, default_attrs);
op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);

framework::KernelSignature pt_kernel_signature;
Expand All @@ -172,11 +172,26 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
paddle::platform::is_in_xpu_black_list(op.Type());

#endif
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
pt_kernel_signature =
std::move(op.GetExpectedPhiKernelArgs(dygraph_exe_ctx));
VLOG(6) << pt_kernel_signature;

bool has_phi_kernel = false;

const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
if (arg_map_fn) {
has_phi_kernel = true;
pt_kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
const auto* kernel_sig =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
if (kernel_sig) {
has_phi_kernel = true;
pt_kernel_signature = *kernel_sig;
}
}

if (has_phi_kernel) {
VLOG(6) << pt_kernel_signature;
pt_kernel_name = pt_kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
Expand Down Expand Up @@ -231,7 +246,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}

return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_kernel, dev_ctx);
} else {
VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << pt_kernel_name
Expand Down Expand Up @@ -280,7 +295,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< " | kernel key: " << pt_cpu_kernel_key
<< " | kernel: " << pt_cpu_kernel;
auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
return PreparedOp(op, ctx, expected_kernel_key,
return PreparedOp(op, empty_ctx, expected_kernel_key,
std::move(pt_kernel_signature), pt_cpu_kernel,
cpu_ctx);
}
Expand Down Expand Up @@ -373,7 +388,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
dev_ctx = pool.Get(expected_kernel_key.place_);
}

return PreparedOp(op, ctx, expected_kernel_key, kernel_iter->second, dev_ctx);
return PreparedOp(op, empty_ctx, expected_kernel_key, kernel_iter->second,
dev_ctx);
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
Expand Down
2 changes: 1 addition & 1 deletion paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void PhiOpConvertPass::convertStage() {
op->replaceAllUsesWith(kernel_op.getResults());
} else {
::phi::KernelSignature kernel_sign =
::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
(*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))(
infrt::ProtoArgumentMappingContext(op));
VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
<< kernel_sign.name << ")";
Expand Down
19 changes: 12 additions & 7 deletions paddle/phi/core/compat/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ class DefaultKernelSignatureMap {
return it->second;
}

const KernelSignature* GetNullable(const std::string& op_type) const {
auto it = map_.find(op_type);
if (it != map_.end()) {
return &it->second;
}
return nullptr;
}

void Insert(std::string op_type, KernelSignature signature) {
PADDLE_ENFORCE_NE(
Has(op_type),
Expand Down Expand Up @@ -148,16 +156,13 @@ class OpUtilsMap {
}
}

ArgumentMappingFn GetArgumentMappingFn(const std::string& op_type) const {
const ArgumentMappingFn* GetArgumentMappingFn(
const std::string& op_type) const {
auto it = arg_mapping_fn_map_.find(op_type);
if (it == arg_mapping_fn_map_.end()) {
auto func =
[&op_type](const ArgumentMappingContext& ctx) -> KernelSignature {
return DefaultKernelSignatureMap::Instance().Get(op_type);
};
return func;
return nullptr;
} else {
return it->second;
return &it->second;
}
}

Expand Down
Loading

1 comment on commit 1051469

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.