Skip to content

Commit

Permalink
[XPU] Optimize fp16 xpu models (#53523)
Browse files Browse the repository at this point in the history
  • Loading branch information
wz1qqx committed May 8, 2023
1 parent 186f5e0 commit 0a59825
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 63 deletions.
51 changes: 43 additions & 8 deletions paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,25 @@ class Scope;
} // namespace framework
} // namespace paddle

namespace {

template <typename T1, typename T2>
void ConvertTensorType(phi::DenseTensor* tensor) {
phi::DenseTensor tmp_tensor;
tmp_tensor.set_type(phi::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}

} // namespace

namespace paddle {
namespace framework {
namespace ir {
Expand Down Expand Up @@ -157,15 +176,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
if (with_bn_) {
ew_bias_add_out->assert_is_op_input("batch_norm", "X");
bn_bias = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1);
bn_mean = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1);
bn_scale = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1);
bn_var = pattern->NewNode(bn_var_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1);
bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
Expand Down Expand Up @@ -420,13 +447,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
// recompute bias and weight for conv2d_xpu op
auto* filter_t =
scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>();
// conv_filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float16, float>(filter_t);
}
auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias;
bool has_branch = with_branch_x || with_branch_y;
// Create conv_fusion_bias (conv bias) variable
Node* fusion_bias_node = nullptr;
if (has_bias) {
if (ew_bias_add != nullptr) {
if (with_conv_bias) {
auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name())
->GetMutable<phi::DenseTensor>();
auto ew_bias_add_y_dims = ew_bias_add_y_t->dims();
Expand All @@ -439,7 +470,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
filter_dims[0]));
PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node);
}
if (bn != nullptr) {
if (with_bn) {
auto bn_bias_t =
scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(filter_dims[0],
Expand Down Expand Up @@ -469,40 +500,44 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
auto filter_len = filter_t->numel();
auto filter_stride = filter_len / mean_len;
float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon"));
if (fusion_bias_node == nullptr) { // prev node is conv
if (!with_conv_bias) { // prev node is conv
PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node);
}
auto fusion_bias_t = scope->Var(fusion_bias_node->Name())
->GetMutable<phi::DenseTensor>();
float* fusion_bias_ptr =
fusion_bias_t->mutable_data<float>(paddle::platform::CPUPlace());
// recompute bias and weights
if (ew_bias_add == nullptr) {
if (!with_conv_bias) { // prev node is conv
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i];
fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
}
}
} else {
for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
bn_bias_ptr[i] +=
fusion_bias_ptr[i] =
bn_bias_ptr[i] +
(fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
}
}
memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float));
}
}
}
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float, float16>(filter_t);
}
// filter max
Node* filter_int16 = nullptr;
Node* filter_max = nullptr;
PrepareWeight<int16_t>(
graph, scope, block, conv_filter, &filter_int16, &filter_max, false);
bool has_branch = with_branch_x || with_branch_y;
// output && output max
std::string conv2d_xpu_out_name;
if (!act_type.empty()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/fused_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# otherwise the operator only could be used in static mode.

- op : conv2d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(out), Tensor(out_max)
infer_meta :
func : Conv2dXPUInferMeta
kernel :
func : conv2d_xpu
data_type : x
optional : bias, branch, x_max
optional : bias, branch, branch_max ,x_max

- op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
Expand Down
19 changes: 12 additions & 7 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ XPUOpMap& get_kl2_ops() {
{"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -401,7 +402,8 @@ XPUOpMap& get_kl2_ops() {
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -438,7 +440,8 @@ XPUOpMap& get_kl2_ops() {
{"layer_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"less_equal",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
Expand Down Expand Up @@ -554,7 +557,8 @@ XPUOpMap& get_kl2_ops() {
{"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -646,7 +650,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::FLOAT16})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})},
{"sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sign", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice_grad",
Expand Down Expand Up @@ -676,7 +681,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::INT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -733,7 +738,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT16,
phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"tanh_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(swish,

#if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {}
PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif

#ifdef PADDLE_WITH_MKLDNN
Expand Down
8 changes: 6 additions & 2 deletions paddle/phi/kernels/batch_norm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ PD_REGISTER_KERNEL(batch_norm_infer,
phi::dtype::float16) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(
batch_norm_infer, XPU, ALL_LAYOUT, phi::BatchNormInferKernel, float) {}
PD_REGISTER_KERNEL(batch_norm_infer,
XPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
phi::dtype::float16) {}
#endif
7 changes: 5 additions & 2 deletions paddle/phi/kernels/fusion/xpu/conv2d_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
const std::vector<int>& strides,
Expand Down Expand Up @@ -69,10 +70,12 @@ void Conv2dXPUKernel(const Context& ctx,
branch.get_ptr() == nullptr
? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
const float* branch_max_data = branch_max.get_ptr() == nullptr
? nullptr
: branch_max.get_ptr()->data<float>();
const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));

xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_param;
Expand Down Expand Up @@ -102,7 +105,7 @@ void Conv2dXPUKernel(const Context& ctx,
/* const float* bias */ bias_data,
/* const TY* branch */ branch_data,
/* const baidu::xpu::api::Activation_t& act */ act,
/* const float* branch_maxptr */ nullptr,
/* const float* branch_maxptr */ branch_max_data,
/* const float* scale */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
}
Expand Down
45 changes: 37 additions & 8 deletions paddle/phi/kernels/xpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
// using XPUType = typename XPUTypeTrait<T>::Type;
// // dev_ctx.template Alloc<T>(out);
// auto pow_factor = factor.to<T>();
// const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
// auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));
// // const T* x_data = x.data<T>();
// // T* y_data = out->data<T>();
dev_ctx.template Alloc<T>(out);
float pow_factor = factor.to<float>();
const T* x_data = x.data<T>();
Expand Down Expand Up @@ -534,9 +541,28 @@ PD_REGISTER_KERNEL(
relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {}

#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
PD_REGISTER_KERNEL(
sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(swish_raw,
XPU,
ALL_LAYOUT,
phi::SwishRawKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(hard_sigmoid,
XPU,
ALL_LAYOUT,
phi::HardSigmoidKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(leaky_relu,
XPU,
ALL_LAYOUT,
phi::LeakyReluKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {}

PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
Expand All @@ -547,18 +573,21 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}

#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}

PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
// PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
// PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
// PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
// PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
Loading

0 comments on commit 0a59825

Please sign in to comment.