From ee7fd7875954e3ba49a484ca103db50f1b15b589 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 5 Dec 2022 16:33:41 +0800 Subject: [PATCH] Register exp/expm1/logit bf16 activation op kernels (#48702) * register more bf16 ops * update to register coresponding backward ops --- paddle/phi/kernels/gpu/activation_grad_kernel.cu | 9 ++++++--- paddle/phi/kernels/gpu/activation_kernel.cu | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 0c8c8b43a0bac..441790aab3ae2 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -371,7 +371,8 @@ PD_REGISTER_KERNEL(exp_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -386,7 +387,8 @@ PD_REGISTER_KERNEL(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(logit_grad, GPU, @@ -394,7 +396,8 @@ PD_REGISTER_KERNEL(logit_grad, phi::LogitGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(square_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 271ad6107bce4..0b396b17f5cb8 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -215,21 +215,24 @@ PD_REGISTER_KERNEL(exp, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(expm1, GPU, ALL_LAYOUT, phi::Expm1Kernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(logit, GPU, ALL_LAYOUT, phi::LogitKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(square, GPU, ALL_LAYOUT,