Skip to content

Commit

Permalink
Enrich 0d Tensor Dygraph and Shape Unit Test for case and `switch_c…
Browse files Browse the repository at this point in the history
…ase` (#49889)

Followed PR #49842 , added Digraph and Shape unit test for `case` and `switch_case`. This PR only contained test changes because `case` and `switch_case` call `cond`. The PR #49842 has already solved the 0d tensor support.
  • Loading branch information
zhhsplendid authored Jan 18, 2023
1 parent ce04589 commit 7737672
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 8 deletions.
95 changes: 95 additions & 0 deletions python/paddle/fluid/tests/unittests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard

paddle.enable_static()
Expand Down Expand Up @@ -145,10 +146,101 @@ def fn_3():
)

np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[3], 2, rtol=1e-05)
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
self.assertEqual(res[4].shape, ())

def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='bool', fill_value=0)
# pred is False, so out = -x
out = paddle.static.nn.case(
pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x
)
append_backward(out)

place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)

res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(2.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(-1.0), rtol=1e-05
)
self.assertEqual(res[1].shape, ())

def test_0d_tensor_dygraph(self):
paddle.disable_static()

def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)

def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)

def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)

x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3

# call fn_1
out_0 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
)

# call fn_2
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)

# call default fn_3
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
)

# no default, call fn_2
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_2)]
)

# no default, call fn_2. but pred_2 is false
out_4 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_2)]
)

np.testing.assert_allclose(out_0, 1, rtol=1e-05)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(out_1, 2, rtol=1e-05)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(out_2, 3, rtol=1e-05)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(out_3, 2, rtol=1e-05)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(out_4, 2, rtol=1e-05)
self.assertEqual(out_4.shape, [])

paddle.enable_static()

def test_return_var_tuple(self):
def fn_1():
Expand Down Expand Up @@ -394,8 +486,11 @@ def fn_3():
res = exe.run(main_program, fetch_list=[out_1, out_2, out_3])

np.testing.assert_allclose(res[0], 1, rtol=1e-05)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[2], 3, rtol=1e-05)
self.assertEqual(res[2].shape, ())


class TestAPICase_Error(unittest.TestCase):
Expand Down
138 changes: 130 additions & 8 deletions python/paddle/fluid/tests/unittests/test_switch_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard

paddle.enable_static()
Expand Down Expand Up @@ -93,25 +94,25 @@ def fn_3():
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3),
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
np.testing.assert_allclose(
res[3],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[3], 2),
)
np.testing.assert_allclose(
res[4],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[4], 2),
)

def test_0d_tensor(self):
Expand Down Expand Up @@ -176,30 +177,148 @@ def fn_3():
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1),
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 3),
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(
res[3],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[3], 2),
)
self.assertEqual(res[3].shape, ())
np.testing.assert_allclose(
res[4],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 2),
err_msg='result is {} but answer is {}'.format(res[4], 2),
)
self.assertEqual(res[4].shape, ())

def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='int32', fill_value=2)
# pred is 2, so out = 2 * x
out = paddle.static.nn.switch_case(
branch_index=pred,
branch_fns=[(1, lambda: x), (2, lambda: 2 * x)],
default=lambda: -x,
)
append_backward(out)

place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)

res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
np.testing.assert_allclose(
np.asarray(res[0]), np.array(-4.0), rtol=1e-05
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
np.asarray(res[1]), np.array(2.0), rtol=1e-05
)
self.assertEqual(res[1].shape, ())

def test_0d_tensor_dygraph(self):
paddle.disable_static()

def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)

def fn_2():
return paddle.full(shape=[], dtype='int32', fill_value=2)

def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)

index_1 = paddle.full(shape=[], dtype='int32', fill_value=1)
index_2 = paddle.full(shape=[], dtype='int32', fill_value=2)
index_5 = paddle.full(shape=[], dtype='int32', fill_value=5)

# call fn_1
out_0 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)

# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
)

# call default fn_3
out_2 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3,
)

# no default, call fn_2
out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
)

# no default, call fn_2 but branch_index is 5
out_4 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
)
np.testing.assert_allclose(
out_0,
1,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_0, 1),
)
self.assertEqual(out_0.shape, [])
np.testing.assert_allclose(
out_1,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_1, 2),
)
self.assertEqual(out_1.shape, [])
np.testing.assert_allclose(
out_2,
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_2, 3),
)
self.assertEqual(out_2.shape, [])
np.testing.assert_allclose(
out_3,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_3, 2),
)
self.assertEqual(out_3.shape, [])
np.testing.assert_allclose(
out_4,
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(out_4, 2),
)
self.assertEqual(out_4.shape, [])

paddle.enable_static()

def test_return_var_tuple(self):
def fn_1():
Expand Down Expand Up @@ -426,18 +545,21 @@ def fn_3():
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[0], 1),
)
self.assertEqual(res[0].shape, ())
np.testing.assert_allclose(
res[1],
2,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[1], 2),
)
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(
res[2],
3,
rtol=1e-05,
err_msg='result is {} but answer is {}'.format(res[2], 3),
)
self.assertEqual(res[2].shape, ())


# test TypeError and ValueError of api switch_case
Expand Down

0 comments on commit 7737672

Please sign in to comment.