diff --git a/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu index ea1d377c45976..86cccc5e03b1c 100644 --- a/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_min_grad_kernel.cu @@ -16,8 +16,63 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/compare_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" +namespace phi { + +template +void ReduceMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& out_grad, + const IntArray& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* x_grad) { + dev_ctx.Alloc(x_grad, x.dtype()); + reduce_all = recompute_reduce_all(x, dims, reduce_all); + + // get reduce_dim + int dim_size = x.dims().size(); + auto reduce_dims = + funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); + auto update_dims = vectorize(x.dims()); + for (auto i : reduce_dims) { + update_dims[i] = 1; + } + + // make new tensor of out and out_grad + phi::DenseTensor new_out(out.type()); + new_out.ShareDataWith(out); + new_out.Resize(phi::make_ddim(update_dims)); + + phi::DenseTensor new_out_grad(out_grad.type()); + new_out_grad.ShareDataWith(out_grad); + new_out_grad.Resize(phi::make_ddim(update_dims)); + + // make equal_out + phi::DenseTensor* equal_out = new phi::DenseTensor(); + equal_out->Resize(x.dims()); + dev_ctx.template Alloc(equal_out); + + // compute + // 1. equal_out = Equal(x, y) + std::vector equal_inputs = {&new_out, &x}; + std::vector equal_outputs = {equal_out}; + funcs::BroadcastKernel( + dev_ctx, equal_inputs, &equal_outputs, 0, funcs::EqualFunctor()); + + // 2. dx = dout * 1 + std::vector mul_inputs = {&new_out_grad, equal_out}; + std::vector mul_outputs = {x_grad}; + funcs::BroadcastKernel( + dev_ctx, mul_inputs, &mul_outputs, 0, funcs::MultiplyFunctor()); + delete equal_out; +} +} // namespace phi PD_REGISTER_KERNEL(min_grad, GPU, ALL_LAYOUT, @@ -25,4 +80,6 @@ PD_REGISTER_KERNEL(min_grad, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/kps/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu index 450fee16b4ca9..8ed9ec30c1920 100644 --- a/paddle/phi/kernels/kps/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -36,6 +36,14 @@ void MinRawKernel(const Context& dev_ctx, #ifdef PADDLE_WITH_XPU_KP PD_REGISTER_KERNEL(min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float) {} #else -PD_REGISTER_KERNEL( - min_raw, KPS, ALL_LAYOUT, phi::MinRawKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(min_raw, + KPS, + ALL_LAYOUT, + phi::MinRawKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index c4c58c8342e60..ff50e9d1077b0 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -39,7 +39,20 @@ void MinKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( min, CPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {} -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_CUDA) +PD_REGISTER_KERNEL(min, + GPU, + ALL_LAYOUT, + phi::MinKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#endif + +#if defined(PADDLE_WITH_HIP) PD_REGISTER_KERNEL( min, GPU, ALL_LAYOUT, phi::MinKernel, float, double, int, int64_t) {} #endif diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 050879369244d..631b760a7b8da 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -418,6 +418,51 @@ def test_check_output(self): self.check_output() +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework." +) +class TestMinFP16Op(OpTest): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.python_api = paddle.min + self.public_python_api = paddle.min + self.init_dtype() + if self.dtype == np.uint16: + x = np.random.random((5, 6, 10)).astype(np.float32) + self.inputs = {'X': convert_float_to_uint16(x)} + else: + x = np.random.random((5, 6, 10)).astype(self.dtype) + self.inputs = {'X': x} + self.attrs = {'dim': [2], 'keep_dim': True} + out = x.min(axis=tuple(self.attrs['dim']), keepdims=True) + if self.dtype == np.uint16: + self.outputs = {'Out': convert_float_to_uint16(out)} + else: + self.outputs = {'Out': out} + + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output() + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestMinBF16Op(TestMinFP16Op): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0)) + + def raw_reduce_prod(x, dim=[0], keep_dim=False): return paddle.prod(x, dim, keep_dim) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 0e6b55142bf70..0d6f80e4e205a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2449,7 +2449,10 @@ def min(x, axis=None, keepdim=False, name=None): reduce_all, axis = _get_reduce_axis_with_tensor(axis, x) helper = LayerHelper('min', **locals()) check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'min' + x, + 'x', + ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + 'min', ) out = helper.create_variable_for_type_inference(dtype=x.dtype)