diff --git a/paddle/phi/kernels/dist_grad_kernel.cc b/paddle/phi/kernels/dist_grad_kernel.cc index 17c24fa905b5c..e6ef962c665c2 100644 --- a/paddle/phi/kernels/dist_grad_kernel.cc +++ b/paddle/phi/kernels/dist_grad_kernel.cc @@ -98,6 +98,12 @@ PD_REGISTER_KERNEL( dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL( - dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} +PD_REGISTER_KERNEL(dist_grad, + GPU, + ALL_LAYOUT, + phi::DistGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/gpu/dist_kernel.cu b/paddle/phi/kernels/gpu/dist_kernel.cu index 5040be8eaaca7..9129c87b91434 100644 --- a/paddle/phi/kernels/gpu/dist_kernel.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/dist_kernel.h" +#include + #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/dist_kernel.h" #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/gpu/reduce.h" @@ -24,47 +27,53 @@ namespace phi { #define FULL_MASK 0xffffffff -template +template struct ZeroOrderFunctor { public: - __device__ T operator()(const T& x, const T& y) const { - return static_cast((x - y) != 0); + HOSTDEVICE explicit inline ZeroOrderFunctor() {} + HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { + return static_cast(x != y); } }; -template +template struct OtherOrderFunctor { - explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {} - __device__ T operator()(const T& x, const T& y) const { - return static_cast(pow(abs(x - y), p_order_)); + HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& p_order) + : p_order_(p_order) {} + + HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const { + return static_cast( + pow(abs(static_cast(x) - static_cast(y)), p_order_)); } private: - T p_order_; + Ty p_order_; }; -template +template struct PowFunctor { - explicit PowFunctor(const T& p_order) : p_order_(p_order) {} - HOSTDEVICE inline T operator()(const T x) const { - return static_cast(pow(x, p_order_)); + HOSTDEVICE explicit inline PowFunctor(const Ty& p_order) + : p_order_(p_order) {} + HOSTDEVICE inline Tx operator()(const Tx x) const { + return static_cast(pow(static_cast(x), p_order_)); } - T p_order_; + Ty p_order_; }; template __global__ void ReduceSumWithSubtract( const T* x, const T* y, T* out, int64_t N, Functor func) { - T sum_val = 0; + using MT = typename phi::dtype::MPTypeTrait::Type; + MT sum_val(0.0); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { sum_val += func(x[i], y[i]); } __syncthreads(); - sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); + sum_val = phi::funcs::BlockReduceSum(sum_val, FULL_MASK); if (threadIdx.x == 0) { - out[blockIdx.x] = sum_val; + out[blockIdx.x] = static_cast(sum_val); } } @@ -73,16 +82,17 @@ __global__ void ReduceMaxWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T max_val = -1e10f; + using MT = typename phi::dtype::MPTypeTrait::Type; + MT max_val = std::numeric_limits::min(); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - max_val = max(max_val, abs(x[i] - y[i])); + max_val = max(max_val, abs(static_cast(x[i]) - static_cast(y[i]))); } __syncthreads(); - max_val = phi::funcs::BlockReduceMax(max_val, FULL_MASK); + max_val = phi::funcs::BlockReduceMax(max_val, FULL_MASK); if (threadIdx.x == 0) { - out[blockIdx.x] = max_val; + out[blockIdx.x] = static_cast(max_val); } } @@ -91,16 +101,17 @@ __global__ void ReduceMinWithSubtract(const T* x, const T* y, T* out, int64_t N) { - T min_val = 1e10f; + using MT = typename phi::dtype::MPTypeTrait::Type; + MT min_val = std::numeric_limits::max(); for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - min_val = min(min_val, abs(x[i] - y[i])); + min_val = min(min_val, abs(static_cast(x[i]) - static_cast(y[i]))); } __syncthreads(); - min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK); + min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK); if (threadIdx.x == 0) { - out[blockIdx.x] = min_val; + out[blockIdx.x] = static_cast(min_val); } } @@ -110,6 +121,7 @@ void DistKernel(const Context& dev_ctx, const DenseTensor& y, float p, DenseTensor* out) { + using MT = typename phi::dtype::MPTypeTrait::Type; DenseTensor intermediate; const T* x_ptr = x.data(); const T* y_ptr = y.data(); @@ -130,10 +142,9 @@ void DistKernel(const Context& dev_ctx, if (p == 0) { ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); - phi::funcs::ReduceKernel>( - dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); - + x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor()); + phi::funcs::ReduceKernel>( + dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else if (p == INFINITY) { ReduceMaxWithSubtract <<>>( @@ -150,19 +161,19 @@ void DistKernel(const Context& dev_ctx, dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); } else { - T p_order = static_cast(p); + MT p_order = static_cast(p); ReduceSumWithSubtract <<>>( - x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); - phi::funcs::ReduceKernel>( - dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); + x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor(p_order)); + phi::funcs::ReduceKernel>( + dev_ctx, intermediate, out, kps::IdentityFunctor(), reduce_axis); const DenseTensor* tmp_norm = out; std::vector ins = {tmp_norm}; std::vector outs = {out}; - T p_order_ = static_cast(1. / p_order); + MT p_order_ = static_cast(static_cast(1.) / p_order); phi::funcs::ElementwiseKernel( - dev_ctx, ins, &outs, PowFunctor(p_order_)); + dev_ctx, ins, &outs, PowFunctor(p_order_)); } } else { @@ -173,4 +184,11 @@ void DistKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {} +PD_REGISTER_KERNEL(dist, + GPU, + ALL_LAYOUT, + phi::DistKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 76ae5438cefba..8d9db61f92929 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -670,8 +670,8 @@ def dist(x, y, p=2, name=None): ||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}} Args: - x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64. - y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64. + x (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64. + y (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64. p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2. name (str, optional): The default value is `None`. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -701,8 +701,12 @@ def dist(x, y, p=2, name=None): if in_dynamic_mode(): return _C_ops.dist(x, y, p) - check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist') - check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist') + check_variable_and_dtype( + x, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist' + ) + check_variable_and_dtype( + y, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist' + ) check_type(p, 'p', (float, int), 'dist') helper = LayerHelper("dist", **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/test/legacy_test/test_dist_op.py b/test/legacy_test/test_dist_op.py index 96c0de915cff2..958071cfc62a1 100644 --- a/test/legacy_test/test_dist_op.py +++ b/test/legacy_test/test_dist_op.py @@ -158,6 +158,86 @@ def init_case(self): self.p = 1.5 +class TestDistBF16Op(OpTest): + def init_data_type(self): + self.data_type = 'bfloat16' + + +class TestDistBF16OpCase1(TestDistBF16Op): + def init_case(self): + self.x_shape = (3, 5, 5, 6) + self.y_shape = (5, 5, 6) + self.p = 1.0 + + +class TestDistBF16OpCase2(TestDistBF16Op): + def init_case(self): + self.x_shape = (10, 10) + self.y_shape = (4, 10, 10) + self.p = 2.0 + + +class TestDistBF16OpCase3(TestDistBF16Op): + def init_case(self): + self.x_shape = (15, 10) + self.y_shape = (15, 10) + self.p = float("inf") + + +class TestDistBF16OpCase4(TestDistBF16Op): + def init_case(self): + self.x_shape = (2, 3, 4, 5, 8) + self.y_shape = (3, 1, 5, 8) + self.p = float("-inf") + + +class TestDistBF16OpCase5(TestDistBF16Op): + def init_case(self): + self.x_shape = (4, 1, 4, 8) + self.y_shape = (2, 2, 1, 4, 4, 8) + self.p = 1.5 + + +class TestDistFP16Op(OpTest): + def init_data_type(self): + self.data_type = 'float16' + + +class TestDistFP16OpCase1(TestDistFP16Op): + def init_case(self): + self.x_shape = (3, 5, 5, 6) + self.y_shape = (5, 5, 6) + self.p = 1.0 + + +class TestDistFP16OpCase2(TestDistFP16Op): + def init_case(self): + self.x_shape = (10, 10) + self.y_shape = (4, 10, 10) + self.p = 2.0 + + +class TestDistFP16OpCase3(TestDistFP16Op): + def init_case(self): + self.x_shape = (15, 10) + self.y_shape = (15, 10) + self.p = float("inf") + + +class TestDistFP16OpCase4(TestDistFP16Op): + def init_case(self): + self.x_shape = (2, 3, 4, 5, 8) + self.y_shape = (3, 1, 5, 8) + self.p = float("-inf") + + +class TestDistFP16OpCase5(TestDistFP16Op): + def init_case(self): + self.x_shape = (4, 1, 4, 8) + self.y_shape = (2, 2, 1, 4, 4, 8) + self.p = 1.5 + + class TestDistAPI(unittest.TestCase): def init_data_type(self): self.data_type = (