From 9f0a57b7f8a39a3ae0cd0939097e9636c81568e6 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 06:03:53 +0000 Subject: [PATCH 01/11] add layer norm --- paddle/fluid/operators/layer_norm_kernel.cu.h | 6 +++--- paddle/fluid/operators/layer_norm_op.cu | 14 ++++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_kernel.cu.h b/paddle/fluid/operators/layer_norm_kernel.cu.h index b31c7a1cde0f1..62c21dd2eee40 100644 --- a/paddle/fluid/operators/layer_norm_kernel.cu.h +++ b/paddle/fluid/operators/layer_norm_kernel.cu.h @@ -474,11 +474,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_1024_kernel( for (int it = 0; it < LDGS; it++) { #pragma unroll for (int jt = 0; jt < VecSize; jt++) { - U x_tmp = x[it][jt]; + U x_tmp = static_cast(x[it][jt]); U y_tmp = var_cur_row * (x_tmp - mean_cur_row); U dy_tmp = static_cast(gamma[it][jt]) * - static_cast(dout[it][jt]); // scale * dy - U dout_tmp = dout[it][jt]; // dy + static_cast(dout[it][jt]); // scale * dy + U dout_tmp = static_cast(dout[it][jt]); // dy // used for get dx (row reduction) sum_loss1 += dy_tmp; // scale * dy, sum_1 diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index d439b3220d96e..9ca73d8ecdd98 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -253,22 +253,28 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, - ops::LayerNormKernel); + ops::LayerNormKernel, + ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel); + plat::float16>, + ops::LayerNormGradKernel); #else REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, ops::LayerNormKernel, - ops::LayerNormKernel); + ops::LayerNormKernel, + ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormGradKernel); + plat::float16>, + ops::LayerNormGradKernel); #endif From b40e3732778b439c7579299900461c95f90fcd92 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 06:04:33 +0000 Subject: [PATCH 02/11] add p norm --- paddle/fluid/operators/p_norm_op.cu | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index f2cb427a0a5b1..d0b78b9b0643d 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -39,6 +39,11 @@ __device__ __forceinline__ int sgn(T val) { __device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) { return static_cast(abs(static_cast(x))); } + +__device__ __forceinline__ platform::bfloat16 inline_abs(platform::bfloat16 x) { + return static_cast(abs(static_cast(x))); +} + __device__ __forceinline__ float inline_abs(float x) { return abs(x); } __device__ __forceinline__ double inline_abs(double x) { return abs(x); } @@ -53,6 +58,11 @@ __device__ __forceinline__ platform::float16 inline_pow( return static_cast( pow(static_cast(base), static_cast(exponent))); } +__device__ __forceinline__ platform::bfloat16 inline_pow( + platform::bfloat16 base, platform::bfloat16 exponent) { + return static_cast( + pow(static_cast(base), static_cast(exponent))); +} __device__ __forceinline__ float inline_pow(float base, float exponent) { return pow(base, exponent); } @@ -202,9 +212,11 @@ using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(p_norm, ops::PnormCUDAKernel, + ops::PnormCUDAKernel, ops::PnormCUDAKernel, ops::PnormCUDAKernel); REGISTER_OP_CUDA_KERNEL( p_norm_grad, ops::PnormGradCUDAKernel, + ops::PnormGradCUDAKernel, ops::PnormGradCUDAKernel, ops::PnormGradCUDAKernel); From 67a7e745cc85acd503d771e15a87b40208b65ede Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 06:05:12 +0000 Subject: [PATCH 03/11] add reduce sum --- paddle/fluid/operators/reduce_ops/reduce_sum_op.cc | 1 + paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu | 1 + paddle/phi/kernels/cpu/math_kernel.cc | 1 + paddle/phi/kernels/gpu/math_kernel.cu | 1 + paddle/phi/kernels/math_kernel.cc | 1 + 5 files changed, 5 insertions(+) diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index bdab14a18a05a..a9a6e6581b794 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -116,6 +116,7 @@ REGISTER_OP_CPU_KERNEL( reduce_sum_grad, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, + CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel>, CPUReduceSumGradKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index c3d3e0cf6ecd5..2f6bf12751809 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,6 +23,7 @@ REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel>, CUDAReduceSumGradKernel>); diff --git a/paddle/phi/kernels/cpu/math_kernel.cc b/paddle/phi/kernels/cpu/math_kernel.cc index 581c5f90f35e5..cf61a2bc4f920 100644 --- a/paddle/phi/kernels/cpu/math_kernel.cc +++ b/paddle/phi/kernels/cpu/math_kernel.cc @@ -169,6 +169,7 @@ PD_REGISTER_KERNEL(sum_raw, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/gpu/math_kernel.cu b/paddle/phi/kernels/gpu/math_kernel.cu index 02e3f00bd3425..cd30eaa2915e1 100644 --- a/paddle/phi/kernels/gpu/math_kernel.cu +++ b/paddle/phi/kernels/gpu/math_kernel.cu @@ -152,6 +152,7 @@ PD_REGISTER_KERNEL(sum_raw, float, double, float16, + bfloat16, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/math_kernel.cc b/paddle/phi/kernels/math_kernel.cc index db6c5e1ac3591..06c4aa0436c45 100644 --- a/paddle/phi/kernels/math_kernel.cc +++ b/paddle/phi/kernels/math_kernel.cc @@ -92,6 +92,7 @@ PD_REGISTER_KERNEL(sum, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, From 9490159037ef40405f39e909a126c874c21339b6 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 08:50:29 +0000 Subject: [PATCH 04/11] refine layer norm register bf16 for cudnn811 --- paddle/fluid/operators/layer_norm_op.cu | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 9ca73d8ecdd98..dfe73d3727132 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -253,11 +253,23 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( layer_norm, ops::LayerNormKernel, + ops::LayerNormKernel); +REGISTER_OP_CUDA_KERNEL( + layer_norm_grad, + ops::LayerNormGradKernel, + ops::LayerNormGradKernel); +#elif CUDNN_VERSION_MIN(8, 1, 0) +REGISTER_OP_CUDA_KERNEL( + layer_norm, + ops::LayerNormKernel, + ops::LayerNormKernel, ops::LayerNormKernel, ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, + ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormKernel, - ops::LayerNormKernel, - ops::LayerNormKernel); + ops::LayerNormKernel); REGISTER_OP_CUDA_KERNEL( layer_norm_grad, ops::LayerNormGradKernel, ops::LayerNormGradKernel, ops::LayerNormGradKernel, - ops::LayerNormGradKernel); + plat::float16>); #endif From c2ca074518e60476d42a1ba5eccb10b8547eaf9d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 09:23:21 +0000 Subject: [PATCH 05/11] add bf16 cast for hip --- paddle/fluid/operators/cast_op.cu | 8 ++++---- paddle/phi/kernels/gpu/cast_kernel.cu | 8 ++++---- paddle/phi/kernels/math_kernel.cc | 1 + 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 5c7dd0e2561fa..026311e420f21 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -29,9 +29,9 @@ using CUDA = paddle::platform::CUDADeviceContext; ops::CastOpKernel>, \ ops::CastOpKernel>, ##__VA_ARGS__); -#if !defined(PADDLE_WITH_HIP) +// #if !defined(PADDLE_WITH_HIP) // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel) -#else -REGISTER_CAST_CUDA_BASE(transfer_dtype) -#endif +// #else +// REGISTER_CAST_CUDA_BASE(transfer_dtype) +// #endif diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 7a6c99c5fe15f..7f90a17be6347 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -80,8 +80,8 @@ void CastKernel(const Context& dev_ctx, paddle::experimental::DataType::UNDEFINED); \ } -#if !defined(PADDLE_WITH_HIP) +// #if !defined(PADDLE_WITH_HIP) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16) -#else -PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) -#endif +// #else +// PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) +// #endif diff --git a/paddle/phi/kernels/math_kernel.cc b/paddle/phi/kernels/math_kernel.cc index 06c4aa0436c45..e8efdd5ef29ae 100644 --- a/paddle/phi/kernels/math_kernel.cc +++ b/paddle/phi/kernels/math_kernel.cc @@ -164,6 +164,7 @@ PD_REGISTER_KERNEL(sum, float, double, phi::dtype::float16, + phi::dtype::bfloat16, int16_t, int, int64_t, From a2f7e07d3b5acba155f03ccb686deaeae142dca2 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 25 Feb 2022 11:55:11 +0000 Subject: [PATCH 06/11] add unittest --- paddle/fluid/operators/cast_op.cu | 8 +- .../operators/reduce_ops/reduce_sum_op.cc | 1 - paddle/phi/kernels/cpu/math_kernel.cc | 1 - paddle/phi/kernels/gpu/cast_kernel.cu | 8 +- paddle/phi/kernels/math_kernel.cc | 1 - .../tests/unittests/test_layer_norm_op.py | 47 ++++++++++++ .../fluid/tests/unittests/test_norm_all.py | 76 ++++++++++++++++++- .../fluid/tests/unittests/test_reduce_op.py | 35 ++++++++- 8 files changed, 164 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 026311e420f21..5c7dd0e2561fa 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -29,9 +29,9 @@ using CUDA = paddle::platform::CUDADeviceContext; ops::CastOpKernel>, \ ops::CastOpKernel>, ##__VA_ARGS__); -// #if !defined(PADDLE_WITH_HIP) +#if !defined(PADDLE_WITH_HIP) // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel) -// #else -// REGISTER_CAST_CUDA_BASE(transfer_dtype) -// #endif +#else +REGISTER_CAST_CUDA_BASE(transfer_dtype) +#endif diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index a9a6e6581b794..bdab14a18a05a 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -116,7 +116,6 @@ REGISTER_OP_CPU_KERNEL( reduce_sum_grad, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, - CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel>, CPUReduceSumGradKernel>); diff --git a/paddle/phi/kernels/cpu/math_kernel.cc b/paddle/phi/kernels/cpu/math_kernel.cc index cf61a2bc4f920..581c5f90f35e5 100644 --- a/paddle/phi/kernels/cpu/math_kernel.cc +++ b/paddle/phi/kernels/cpu/math_kernel.cc @@ -169,7 +169,6 @@ PD_REGISTER_KERNEL(sum_raw, float, double, phi::dtype::float16, - phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 7f90a17be6347..7a6c99c5fe15f 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -80,8 +80,8 @@ void CastKernel(const Context& dev_ctx, paddle::experimental::DataType::UNDEFINED); \ } -// #if !defined(PADDLE_WITH_HIP) +#if !defined(PADDLE_WITH_HIP) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16) -// #else -// PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) -// #endif +#else +PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) +#endif diff --git a/paddle/phi/kernels/math_kernel.cc b/paddle/phi/kernels/math_kernel.cc index e8efdd5ef29ae..aea874f142680 100644 --- a/paddle/phi/kernels/math_kernel.cc +++ b/paddle/phi/kernels/math_kernel.cc @@ -92,7 +92,6 @@ PD_REGISTER_KERNEL(sum, float, double, phi::dtype::float16, - phi::dtype::bfloat16, int16_t, int, int64_t, diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 7dd310d2b88a9..994b3a5adcd89 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -375,6 +375,53 @@ def assert_equal(x, y): assert_equal(b_g_np_1, b_g_np_2) +class TestBF16ScaleBiasLayerNorm(unittest.TestCase): + def check_main(self, x_np, weight_np, bias_np, dtype): + paddle.disable_static() + + x = paddle.to_tensor(x_np) + weight = paddle.to_tensor(weight_np) + bias = paddle.to_tensor(bias_np) + + if dtype == "bfloat16": + x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16) + + x.stop_gradient = False + weight.stop_gradient = False + bias.stop_gradient = False + + y = F.layer_norm(x, x.shape[1:], weight, bias) + x_g, w_g, b_g = paddle.grad(y, [x, weight, bias]) + + y_np = y.cast('float32').numpy() + x_g_np = x_g.cast('float32').numpy() + w_g_np = w_g.cast('float32').numpy() + b_g_np = b_g.cast('float32').numpy() + + paddle.enable_static() + return y_np, x_g_np, w_g_np, b_g_np + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + x_np = np.random.random([10, 20]).astype('float32') + weight_np = np.random.random([20]).astype('float32') + bias_np = np.random.random([20]).astype('float32') + + y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main( + x_np, weight_np, bias_np, 'float32') + y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main( + x_np, weight_np, bias_np, 'bfloat16') + + def assert_equal(x, y): + self.assertTrue(np.allclose(x, y, atol=1.e-1)) + + assert_equal(y_np_1, y_np_2) + assert_equal(x_g_np_1, x_g_np_2) + assert_equal(w_g_np_1, w_g_np_2) + assert_equal(b_g_np_1, b_g_np_2) + + class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase): def test_main(self): self.assertTrue(_keep_layer_norm_scale_bias_to_fp32()) diff --git a/python/paddle/fluid/tests/unittests/test_norm_all.py b/python/paddle/fluid/tests/unittests/test_norm_all.py index b20305b78efe2..575bc653618a5 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_all.py +++ b/python/paddle/fluid/tests/unittests/test_norm_all.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid import paddle.fluid.core as core @@ -282,6 +282,80 @@ def init_test_case(self): self.asvector = True +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestPnormBF16Op(OpTest): + def setUp(self): + self.op_type = "p_norm" + self.init_test_case() + self.x = (np.random.random(self.shape) + 0.5).astype(np.float32) + self.norm = p_norm(self.x, self.axis, self.porder, self.keepdim, + self.asvector) + self.gradient = self.calc_gradient() + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder), + 'asvector': self.asvector + } + self.outputs = {'Out': convert_float_to_uint16(self.norm)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', user_defined_grads=self.gradient) + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-12 + self.porder = 2.0 + self.keepdim = False + self.dtype = np.uint16 + self.asvector = False + + def calc_gradient(self): + self.attrs = { + 'epsilon': self.epsilon, + 'axis': self.axis, + 'keepdim': self.keepdim, + 'porder': float(self.porder), + 'asvector': self.asvector + } + x = self.x + porder = self.attrs["porder"] + axis = self.attrs["axis"] + asvector = self.attrs["asvector"] + x_dtype = x.dtype + x = x.astype(np.float32) if x.dtype == np.float16 else x + if porder == 0: + grad = np.zeros(x.shape).astype(x.dtype) + elif porder in [float("inf"), float("-inf")]: + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) + x_abs = np.abs(x) + grad = np.sign(x) + grad[x_abs != norm] = 0.0 + else: + norm = p_norm( + x, axis=axis, porder=porder, keepdims=True, reduce_all=asvector) + grad = np.power(norm, 1 - porder) * np.power( + np.abs(x), porder - 1) * np.sign(x) + + numel = 1 + for s in x.shape: + numel *= s + divisor = numel if asvector else x.shape[axis] + numel /= divisor + return [grad.astype(x_dtype) * 1 / numel] + + def run_fro(self, p, axis, shape_x, dtype, keep_dim, check_dim=False): with fluid.program_guard(fluid.Program()): data = fluid.data(name="X", shape=shape_x, dtype=dtype) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index faa67e1d6da8f..75c57bed25a66 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest, skip_check_grad_ci +from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle import paddle.fluid.core as core import paddle.fluid as fluid @@ -61,6 +61,39 @@ def test_check_grad(self): self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSumOp_bf16(OpTest): + def setUp(self): + np.random.seed(100) + self.op_type = "reduce_sum" + self.dtype = np.uint16 + self.x = np.random.uniform(1, 2, (5, 6, 10)).astype(np.float32) + print(self.x) + self.attrs = {'dim': [0, 1, 2]} + self.out = self.x.sum(axis=tuple(self.attrs['dim'])) + print(self.out) + self.gradient = self.calc_gradient() + + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + self.gradient = self.calc_gradient() + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', user_defined_grads=self.gradient) + + def calc_gradient(self): + x = self.x + grad = np.ones(x.shape, dtype=x.dtype) + return [grad] + + class TestSumOp_fp16_withInt(OpTest): def setUp(self): self.op_type = "reduce_sum" From eba7f4cbf73a5d7a9a77c4c87210dbb46ea618f6 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 26 Feb 2022 02:27:17 +0000 Subject: [PATCH 07/11] refine rocm --- paddle/fluid/operators/cast_op.cu | 4 ---- paddle/phi/kernels/gpu/cast_kernel.cu | 4 ---- 2 files changed, 8 deletions(-) diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 5c7dd0e2561fa..eb51215790bbc 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -29,9 +29,5 @@ using CUDA = paddle::platform::CUDADeviceContext; ops::CastOpKernel>, \ ops::CastOpKernel>, ##__VA_ARGS__); -#if !defined(PADDLE_WITH_HIP) // See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc REGISTER_CAST_CUDA_BASE(transfer_dtype, ops::CastOpKernel) -#else -REGISTER_CAST_CUDA_BASE(transfer_dtype) -#endif diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 7a6c99c5fe15f..569a46f56d563 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -80,8 +80,4 @@ void CastKernel(const Context& dev_ctx, paddle::experimental::DataType::UNDEFINED); \ } -#if !defined(PADDLE_WITH_HIP) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, phi::dtype::bfloat16) -#else -PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) -#endif From fe561c2c213724cbc904d029838f74ed60284e6f Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 26 Feb 2022 04:11:06 +0000 Subject: [PATCH 08/11] refine layer_norm unittest --- python/paddle/fluid/tests/unittests/test_layer_norm_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py index 994b3a5adcd89..ca9a489c7496f 100644 --- a/python/paddle/fluid/tests/unittests/test_layer_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_layer_norm_op.py @@ -402,7 +402,7 @@ def check_main(self, x_np, weight_np, bias_np, dtype): return y_np, x_g_np, w_g_np, b_g_np def test_main(self): - if not paddle.is_compiled_with_cuda(): + if (not core.is_compiled_with_cuda()) or (core.cudnn_version() < 8100): return x_np = np.random.random([10, 20]).astype('float32') weight_np = np.random.random([20]).astype('float32') From 0ea4582646af3fb7d5cda6f6dd845b468902997f Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 26 Feb 2022 13:04:12 +0000 Subject: [PATCH 09/11] refine reduce op --- python/paddle/fluid/tests/unittests/test_reduce_op.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 75c57bed25a66..af2defe8b27f0 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -68,11 +68,9 @@ def setUp(self): np.random.seed(100) self.op_type = "reduce_sum" self.dtype = np.uint16 - self.x = np.random.uniform(1, 2, (5, 6, 10)).astype(np.float32) - print(self.x) + self.x = np.random.uniform(0, 1, (2, 5, 10)).astype(np.float32) self.attrs = {'dim': [0, 1, 2]} self.out = self.x.sum(axis=tuple(self.attrs['dim'])) - print(self.out) self.gradient = self.calc_gradient() self.inputs = {'X': convert_float_to_uint16(self.x)} From 33cb7bdbb1a64412c202b54cc358747c3b9035c4 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 28 Feb 2022 01:58:22 +0000 Subject: [PATCH 10/11] refine unittest --- python/paddle/fluid/tests/unittests/test_reduce_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index af2defe8b27f0..d246356b4ec75 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -68,7 +68,7 @@ def setUp(self): np.random.seed(100) self.op_type = "reduce_sum" self.dtype = np.uint16 - self.x = np.random.uniform(0, 1, (2, 5, 10)).astype(np.float32) + self.x = np.random.uniform(0, 0.1, (2, 5, 10)).astype(np.float32) self.attrs = {'dim': [0, 1, 2]} self.out = self.x.sum(axis=tuple(self.attrs['dim'])) self.gradient = self.calc_gradient() From 44ef3a6dd7da3bb5ee4db7b19935fd71b331093b Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 28 Feb 2022 06:25:56 +0000 Subject: [PATCH 11/11] enhance atol for reduce unittest --- python/paddle/fluid/tests/unittests/op_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 848ebae0706e3..7caf1e68ce49a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1135,7 +1135,7 @@ def check_output_with_place(self, else: atol = 2 else: - atol = 1e-2 + atol = 1e-1 if no_check_set is not None: if self.op_type not in no_check_set_white_list.no_check_set_white_list: