From 9bc3f7eea782c02d4271fe21028c54a411a2fcd3 Mon Sep 17 00:00:00 2001 From: zhouwei25 Date: Fri, 13 Jan 2023 07:11:33 +0000 Subject: [PATCH] [Zero-Dim]simplify static unittest --- .../tests/unittests/test_zero_dim_tensor.py | 224 +++++++----------- .../unittests/xpu/test_zero_dim_tensor_xpu.py | 2 +- 2 files changed, 87 insertions(+), 139 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 8bea782b74425..8961de5540a76 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -20,7 +20,6 @@ import paddle import paddle.fluid as fluid import paddle.nn.functional as F -from paddle.fluid.framework import grad_var_name fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) @@ -138,10 +137,8 @@ def test_static_unary(self): paddle.static.append_backward(loss) fetch_list = [x, out] - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - fetch_list.extend([x_grad, out_grad]) + if block.has_var(x.grad_name): + fetch_list.extend([x.grad_name, out.grad_name]) # 1) Test Program res = exe.run(main_prog, fetch_list=fetch_list) @@ -235,10 +232,9 @@ def test_static_reduce(self): paddle.static.append_backward(out.sum()) fetch_list = [x, out] - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - fetch_list.append([x_grad, out_grad]) + if block.has_var(x.grad_name): + fetch_list.extend([x.grad_name, out.grad_name]) + res = exe.run(main_prog, fetch_list=fetch_list) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) @@ -411,10 +407,10 @@ def test_static_binary(self): self.assertEqual(x.shape, ()) self.assertEqual(y.shape, ()) self.assertEqual(out.shape, ()) - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - y_grad = block.var(grad_var_name(y.name)) + if block.has_var(x.grad_name): + out_grad = block.var(out.grad_name) + x_grad = block.var(x.grad_name) + y_grad = block.var(y.grad_name) self.assertEqual(x_grad.shape, ()) self.assertEqual(y_grad.shape, ()) @@ -438,10 +434,10 @@ def test_static_binary(self): self.assertEqual(x.shape, ()) self.assertEqual(y.shape, (2, 3, 4)) self.assertEqual(out.shape, (2, 3, 4)) - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - y_grad = block.var(grad_var_name(y.name)) + if block.has_var(x.grad_name): + out_grad = block.var(out.grad_name) + x_grad = block.var(x.grad_name) + y_grad = block.var(y.grad_name) self.assertEqual(x_grad.shape, ()) self.assertEqual(y_grad.shape, (2, 3, 4)) @@ -465,10 +461,10 @@ def test_static_binary(self): self.assertEqual(x.shape, (2, 3, 4)) self.assertEqual(y.shape, ()) self.assertEqual(out.shape, (2, 3, 4)) - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - y_grad = block.var(grad_var_name(y.name)) + if block.has_var(x.grad_name): + out_grad = block.var(out.grad_name) + x_grad = block.var(x.grad_name) + y_grad = block.var(y.grad_name) self.assertEqual(x_grad.shape, (2, 3, 4)) self.assertEqual(y_grad.shape, ()) @@ -489,9 +485,9 @@ def test_static_binary(self): self.assertEqual(x.shape, ()) self.assertEqual(out.shape, ()) - if block.has_var(grad_var_name(x.name)): - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) + if block.has_var(x.name): + out_grad = block.var(out.grad_name) + x_grad = block.var(x.grad_name) self.assertEqual(out_grad.shape, ()) self.assertEqual(x_grad.shape, ()) @@ -1160,10 +1156,9 @@ def test_flip(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1177,10 +1172,9 @@ def test_pow_factor(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1194,10 +1188,9 @@ def test_cast(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1211,10 +1204,7 @@ def test_cumprod(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1230,10 +1220,9 @@ def test_clip(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1247,10 +1236,9 @@ def test_increment(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1299,10 +1287,7 @@ def test_gather_1D(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, ()) self.assertEqual(res[0], 1) self.assertEqual(res[1].shape, (10,)) @@ -1317,10 +1302,7 @@ def test_gather_XD_axis_0(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, (3,)) np.testing.assert_array_equal(res[0], [1.0, 1.0, 1.0]) self.assertEqual(res[1].shape, (2, 3)) @@ -1335,10 +1317,7 @@ def _test_gather_XD_axis_1(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, (2,)) np.testing.assert_array_equal(res[0], [1.0, 1.0]) self.assertEqual(res[1].shape, (2, 3)) @@ -1354,10 +1333,7 @@ def test_scatter_1D(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, (10,)) self.assertEqual(res[0][2], 4.0) self.assertEqual(res[1].shape, (10,)) @@ -1373,10 +1349,7 @@ def test_scatter_XD(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, (2, 3)) np.testing.assert_array_equal(res[0][1], [4.0, 4.0, 4.0]) self.assertEqual(res[1].shape, (2, 3)) @@ -1431,10 +1404,9 @@ def test_scatter_nd(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - updates_grad = block.var(grad_var_name(updates.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, out_grad, updates_grad]) + res = self.exe.run( + prog, fetch_list=[out, out.grad_name, updates.grad_name] + ) self.assertEqual(res[0].shape, (5,)) self.assertEqual(res[0][3], 2) self.assertEqual(res[1].shape, (5,)) @@ -1448,9 +1420,7 @@ def test_kthvalue(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - res = self.exe.run(prog, fetch_list=[out, index, x_grad]) + res = self.exe.run(prog, fetch_list=[out, index, x.grad_name]) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1464,9 +1434,7 @@ def test_mode(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - res = self.exe.run(prog, fetch_list=[out, index, x_grad]) + res = self.exe.run(prog, fetch_list=[out, index, x.grad_name]) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1484,10 +1452,9 @@ def test_flatten(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - out_grad = block.var(grad_var_name(out.name)) - x_grad = block.var(grad_var_name(x.name)) - res = self.exe.run(prog, feed={}, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run( + prog, feed={}, fetch_list=[out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, (1,)) self.assertEqual(res[1].shape, ()) @@ -1501,10 +1468,7 @@ def test_scale(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[out, x_grad, out_grad]) + res = self.exe.run(prog, fetch_list=[out, x.grad_name, out.grad_name]) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1563,15 +1527,6 @@ def test_reshape_list(self): paddle.static.append_backward(out4.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x1_grad = block.var(grad_var_name(x1.name)) - x2_grad = block.var(grad_var_name(x2.name)) - x3_grad = block.var(grad_var_name(x3.name)) - x4_grad = block.var(grad_var_name(x4.name)) - out1_grad = block.var(grad_var_name(out1.name)) - out2_grad = block.var(grad_var_name(out2.name)) - out3_grad = block.var(grad_var_name(out3.name)) - out4_grad = block.var(grad_var_name(out4.name)) res = self.exe.run( prog, fetch_list=[ @@ -1579,14 +1534,14 @@ def test_reshape_list(self): out2, out3, out4, - x1_grad, - x2_grad, - x3_grad, - x4_grad, - out1_grad, - out2_grad, - out3_grad, - out4_grad, + x1.grad_name, + x2.grad_name, + x3.grad_name, + x4.grad_name, + out1.grad_name, + out2.grad_name, + out3.grad_name, + out4.grad_name, ], ) self.assertEqual(res[0].shape, ()) @@ -1625,25 +1580,18 @@ def test_reshape_tensor(self): paddle.static.append_backward(out3.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x1_grad = block.var(grad_var_name(x1.name)) - x2_grad = block.var(grad_var_name(x2.name)) - x3_grad = block.var(grad_var_name(x3.name)) - out1_grad = block.var(grad_var_name(out1.name)) - out2_grad = block.var(grad_var_name(out2.name)) - out3_grad = block.var(grad_var_name(out3.name)) res = self.exe.run( prog, fetch_list=[ out1, out2, out3, - x1_grad, - x2_grad, - x3_grad, - out1_grad, - out2_grad, - out3_grad, + x1.grad_name, + x2.grad_name, + x3.grad_name, + out1.grad_name, + out2.grad_name, + out3.grad_name, ], ) self.assertEqual(res[0].shape, (1, 1, 1)) @@ -1667,10 +1615,9 @@ def test_reverse(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - out_grad = block.var(grad_var_name(out.name)) - res = self.exe.run(prog, fetch_list=[x, out, x_grad, out_grad]) + res = self.exe.run( + prog, fetch_list=[x, out, x.grad_name, out.grad_name] + ) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) self.assertEqual(res[2].shape, ()) @@ -1689,14 +1636,16 @@ def test_sort(self): paddle.static.append_backward(out2.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x1_grad = block.var(grad_var_name(x1.name)) - x2_grad = block.var(grad_var_name(x2.name)) - out1_grad = block.var(grad_var_name(out1.name)) - out2_grad = block.var(grad_var_name(out2.name)) res = self.exe.run( prog, - fetch_list=[out1, out2, out1_grad, out2_grad, x1_grad, x2_grad], + fetch_list=[ + out1, + out2, + out1.grad_name, + out2.grad_name, + x1.grad_name, + x2.grad_name, + ], ) self.assertEqual(res[0].shape, ()) @@ -1744,12 +1693,9 @@ def test_lerp(self): paddle.static.append_backward(out.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x_grad = block.var(grad_var_name(x.name)) - y_grad = block.var(grad_var_name(y.name)) - out_grad = block.var(grad_var_name(out.name)) - - res = self.exe.run(prog, fetch_list=[out, out_grad, y_grad, x_grad]) + res = self.exe.run( + prog, fetch_list=[out, out.grad_name, y.grad_name, x.grad_name] + ) self.assertEqual(res[0].shape, shape[3]) self.assertEqual(res[1].shape, shape[3]) self.assertEqual(res[2].shape, shape[1]) @@ -1769,14 +1715,16 @@ def test_repeat_interleave(self): paddle.static.append_backward(out2.sum()) prog = paddle.static.default_main_program() - block = prog.global_block() - x1_grad = block.var(grad_var_name(x1.name)) - x2_grad = block.var(grad_var_name(x2.name)) - out1_grad = block.var(grad_var_name(out1.name)) - out2_grad = block.var(grad_var_name(out2.name)) res = self.exe.run( prog, - fetch_list=[out1, out2, x1_grad, x2_grad, out1_grad, out2_grad], + fetch_list=[ + out1, + out2, + x1.grad_name, + x2.grad_name, + out1.grad_name, + out2.grad_name, + ], ) self.assertEqual(res[0].shape, (2,)) self.assertEqual(res[1].shape, (3,)) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index 97925b72beaa7..728fcb09f8333 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -463,7 +463,7 @@ def test_gather_xD_axis_0(self): self.assertEqual(x.grad.shape, [2, 3]) self.assertEqual(out.grad.shape, [3]) - def test_gather_xD_axis_1(self): + def _test_gather_xD_axis_1(self): x = paddle.to_tensor( [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False )