From 4b6b66054c5f2921eebae689f9654039c3af0f2e Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 21 Sep 2021 10:23:48 +0800 Subject: [PATCH] support fp16 (#35888) --- .../elementwise/elementwise_max_op.cu | 4 + .../elementwise/elementwise_max_op.h | 4 +- paddle/fluid/operators/p_norm_op.cu | 74 ++++++++++++------- python/paddle/nn/functional/norm.py | 3 +- 4 files changed, 56 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 9657e1896e334..65505381db174 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -41,12 +41,16 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( elementwise_max, + ops::ElementwiseMaxKernel, ops::ElementwiseMaxKernel, ops::ElementwiseMaxKernel, ops::ElementwiseMaxKernel, ops::ElementwiseMaxKernel); REGISTER_OP_CUDA_KERNEL( elementwise_max_grad, + ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 8ee8fe923a811..06269b12e8e20 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -39,14 +39,14 @@ class ElementwiseMaxKernel : public framework::OpKernel { template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * (x > y); + return dout * static_cast(x > y); } }; template struct MaxGradDy { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * (x <= y); + return dout * static_cast(x <= y); } }; diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index bd6694abdbf76..cfe778c49121f 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -20,7 +20,9 @@ limitations under the License. */ #include namespace cub = hipcub; #endif +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/p_norm_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -30,12 +32,23 @@ __device__ __forceinline__ int sgn(T val) { return (T(0) < val) - (val < T(0)); } +__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) { + return static_cast(abs(static_cast(x))); +} __device__ __forceinline__ float inline_abs(float x) { return abs(x); } __device__ __forceinline__ double inline_abs(double x) { return abs(x); } +__device__ __forceinline__ int inline_sign(platform::float16 x) { + return sgn(x); +} __device__ __forceinline__ int inline_sign(float x) { return sgn(x); } __device__ __forceinline__ int inline_sign(double x) { return sgn(x); } +__device__ __forceinline__ platform::float16 inline_pow( + platform::float16 base, platform::float16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} __device__ __forceinline__ float inline_pow(float base, float exponent) { return pow(base, exponent); } @@ -47,21 +60,23 @@ template __global__ void Pnorm(const T* x, const int pre, const int axis_n, // dim in axis const int post, float porder, T* out_norm) { - typedef cub::BlockReduce BlockReduce; + using MT = typename details::MPTypeTrait::Type; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int num = pre * post; - auto porder_t = static_cast(porder); - auto porder_inv = static_cast(1.0 / porder); + auto porder_t = static_cast(porder); + auto porder_inv = static_cast(1.0 / porder); for (int i = blockIdx.x; i < num; i += gridDim.x) { int base = (i / post) * post * axis_n + (i % post); - T sum = 0.0; + MT sum = static_cast(0.0); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const T x_ij = x[base + j * post]; + const MT x_ij = static_cast(x[base + j * post]); sum += inline_pow(inline_abs(x_ij), porder_t); } - T reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) out_norm[i] = inline_pow(reduce_result, porder_inv); + MT reduce_result = BlockReduce(temp_storage).Sum(sum); + if (threadIdx.x == 0) + out_norm[i] = static_cast(inline_pow(reduce_result, porder_inv)); } } @@ -69,18 +84,19 @@ template __global__ void ZeorNorm(const T* x, const int pre, const int axis_n, // dim in axis const int post, T* out_norm) { - typedef cub::BlockReduce BlockReduce; + using MT = typename details::MPTypeTrait::Type; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; int num = pre * post; for (int i = blockIdx.x; i < num; i += gridDim.x) { int base = (i / post) * post * axis_n + (i % post); - T sum = 0.0; + MT sum = static_cast(0.0); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const T x_ij = x[base + j * post]; - sum += static_cast(x_ij != 0); + const MT x_ij = static_cast(x[base + j * post]); + sum += static_cast(static_cast(x_ij) != 0); } - T reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) out_norm[i] = reduce_result; + MT reduce_result = BlockReduce(temp_storage).Sum(sum); + if (threadIdx.x == 0) out_norm[i] = static_cast(reduce_result); } } @@ -172,27 +188,29 @@ __global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, const float porder, const int pre, const int axis_n, const int post, const T eps, T* x_grad) { + using MT = typename details::MPTypeTrait::Type; // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) int num = pre * post; - auto porder_grad = static_cast(porder - 1.0f); + auto porder_grad = static_cast(porder - 1.0f); for (int i = blockIdx.x; i < num; i += gridDim.x) { - __shared__ T pnorm_i; - __shared__ T yout_i; + __shared__ MT pnorm_i; + __shared__ MT yout_i; auto base = (i / post) * post * axis_n + (i % post); if (threadIdx.x == 0) { - pnorm_i = x_norm[i]; - yout_i = y_grad[i]; + pnorm_i = static_cast(x_norm[i]); + yout_i = static_cast(y_grad[i]); } __syncthreads(); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { int index = base + j * post; - const T x_ij = inline_abs(x[index]); - x_grad[index] = inline_pow(x_ij, porder_grad) / - (inline_pow(pnorm_i, porder_grad) + eps) * yout_i * - inline_sign(x[index]); + const MT x_ij = static_cast(inline_abs(x[index])); + x_grad[index] = static_cast( + inline_pow(x_ij, porder_grad) / + (inline_pow(pnorm_i, porder_grad) + static_cast(eps)) * yout_i * + static_cast(inline_sign(x[index]))); } } } @@ -216,7 +234,7 @@ __global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad, int index = base + j * post; const T x_ij = inline_abs(x[index]); if (x_ij == pnorm_i) { - x_grad[index] = inline_sign(x[index]) * yout_i; + x_grad[index] = static_cast(inline_sign(x[index])) * yout_i; } else { x_grad[index] = static_cast(0); } @@ -278,7 +296,11 @@ class PnormGradCUDAKernel : public framework::OpKernel { namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; -REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel, +REGISTER_OP_CUDA_KERNEL(p_norm, + ops::PnormCUDAKernel, + ops::PnormCUDAKernel, ops::PnormCUDAKernel); -REGISTER_OP_CUDA_KERNEL(p_norm_grad, ops::PnormGradCUDAKernel, - ops::PnormGradCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + p_norm_grad, ops::PnormGradCUDAKernel, + ops::PnormGradCUDAKernel, + ops::PnormGradCUDAKernel); diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index db73e56f879a7..89843885c8a12 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -86,7 +86,8 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): check_type(p, 'p', (float, int), 'normalize') check_type(axis, 'axis', (int), 'normalize') - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'normalize') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'normalize') if len(x.shape) == 1 and axis != 0 and axis != -1: raise ValueError( "Axis must be 0 or -1 when x is a 1-D tensor, but received axis = {}".