Skip to content

Commit

Permalink
Add XPU bf16 support for squeeze, unsqueeze kernels (PaddlePaddle#58161)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqulilujia authored and wentaoyu committed Oct 24, 2023
1 parent deef793 commit 39bce4f
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 8 deletions.
20 changes: 16 additions & 4 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze2",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -806,6 +808,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -814,6 +817,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -822,6 +827,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"stack",
XPUKernelSet({phi::DataType::FLOAT32,
Expand Down Expand Up @@ -935,7 +942,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
Expand All @@ -944,24 +952,28 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
phi::DataType::BFLOAT16})},
{"unstack",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
Expand Down
20 changes: 16 additions & 4 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze2",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -788,6 +790,7 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -796,6 +799,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"squeeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
Expand All @@ -804,6 +809,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::FLOAT32})},
{"stack",
XPUKernelSet({phi::DataType::FLOAT32,
Expand Down Expand Up @@ -917,7 +924,8 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
Expand All @@ -926,24 +934,28 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze_grad",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unsqueeze",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::INT8,
phi::DataType::UINT8,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
phi::DataType::BFLOAT16})},
{"unstack",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/squeeze_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ PD_REGISTER_KERNEL(squeeze_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ PD_REGISTER_KERNEL(squeeze_infer,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand All @@ -129,6 +130,7 @@ PD_REGISTER_KERNEL(squeeze,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/unsqueeze_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ PD_REGISTER_KERNEL(unsqueeze_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/unsqueeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ PD_REGISTER_KERNEL(unsqueeze_infer,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand All @@ -137,6 +138,7 @@ PD_REGISTER_KERNEL(unsqueeze,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
int,
uint8_t,
Expand Down

0 comments on commit 39bce4f

Please sign in to comment.