Skip to content

Commit

Permalink
[cherry-pick] Optimize performance of dygraph (#42196) (#42329)
Browse files Browse the repository at this point in the history
* Optimize performance of dygraph (v4)  (#42196)

* optimize performance of dygraph

* optimize performance of dygraph and elementwise_add

* optimize the trace op

* fix bug

* fix bug

* fix unittest bug

* fix code format

* fix cherry-pick problem
  • Loading branch information
zyfncg authored Apr 28, 2022
1 parent fe4646d commit 2ea56c9
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 76 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/data_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ size_t SizeOfType(proto::VarType::Type type) {
}

// Now only supports promotion of complex type
bool NeedPromoteTypes(const proto::VarType::Type a,
const proto::VarType::Type b) {
inline bool NeedPromoteTypes(const proto::VarType::Type& a,
const proto::VarType::Type& b) {
return (IsComplexType(a) || IsComplexType(b));
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ inline std::ostream& operator<<(std::ostream& out,
return out;
}

extern inline bool IsComplexType(const proto::VarType::Type type) {
extern inline bool IsComplexType(const proto::VarType::Type& type) {
return (type == proto::VarType::COMPLEX64 ||
type == proto::VarType::COMPLEX128);
}
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/framework/op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ namespace framework {

std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) {
const VariableNameMap& outputs, const AttributeMap& attrs,
bool attr_check) {
auto& info = OpInfoMap::Instance().Get(type);
if (attr_check && info.Checker() != nullptr) {
info.Checker()->Check(&attrs);
auto tmp_attrs = attrs;
info.Checker()->Check(&tmp_attrs);
return std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, tmp_attrs));
}
auto op = info.Creator()(type, inputs, outputs, attrs);
return std::unique_ptr<OperatorBase>(op);
return std::unique_ptr<OperatorBase>(
info.Creator()(type, inputs, outputs, attrs));
}

static VariableNameMap ConvertOpDescVarsToVarNameMap(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
AttributeMap attrs,
const AttributeMap& attrs,
bool attr_check = true);

static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
Expand Down
26 changes: 14 additions & 12 deletions paddle/fluid/framework/phi_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,21 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
phi::KernelKey TransOpKernelTypeToPhiKernelKey(
const OpKernelType& kernel_type) {
phi::Backend backend = phi::TransToPhiBackend(kernel_type.place_);
if (kernel_type.library_type_ == LibraryType::kMKLDNN) {
backend = phi::Backend::MKLDNN;
} else if (kernel_type.library_type_ == LibraryType::kCUDNN) {
backend = phi::Backend::GPUDNN;
} else if (kernel_type.library_type_ == LibraryType::kKP) {
backend = phi::Backend::KPS;
} else {
// do nothing
switch (kernel_type.library_type_) {
case LibraryType::kCUDNN:
backend = phi::Backend::GPUDNN;
break;
case LibraryType::kMKLDNN:
backend = phi::Backend::MKLDNN;
break;
case LibraryType::kKP:
backend = phi::Backend::KPS;
break;
default:
break;
}
paddle::experimental::DataLayout layout = kernel_type.data_layout_;
paddle::experimental::DataType dtype =
paddle::framework::TransToPhiDataType(kernel_type.data_type_);
return phi::KernelKey(backend, layout, dtype);
return phi::KernelKey(backend, kernel_type.data_layout_,
framework::TransToPhiDataType(kernel_type.data_type_));
}

phi::KernelKey FallBackToCpu(const OpKernelType& expected_kernel_key,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
auto* op_kernel = static_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
"Only support operator with kernel in Dygraph mode."));
Expand Down
63 changes: 37 additions & 26 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ static const phi::Kernel empty_kernel;
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;

const phi::KernelFactory& PreparedOp::phi_kernel_factory =
phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
phi::DefaultKernelSignatureMap::Instance();

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
const std::shared_ptr<paddle::imperative::VarBase>& var) {
return var->SharedVar();
Expand Down Expand Up @@ -139,12 +146,14 @@ PreparedOp::PreparedOp(const framework::OperatorBase& op,
phi_kernel_(phi_kernel) {}

template <typename VarType>
PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
PreparedOp PrepareImpl(
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& outs,
const framework::OperatorWithKernel& op, const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const phi::KernelFactory& phi_kernel_factory,
const phi::OpUtilsMap& phi_op_utils_map,
const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);

Expand Down Expand Up @@ -184,15 +193,15 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,

bool has_phi_kernel = false;

const auto* arg_map_fn =
phi::OpUtilsMap::Instance().GetArgumentMappingFn(op.Type());
const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());

if (arg_map_fn) {
has_phi_kernel = true;
kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else {
default_kernel_signature =
phi::DefaultKernelSignatureMap::Instance().GetNullable(op.Type());
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
Expand Down Expand Up @@ -228,8 +237,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
<< ", using_kernel_key:" << expected_kernel_key;
phi::KernelKey try_pt_kernel_key =
TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
if (!phi::KernelFactory::Instance().HasKernel(pt_kernel_name,
try_pt_kernel_key)) {
if (!phi_kernel_factory.HasKernel(pt_kernel_name, try_pt_kernel_key)) {
expected_kernel_key.library_type_ = expected_kernel_key_library_type;
VLOG(3) << "modify XPU KP kernel: " << op.Type() << " is failed "
<< expected_kernel_key;
Expand All @@ -239,8 +247,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif

pt_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
auto& phi_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_kernel_key);
auto& phi_kernel =
phi_kernel_factory.SelectKernel(pt_kernel_name, pt_kernel_key);

if (phi_kernel.IsValid()
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
Expand Down Expand Up @@ -295,11 +303,11 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
|| (is_xpu_unsupport && !is_xpu_kp_support)
#endif
) {
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op.Type())) {
if (has_phi_kernel) {
auto pt_cpu_kernel_key =
FallBackToCpu(expected_kernel_key, pt_kernel_key, op);
auto& pt_cpu_kernel = phi::KernelFactory::Instance().SelectKernel(
pt_kernel_name, pt_cpu_kernel_key);
auto& pt_cpu_kernel =
phi_kernel_factory.SelectKernel(pt_kernel_name, pt_cpu_kernel_key);
if (pt_cpu_kernel.IsValid()) {
VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << pt_kernel_name
<< " | kernel key: " << pt_cpu_kernel_key
Expand Down Expand Up @@ -408,7 +416,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs);
return PrepareImpl<VarBase>(ins, outs, op, place, attrs, default_attrs,
phi_kernel_factory, phi_op_utils_map,
default_phi_kernel_sig_map);
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
Expand All @@ -417,8 +427,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<VariableWrapper>(ins, outs, op, place, attrs,
default_attrs);
return PrepareImpl<VariableWrapper>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
}

PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
Expand All @@ -427,8 +438,9 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
const platform::Place& place,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerVariable>(ins, outs, op, place, attrs,
default_attrs);
return PrepareImpl<egr::EagerVariable>(
ins, outs, op, place, attrs, default_attrs, phi_kernel_factory,
phi_op_utils_map, default_phi_kernel_sig_map);
}
template <typename VarType>
static void PreparedOpRunImpl(
Expand All @@ -441,7 +453,6 @@ static void PreparedOpRunImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
// TODO(zjl): remove scope in dygraph
framework::Scope scope;

{
platform::RecordEvent record_event("infer_shape",
Expand All @@ -458,8 +469,8 @@ static void PreparedOpRunImpl(
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);

func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
attrs, default_attrs));
func(DygraphExecutionContext<VarType>(op, empty_scope, *dev_ctx, ctx, ins,
outs, attrs, default_attrs));
}

if (FLAGS_check_nan_inf) {
Expand Down Expand Up @@ -503,7 +514,7 @@ static void PreparedOpRunPtImpl(
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
{
platform::RecordEvent record_event(op.Type() + "::infer_shape",
platform::RecordEvent record_event("infer_shape",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);
DygraphInferShapeContext<VarType> infer_shape_ctx(
Expand All @@ -513,7 +524,7 @@ static void PreparedOpRunPtImpl(
}

{
platform::RecordEvent record_event(op.Type() + "::compute",
platform::RecordEvent record_event("compute",
platform::TracerEventType::OperatorInner,
1, platform::EventRole::kInnerOp);

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/imperative/prepared_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ class PreparedOp {
const phi::KernelSignature* default_kernel_signature_;
phi::KernelSignature kernel_signature_;
const phi::Kernel& phi_kernel_;

static const phi::KernelFactory& phi_kernel_factory;
static const phi::OpUtilsMap& phi_op_utils_map;
static const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map;
};

const inline framework::Attribute& GetAttr(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void Tracer::TraceOpImpl(const std::string& type,
paddle::framework::AttributeMap* passed_default_attrs_,
bool use_default_attr_map) {
platform::RecordEvent op_type_record_event(
type + " trace_op", platform::TracerEventType::Operator, 1);
"trace_op", platform::TracerEventType::Operator, 1);
platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) {
Expand Down
39 changes: 20 additions & 19 deletions paddle/phi/core/compat/convert_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,26 @@ namespace phi {

Backend TransToPhiBackend(const phi::Place& place) {
auto allocation_type = place.GetType();
if (allocation_type == phi::AllocationType::CPU) {
return Backend::CPU;
} else if (allocation_type == phi::AllocationType::GPU) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::GPUPINNED) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) {
return Backend::NPU;
} else if (allocation_type == phi::AllocationType::IPU) {
return Backend::IPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) {
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
switch (allocation_type) {
case phi::AllocationType::GPU:
return Backend::GPU;
case AllocationType::CPU:
return Backend::CPU;
case AllocationType::GPUPINNED:
return Backend::GPU;
case AllocationType::XPU:
return Backend::XPU;
case AllocationType::NPU:
return Backend::NPU;
case AllocationType::IPU:
return Backend::IPU;
case AllocationType::CUSTOM:
return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
}
}

Expand Down
5 changes: 2 additions & 3 deletions paddle/phi/core/dense_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ void* DenseTensor::AllocateFrom(Allocator* allocator,

template <typename T>
const T* DenseTensor::data() const {
check_memory_size();
PADDLE_ENFORCE_EQ(
dtype(),
paddle::experimental::CppTypeToDataType<T>::Type(),
Expand All @@ -141,13 +140,13 @@ const T* DenseTensor::data() const {

template <typename T>
T* DenseTensor::data() {
check_memory_size();
T* ret = static_cast<T*>(data());
PADDLE_ENFORCE(
(dtype() == paddle::experimental::CppTypeToDataType<T>::Type()),
phi::errors::InvalidArgument(
"The type of data we are trying to retrieve does not match the "
"type of data currently contained in the container."));
return static_cast<T*>(data());
return ret;
}

void* DenseTensor::data() {
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
const tensor_type& arg = ctx->InputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
Expand All @@ -96,7 +96,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
auto arg = ctx->OptionalInputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
Expand All @@ -117,7 +117,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
std::vector<const tensor_type*> arg = std::move( \
ctx->InputsBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \
Expand All @@ -141,7 +141,7 @@ namespace phi {
"Kernel's Input should appear before Attributes."); \
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
const std::pair<int, int> range = ctx->InputRangeAt(in_idx); \
const std::pair<int, int>& range = ctx->InputRangeAt(in_idx); \
paddle::optional<const std::vector<const tensor_type*>> arg = \
ctx->OptionalInputsBetween<tensor_type>(range.first, range.second); \
KernelCallHelper<Tail...>:: \
Expand Down Expand Up @@ -195,7 +195,7 @@ namespace phi {
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
tensor_type* arg = ctx->MutableOutputAt<tensor_type>(range.first); \
KernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
Expand All @@ -212,7 +212,7 @@ namespace phi {
int out_idx, \
typename... PreviousArgs> \
static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx); \
const std::pair<int, int>& range = ctx->OutputRangeAt(out_idx); \
std::vector<tensor_type*> arg = std::move( \
ctx->MutableOutputBetween<tensor_type>(range.first, range.second)); \
KernelCallHelper<Tail...>:: \
Expand Down
Loading

0 comments on commit 2ea56c9

Please sign in to comment.