From 272b375250c1f76fae26bef6490601c5b950aef6 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 10 Aug 2023 16:35:37 +0800 Subject: [PATCH] fix A100 fused linear grad add ut bug (#56136) --- .../test_fused_linear_param_grad_add.py | 61 ++++++++++++------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/test/legacy_test/test_fused_linear_param_grad_add.py b/test/legacy_test/test_fused_linear_param_grad_add.py index e707bbc41fa2a..762b2a99b52e9 100644 --- a/test/legacy_test/test_fused_linear_param_grad_add.py +++ b/test/legacy_test/test_fused_linear_param_grad_add.py @@ -54,7 +54,7 @@ def recreate(x, multi_precision): return paddle.to_tensor(x.numpy()) -def run_ground_truth(x, dy, dweight, dbias, multi_precision): +def run_ground_truth(x, dy, dweight, dbias, multi_precision, has_bias): x, dy, dweight, dbias = recreate([x, dy, dweight, dbias], multi_precision) dweight_tmp = paddle.matmul( @@ -69,24 +69,35 @@ def run_ground_truth(x, dy, dweight, dbias, multi_precision): assert dweight.dtype == dweight.dtype dweight += dweight_tmp - dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) - if dbias is None: - dbias = dbias_tmp - else: - assert dbias.shape == dbias_tmp.shape - assert dbias.dtype == dbias_tmp.dtype - dbias += dbias_tmp + if has_bias: + dbias_tmp = dy.reshape([-1, dy.shape[-1]]).sum(axis=0) + if dbias is None: + dbias = dbias_tmp + else: + assert dbias.shape == dbias_tmp.shape + assert dbias.dtype == dbias_tmp.dtype + dbias += dbias_tmp - return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + return promote_dtype(dweight).numpy(), promote_dtype(dbias).numpy() + else: + return promote_dtype(dweight).numpy() -def run_fused_linear_param_grad_add(x, dy, dweight, dbias, multi_precision): +def run_fused_linear_param_grad_add( + x, dy, dweight, dbias, multi_precision, has_bias +): dweight_new, dbias_new = _C_ops.fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) if dweight is not None: assert dweight_new.data_ptr() == dweight.data_ptr() - return promote_dtype(dweight_new).numpy(), promote_dtype(dbias_new).numpy() + if has_bias: + return ( + promote_dtype(dweight_new).numpy(), + promote_dtype(dbias_new).numpy(), + ) + else: + return promote_dtype(dweight_new).numpy() class TestMainClassBase(unittest.TestCase): @@ -103,7 +114,9 @@ def rand(self, shape, dtype=None): x = paddle.to_tensor(x) return x.astype(dtype or self.dtype) - def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): + def generate_rand_inputs( + self, has_dweight, has_dbias, multi_precision, has_bias + ): x_shape = self.shape dy_shape = self.shape[:-1] + [self.output_size] dweight_shape = [self.shape[-1], self.output_size] @@ -118,7 +131,7 @@ def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): else: dweight = None - if has_dbias: + if has_bias and has_dbias: dbias = self.rand(dbias_shape) if multi_precision: dbias = promote_dtype(dbias) @@ -126,14 +139,15 @@ def generate_rand_inputs(self, has_dweight, has_dbias, multi_precision): dbias = None return x, dy, dweight, dbias - def check_main(self, has_dweight, has_dbias, multi_precision): - print(has_dweight, has_dbias, multi_precision) + def check_main(self, has_dweight, has_dbias, multi_precision, has_bias): x, dy, dweight, dbias = self.generate_rand_inputs( - has_dweight, has_dbias, multi_precision + has_dweight, has_dbias, multi_precision, has_bias + ) + res1 = run_ground_truth( + x, dy, dweight, dbias, multi_precision, has_bias ) - res1 = run_ground_truth(x, dy, dweight, dbias, multi_precision) res2 = run_fused_linear_param_grad_add( - x, dy, dweight, dbias, multi_precision + x, dy, dweight, dbias, multi_precision, has_bias ) self.assertEqual(len(res1), len(res2)) for r1, r2 in zip(res1, res2): @@ -153,9 +167,12 @@ def test_main(self): return for has_dweight in [False, True]: - for has_dbias in [False, True]: - for multi_precision in [False, True]: - self.check_main(has_dweight, has_dbias, multi_precision) + for has_bias in [False, True]: + for has_dbias in [False, True]: + for multi_precision in [False, True]: + self.check_main( + has_dweight, has_dbias, multi_precision, has_bias + ) class TestMainClassBF16(TestMainClassBase):