Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add erfinv FP16 test and BF16 test #53101

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions paddle/phi/kernels/gpu/erfinv_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@
#include "paddle/phi/kernels/erfinv_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
erfinv_grad, GPU, ALL_LAYOUT, phi::ErfinvGradKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv_grad,
GPU,
ALL_LAYOUT,
phi::ErfinvGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
27 changes: 26 additions & 1 deletion paddle/phi/kernels/gpu/erfinv_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "paddle/phi/kernels/erfinv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"

Expand All @@ -24,6 +26,22 @@ struct ErfinvFunctor {
HOSTDEVICE inline T operator()(const T x) const { return erfinv(x); }
};

template <>
struct ErfinvFunctor<float16> {
HOSTDEVICE inline float16 operator()(const float16 x) const {
auto newx = static_cast<float>(x);
return static_cast<float16>(erfinv(newx));
}
};

template <>
struct ErfinvFunctor<bfloat16> {
HOSTDEVICE inline bfloat16 operator()(const bfloat16 x) const {
auto newx = static_cast<float>(x);
return static_cast<bfloat16>(erfinv(newx));
}
};

template <typename T, typename Context>
void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
Expand All @@ -34,4 +52,11 @@ void ErfinvKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {

} // namespace phi

PD_REGISTER_KERNEL(erfinv, GPU, ALL_LAYOUT, phi::ErfinvKernel, float, double) {}
PD_REGISTER_KERNEL(erfinv,
GPU,
ALL_LAYOUT,
phi::ErfinvKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/impl/erfinv_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/erfinv_grad_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

namespace phi {

Expand All @@ -29,7 +32,7 @@ void ErfinvGradKernel(const Context& ctx,
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
auto eigen_dx = EigenVector<T>::Flatten(*x_grad);
auto& place = *ctx.eigen_device();
constexpr T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
const T half_sqrt_pi = static_cast<T>(1 / M_2_SQRTPI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里修改的原因是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据CI错误提示这里应该修改为const常量

eigen_dx.device(place) = half_sqrt_pi * eigen_dout * eigen_out.square().exp();
}

Expand Down
39 changes: 38 additions & 1 deletion python/paddle/fluid/tests/unittests/test_erfinv_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from scipy.special import erfinv

import paddle
Expand Down Expand Up @@ -110,5 +110,42 @@ def run(place):
run(place)


class TestErfinvFP16OP(TestErfinv):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not paddle.fluid.core.is_compiled_with_cuda()
or not paddle.fluid.core.is_bfloat16_supported(
paddle.fluid.core.CUDAPlace(0)
),
"core is not complied with CUDA and not support the bfloat16",
)
class TestErfinvBF16OP(OpTest):
def setUp(self):
self.op_type = "erfinv"
self.public_python_api = paddle.erfinv
self.python_api = paddle.erfinv
self.dtype = np.uint16
self.shape = [11, 17]
x = np.random.uniform(-1, 1, size=self.shape).astype(np.float32)
res_ref = erfinv(x).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(res_ref)}

def test_check_output(self):
place = paddle.fluid.core.CUDAPlace(0)
self.check_output_with_place(place)

def test_check_grad(self):
place = paddle.fluid.core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
'Out',
)


if __name__ == "__main__":
unittest.main()
4 changes: 3 additions & 1 deletion python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4374,7 +4374,9 @@ def erfinv(x, name=None):
if in_dygraph_mode():
return _C_ops.erfinv(x)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], 'erfinv'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件4358行文档可以同步更新下支持的数据类型,另外可以解决下代码冲突,应该就可以合入了

helper = LayerHelper('erfinv', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='erfinv', inputs={'X': x}, outputs={'Out': out})
Expand Down