From 39bce4ff4db31a57ebfcfad32bfd06d5892490ff Mon Sep 17 00:00:00 2001 From: Lucas <33367939+cqulilujia@users.noreply.github.com> Date: Mon, 23 Oct 2023 10:49:41 +0800 Subject: [PATCH] Add XPU bf16 support for squeeze, unsqueeze kernels (#58161) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 20 ++++++++++++++++---- paddle/phi/backends/xpu/xpu3_op_list.cc | 20 ++++++++++++++++---- paddle/phi/kernels/squeeze_grad_kernel.cc | 1 + paddle/phi/kernels/squeeze_kernel.cc | 2 ++ paddle/phi/kernels/unsqueeze_grad_kernel.cc | 1 + paddle/phi/kernels/unsqueeze_kernel.cc | 2 ++ 6 files changed, 38 insertions(+), 8 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 74a8cf0bc1150e..cdc6d895b84be0 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -797,6 +797,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -806,6 +808,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -814,6 +817,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -822,6 +827,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -935,7 +942,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -944,7 +952,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -952,7 +961,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -960,8 +971,9 @@ XPUOpMap& get_kl2_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 29a85493958949..6174f13cd30b26 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -779,6 +779,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze2", XPUKernelSet({phi::DataType::FLOAT64, @@ -788,6 +790,7 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, @@ -796,6 +799,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"squeeze_grad", XPUKernelSet({phi::DataType::FLOAT64, @@ -804,6 +809,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16, phi::DataType::FLOAT32})}, {"stack", XPUKernelSet({phi::DataType::FLOAT32, @@ -917,7 +924,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze2", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -926,7 +934,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32, - phi::DataType::FLOAT16})}, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze_grad", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -934,7 +943,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, - phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, @@ -942,8 +953,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BOOL, phi::DataType::INT8, phi::DataType::UINT8, + phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::FLOAT32})}, + phi::DataType::BFLOAT16})}, {"unstack", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/squeeze_grad_kernel.cc b/paddle/phi/kernels/squeeze_grad_kernel.cc index 473acf9d7a1d15..a8a788e817472b 100644 --- a/paddle/phi/kernels/squeeze_grad_kernel.cc +++ b/paddle/phi/kernels/squeeze_grad_kernel.cc @@ -76,6 +76,7 @@ PD_REGISTER_KERNEL(squeeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index d495b040921b59..a8d24423fcb45c 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(squeeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -129,6 +130,7 @@ PD_REGISTER_KERNEL(squeeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/unsqueeze_grad_kernel.cc b/paddle/phi/kernels/unsqueeze_grad_kernel.cc index 3c119db2c73d6e..d26753ece47cdc 100644 --- a/paddle/phi/kernels/unsqueeze_grad_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_grad_kernel.cc @@ -77,6 +77,7 @@ PD_REGISTER_KERNEL(unsqueeze_grad, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index c08c31da4ef0ce..ddfe142894b151 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -124,6 +124,7 @@ PD_REGISTER_KERNEL(unsqueeze_infer, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t, @@ -137,6 +138,7 @@ PD_REGISTER_KERNEL(unsqueeze, float, double, phi::dtype::float16, + phi::dtype::bfloat16, bool, int, uint8_t,