From 8243e6b2733ca255eb6cabb2f09141e115f4493b Mon Sep 17 00:00:00 2001 From: HarperCy <1312489431@qq.com> Date: Thu, 17 Nov 2022 08:25:19 +0000 Subject: [PATCH] add square fp16 *test=kunlun --- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 4 +++- paddle/phi/kernels/xpu/activation_kernel.cc | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index cbcbde8f9ddcd..12c25e24011c9 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -544,7 +544,9 @@ XPUOpMap& get_kl2_ops() { {"sqrt", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"sqrt_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"square_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, - {"square", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"square", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"squeeze2_grad", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc index f730c38e8f0f2..450ec3d5f476a 100644 --- a/paddle/phi/kernels/xpu/activation_kernel.cc +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -456,6 +456,9 @@ PD_REGISTER_KERNEL( PD_REGISTER_KERNEL( tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} +PD_REGISTER_KERNEL( + square, XPU, ALL_LAYOUT, phi::SquareKernel, float, phi::dtype::float16) {} + PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) @@ -468,4 +471,3 @@ PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) -PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel)