From 10d7f34d9fce02ca2c4150042e0ef2452b377a45 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 22 Dec 2023 08:48:15 +0000 Subject: [PATCH 1/2] fix --- python/paddle/static/nn/control_flow.py | 25 +++++++---- test/legacy_test/test_while_loop_op.py | 59 +++++++++++++++---------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 06ae72db65dcb..74318cf54f452 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -675,25 +675,32 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): pre_cond = cond(*loop_vars) + check_variable_and_dtype( + pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop' + ) + if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1: + raise TypeError( + "the shape of the variable returned by cond should be [1]," + f"but given shape as {list(pre_cond.shape)}." + ) + if in_pir_mode(): while_op = build_while_op(pre_cond, flatten(loop_vars)) with while_op.body() as cur_block: args = cur_block.args() next_var = body(*args) + try: + assert_same_structure(next_var, loop_vars, check_types=False) + except ValueError as e: + raise ValueError( + "body in while_loop should return the same arity " + f"(length and structure) as loop_vars: {e}" + ) next_cond = cond(*next_var) next_cond.stop_gradient = True cf_yield([next_cond, *next_var]) return while_op.as_operation().results() - check_variable_and_dtype( - pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop' - ) - if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1: - raise TypeError( - "the shape of the variable returned by cond should be [1]," - f"but given shape as {list(pre_cond.shape)}." - ) - if in_dygraph_mode(): now_cond = pre_cond.item() while now_cond: diff --git a/test/legacy_test/test_while_loop_op.py b/test/legacy_test/test_while_loop_op.py index 0b73041559c0f..818fc56fc80b9 100644 --- a/test/legacy_test/test_while_loop_op.py +++ b/test/legacy_test/test_while_loop_op.py @@ -488,7 +488,8 @@ def fn_add_one(): class TestApiWhileLoop_Error(unittest.TestCase): @compare_legacy_with_pt - def test_error(self): + @test_with_pir_api + def test_error1(self): def cond_returns_constant(i): return 1 @@ -514,27 +515,9 @@ def body_returns_error_length(i): def body_returns_error_type(i, ten): return paddle.increment(i) - def cond_returns_with_mutable_dict(i, test_dict): - return i > 0 - - def body_returns_with_mutable_dict(i, test_dict): - test_dict['new_key'] = paddle.tensor.fill_constant( - shape=[1], dtype='int64', value=1 - ) - return paddle.increment(i), test_dict - - def cond_returns_with_mutable_list(i, test_list): - return i > 0 - - def body_returns_with_mutable_list(i, test_list): - test_list.append( - paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1) - ) - return paddle.increment(i), test_list - - main_program = Program() - startup_program = Program() - with program_guard(main_program, startup_program): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): data = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=1 ) @@ -621,7 +604,35 @@ def value_error_body_returns_error_type(): self.assertRaises(ValueError, value_error_body_returns_error_type) + @compare_legacy_with_pt + def test_error2(self): + def cond_returns_with_mutable_dict(i, test_dict): + return i > 0 + + def body_returns_with_mutable_dict(i, test_dict): + test_dict['new_key'] = paddle.tensor.fill_constant( + shape=[1], dtype='int64', value=1 + ) + return paddle.increment(i), test_dict + + def cond_returns_with_mutable_list(i, test_list): + return i > 0 + + def body_returns_with_mutable_list(i, test_list): + test_list.append( + paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1) + ) + return paddle.increment(i), test_list + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + data = paddle.tensor.fill_constant( + shape=[1], dtype='int64', value=1 + ) + # The length of `output_vars` with mutable value should keep same with `loop_vars` + # TODO(zhangbo): slice error need to fix, loop_vars support list/dict def value_error_body_returns_with_mutable_dict(): test_dict = { "int_constant": paddle.tensor.fill_constant( @@ -638,6 +649,7 @@ def value_error_body_returns_with_mutable_dict(): ValueError, value_error_body_returns_with_mutable_dict ) + # TODO(zhangbo): loop_vars support list/dict def value_error_body_returns_with_mutable_list(): test_list = [ paddle.tensor.fill_constant( @@ -656,7 +668,8 @@ def value_error_body_returns_with_mutable_list(): class TestApiWhileLoopSliceInBody(unittest.TestCase): - # @compare_legacy_with_pt + @compare_legacy_with_pt + # @test_with_pir_api (need to fix slice bug in pir) def test_var_slice(self): def cond(z, i): return i + 1 <= x_shape[0] From 28b7256d07a59574a1281ecaf9d785c737a02e02 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 23 Dec 2023 01:56:26 +0000 Subject: [PATCH 2/2] fix --- python/paddle/static/nn/control_flow.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 74318cf54f452..f8f73094d7e0f 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -690,7 +690,9 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None): args = cur_block.args() next_var = body(*args) try: - assert_same_structure(next_var, loop_vars, check_types=False) + assert_same_structure( + flatten(next_var), flatten(loop_vars), check_types=False + ) except ValueError as e: raise ValueError( "body in while_loop should return the same arity "