Skip to content

Commit

Permalink
[PIR] Support some while_loop op_test (PaddlePaddle#60271)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent 458d99c commit e47e4cd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 30 deletions.
27 changes: 18 additions & 9 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,25 +675,34 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):

pre_cond = cond(*loop_vars)

check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
f"but given shape as {list(pre_cond.shape)}."
)

if in_pir_mode():
while_op = build_while_op(pre_cond, flatten(loop_vars))
with while_op.body() as cur_block:
args = cur_block.args()
next_var = body(*args)
try:
assert_same_structure(
flatten(next_var), flatten(loop_vars), check_types=False
)
except ValueError as e:
raise ValueError(
"body in while_loop should return the same arity "
f"(length and structure) as loop_vars: {e}"
)
next_cond = cond(*next_var)
next_cond.stop_gradient = True
cf_yield([next_cond, *next_var])
return while_op.as_operation().results()

check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
f"but given shape as {list(pre_cond.shape)}."
)

if in_dygraph_mode():
now_cond = pre_cond.item()
while now_cond:
Expand Down
55 changes: 34 additions & 21 deletions test/legacy_test/test_while_loop_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ def fn_add_one():

class TestApiWhileLoop_Error(unittest.TestCase):
@compare_legacy_with_pt
def test_error(self):
@test_with_pir_api
def test_error1(self):
def cond_returns_constant(i):
return 1

Expand All @@ -549,27 +550,9 @@ def body_returns_error_length(i):
def body_returns_error_type(i, ten):
return paddle.increment(i)

def cond_returns_with_mutable_dict(i, test_dict):
return i > 0

def body_returns_with_mutable_dict(i, test_dict):
test_dict['new_key'] = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
return paddle.increment(i), test_dict

def cond_returns_with_mutable_list(i, test_list):
return i > 0

def body_returns_with_mutable_list(i, test_list):
test_list.append(
paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1)
)
return paddle.increment(i), test_list

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with program_guard(main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
data = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
Expand Down Expand Up @@ -656,7 +639,35 @@ def value_error_body_returns_error_type():

self.assertRaises(ValueError, value_error_body_returns_error_type)

@compare_legacy_with_pt
def test_error2(self):
def cond_returns_with_mutable_dict(i, test_dict):
return i > 0

def body_returns_with_mutable_dict(i, test_dict):
test_dict['new_key'] = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
return paddle.increment(i), test_dict

def cond_returns_with_mutable_list(i, test_list):
return i > 0

def body_returns_with_mutable_list(i, test_list):
test_list.append(
paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1)
)
return paddle.increment(i), test_list

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)

# The length of `output_vars` with mutable value should keep same with `loop_vars`
# TODO(zhangbo): slice error need to fix, loop_vars support list/dict
def value_error_body_returns_with_mutable_dict():
test_dict = {
"int_constant": paddle.tensor.fill_constant(
Expand All @@ -673,6 +684,7 @@ def value_error_body_returns_with_mutable_dict():
ValueError, value_error_body_returns_with_mutable_dict
)

# TODO(zhangbo): loop_vars support list/dict
def value_error_body_returns_with_mutable_list():
test_list = [
paddle.tensor.fill_constant(
Expand All @@ -691,7 +703,8 @@ def value_error_body_returns_with_mutable_list():


class TestApiWhileLoopSliceInBody(unittest.TestCase):
# @compare_legacy_with_pt
@compare_legacy_with_pt
# @test_with_pir_api (need to fix slice bug in pir)
def test_var_slice(self):
def cond(z, i):
return i + 1 <= x_shape[0]
Expand Down

0 comments on commit e47e4cd

Please sign in to comment.