From d86ca707c557684996d496518735b551f8aaae18 Mon Sep 17 00:00:00 2001 From: daidaiershidi <1154864382@qq.com> Date: Wed, 23 Feb 2022 03:50:12 +0000 Subject: [PATCH] fix unittests for eignvalsh --- .../fluid/tests/unittests/test_eigvalsh_op.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py index db02372267677..93745d9561f5d 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py @@ -60,8 +60,12 @@ def setUp(self): self.dtype = "float32" np.random.seed(123) self.x_np = np.random.random(self.x_shape).astype(self.dtype) - self.rtol = 1e-5 - self.atol = 1e-5 + if (paddle.version.cuda() >= "11.6"): + self.rtol = 5e-6 + self.atol = 6e-5 + else: + self.rtol = 1e-5 + self.atol = 1e-5 def test_check_output_gpu(self): if paddle.is_compiled_with_cuda(): @@ -75,23 +79,29 @@ def test_check_output_gpu(self): class TestEigvalshAPI(unittest.TestCase): def setUp(self): - self.init_input_shape() + self.x_shape = [5, 5] self.dtype = "float32" self.UPLO = 'L' - self.rtol = 1e-6 - self.atol = 1e-6 + if (paddle.version.cuda() >= "11.6"): + self.rtol = 5e-6 + self.atol = 6e-5 + else: + self.rtol = 1e-5 + self.atol = 1e-5 self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ else paddle.CPUPlace() np.random.seed(123) + self.init_input_data() + + def init_input_data(self): self.real_data = np.random.random(self.x_shape).astype(self.dtype) - self.complex_data = np.random.random(self.x_shape).astype( + complex_data = np.random.random(self.x_shape).astype( self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) self.trans_dims = list(range(len(self.x_shape) - 2)) + [ len(self.x_shape) - 1, len(self.x_shape) - 2 ] - - def init_input_shape(self): - self.x_shape = [5, 5] + self.complex_symm = np.divide( + complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2) def compare_result(self, actual_w, expected_w): np.testing.assert_allclose( @@ -122,9 +132,9 @@ def check_static_complex_result(self): output_w = paddle.linalg.eigvalsh(input_x) exe = paddle.static.Executor(self.place) expected_w = exe.run(main_prog, - feed={"input_x": self.complex_data}, + feed={"input_x": self.complex_symm}, fetch_list=[output_w]) - actual_w = np.linalg.eigvalsh(self.complex_data) + actual_w = np.linalg.eigvalsh(self.complex_symm) self.compare_result(actual_w, expected_w[0]) def test_in_static_mode(self): @@ -139,14 +149,14 @@ def test_in_dynamic_mode(self): actual_w = paddle.linalg.eigvalsh(input_real_data) self.compare_result(actual_w, expected_w) - input_complex_data = paddle.to_tensor(self.complex_data) - expected_w = np.linalg.eigvalsh(self.complex_data) - actual_w = paddle.linalg.eigvalsh(input_complex_data) + input_complex_symm = paddle.to_tensor(self.complex_symm) + expected_w = np.linalg.eigvalsh(self.complex_symm) + actual_w = paddle.linalg.eigvalsh(input_complex_symm) self.compare_result(actual_w, expected_w) def test_eigvalsh_grad(self): paddle.disable_static(self.place) - x = paddle.to_tensor(self.complex_data, stop_gradient=False) + x = paddle.to_tensor(self.complex_symm, stop_gradient=False) w = paddle.linalg.eigvalsh(x) (w.sum()).backward() np.testing.assert_allclose(