Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New features]Add function node in phi_kernel for MKLDNN #51073

Merged
merged 15 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,8 @@ if(WITH_XPU)
kernel_factory
infershape_utils
op_utils
op_compat_infos)
op_compat_infos
get_kerneltype_forvar_utils)
else()
cc_library(
operator
Expand All @@ -505,7 +506,8 @@ else()
kernel_factory
infershape_utils
op_utils
op_compat_infos)
op_compat_infos
get_kerneltype_forvar_utils)
endif()

cc_test(
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/framework/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,31 @@ void SetTensorToVariable(const Variable &in_var,
}
}

phi::GetKernelTypeForVarContext BuildGetKernelTypeForVarContext(
const phi::KernelKey &kernel_key,
const AttributeMap &fluid_attrs,
phi::AttributeMap *phi_attrs,
heavyrain-lzy marked this conversation as resolved.
Show resolved Hide resolved
bool has_infer_varkernel_fn) {
// According to "GetKernelTypeForVar" in some ops those have MKLDNN codes,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// According to "GetKernelTypeForVar" in some ops those have MKLDNN codes,
// According to "GetKernelTypeForVar" in some ops executed with oneDNN,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will change the comment according to your suggestion.

// the only "string" member, such as "data_layout" 、"data_format" of
// AttibuteMap is useful. In the future the other args maybe used. Because the
// "phi" module should not depend on the "fluid", transform
// "framework::AttributeMap" to "phi::AttributeMap".
if (has_infer_varkernel_fn) {
for (auto &attr : fluid_attrs) {
heavyrain-lzy marked this conversation as resolved.
Show resolved Hide resolved
switch (attr.second.index()) {
case 3: // string type in framwork::Attribute
(*phi_attrs)[attr.first] = PADDLE_GET_CONST(std::string, attr.second);
break;
default:
heavyrain-lzy marked this conversation as resolved.
Show resolved Hide resolved
heavyrain-lzy marked this conversation as resolved.
Show resolved Hide resolved
VLOG(6) << "GetKernelTypeForVarContext currently only use "
"std::string. You add other type if need.";
break;
}
}
}
return phi::GetKernelTypeForVarContext(&kernel_key, phi_attrs);
}

} // namespace framework
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/framework/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
Expand All @@ -45,5 +46,12 @@ void TransformData(const phi::KernelKey &expected_kernel_type,
void SetTensorToVariable(const Variable &in_var,
const phi::DenseTensor &tensor,
Variable *out_var);

phi::GetKernelTypeForVarContext BuildGetKernelTypeForVarContext(
const phi::KernelKey &kernel_key,
const AttributeMap &fluid_attrs,
phi::AttributeMap *phi_attrs,
bool has_infer_varkernel_fn);

} // namespace framework
} // namespace paddle
28 changes: 21 additions & 7 deletions paddle/fluid/framework/new_executor/interpreter/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
Expand Down Expand Up @@ -474,7 +475,17 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,

bool transfered = false;
DataTranferHelper data_transfer_helper(place, var_scope, local_scope);

phi::Kernel* phi_kernel = op_func_node->phi_kernel_;
auto has_infer_varkernel_fn =
(phi_kernel && phi_kernel->get_kerneltype_forvar_fn_ != nullptr);
phi::AttributeMap infer_attrs{};
auto fluid_attrs =
static_cast<const framework::OperatorWithKernel*>(op_base)->Attrs();
auto phi_kernelkey =
framework::TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
phi::GetKernelTypeForVarContext infer_varkernel_context =
BuildGetKernelTypeForVarContext(
phi_kernelkey, fluid_attrs, &infer_attrs, has_infer_varkernel_fn);
auto apply_data_transform_for_one_parameter =
[&](const std::string& parameter_name,
const std::vector<std::string>& argument_names,
Expand Down Expand Up @@ -551,11 +562,15 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
auto kernel_key_for_var =
static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(
parameter_name,
*tensor_in,
framework::TransOpKernelTypeToPhiKernelKey(
expected_kernel_key));

parameter_name, *tensor_in, phi_kernelkey);
if (has_infer_varkernel_fn) {
infer_varkernel_context.SetVarName(
const_cast<std::string*>(&parameter_name));
infer_varkernel_context.SetDenseTensor(
const_cast<phi::DenseTensor*>(tensor_in));
kernel_key_for_var = phi_kernel->get_kerneltype_forvar_fn_(
&infer_varkernel_context);
}
std::unique_ptr<phi::KernelKey>
expected_kernel_key_for_argument_def = nullptr;
if (argument_def &&
Expand Down Expand Up @@ -634,7 +649,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
}
};

phi::Kernel* phi_kernel = op_func_node->phi_kernel_;
if (phi_kernel && phi_kernel->IsValid() &&
phi_kernel->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/compat/get_kerneltype_forvar_utils.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
Expand Down Expand Up @@ -2448,6 +2449,16 @@ Scope* OperatorWithKernel::PrepareData(
}
}

auto has_infer_varkernel_fn =
(run_phi_kernel_ && phi_kernel_->get_kerneltype_forvar_fn_ != nullptr);
phi::AttributeMap infer_attrs{};
auto fluid_attrs = Attrs();
phi::GetKernelTypeForVarContext infer_varkernel_context =
BuildGetKernelTypeForVarContext(expected_kernel_key,
fluid_attrs,
&infer_attrs,
has_infer_varkernel_fn);

const auto& name_map = Inputs();
auto prepare_input_data = [&](const std::string& in_name,
std::vector<Variable*>* in_vars,
Expand Down Expand Up @@ -2510,6 +2521,13 @@ Scope* OperatorWithKernel::PrepareData(

auto kernel_type_for_var =
GetKernelTypeForVar(in_name, *tensor_in, expected_kernel_key);
if (has_infer_varkernel_fn) {
infer_varkernel_context.SetVarName(const_cast<std::string*>(&in_name));
infer_varkernel_context.SetDenseTensor(
const_cast<phi::DenseTensor*>(tensor_in));
kernel_type_for_var =
phi_kernel_->get_kerneltype_forvar_fn_(&infer_varkernel_context);
}
bool need_trans_dtype =
NeedTransformDataType(expected_kernel_key, kernel_type_for_var);
bool need_trans_layout = NeedTransformLayout(
Expand Down
Loading