Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support some while_loop op_test #60271

Merged
merged 4 commits into from
Dec 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,25 +675,34 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):

pre_cond = cond(*loop_vars)

check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
f"but given shape as {list(pre_cond.shape)}."
)

if in_pir_mode():
while_op = build_while_op(pre_cond, flatten(loop_vars))
with while_op.body() as cur_block:
args = cur_block.args()
next_var = body(*args)
try:
assert_same_structure(
flatten(next_var), flatten(loop_vars), check_types=False
)
except ValueError as e:
raise ValueError(
"body in while_loop should return the same arity "
f"(length and structure) as loop_vars: {e}"
)
next_cond = cond(*next_var)
next_cond.stop_gradient = True
cf_yield([next_cond, *next_var])
return while_op.as_operation().results()

check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
f"but given shape as {list(pre_cond.shape)}."
)

if in_dygraph_mode():
now_cond = pre_cond.item()
while now_cond:
Expand Down
55 changes: 34 additions & 21 deletions test/legacy_test/test_while_loop_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,8 @@ def fn_add_one():

class TestApiWhileLoop_Error(unittest.TestCase):
@compare_legacy_with_pt
def test_error(self):
@test_with_pir_api
def test_error1(self):
def cond_returns_constant(i):
return 1

Expand All @@ -549,27 +550,9 @@ def body_returns_error_length(i):
def body_returns_error_type(i, ten):
return paddle.increment(i)

def cond_returns_with_mutable_dict(i, test_dict):
return i > 0

def body_returns_with_mutable_dict(i, test_dict):
test_dict['new_key'] = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
return paddle.increment(i), test_dict

def cond_returns_with_mutable_list(i, test_list):
return i > 0

def body_returns_with_mutable_list(i, test_list):
test_list.append(
paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1)
)
return paddle.increment(i), test_list

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with program_guard(main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
data = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
Expand Down Expand Up @@ -656,7 +639,35 @@ def value_error_body_returns_error_type():

self.assertRaises(ValueError, value_error_body_returns_error_type)

@compare_legacy_with_pt
def test_error2(self):
def cond_returns_with_mutable_dict(i, test_dict):
return i > 0

def body_returns_with_mutable_dict(i, test_dict):
test_dict['new_key'] = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
return paddle.increment(i), test_dict

def cond_returns_with_mutable_list(i, test_list):
return i > 0

def body_returns_with_mutable_list(i, test_list):
test_list.append(
paddle.tensor.fill_constant(shape=[1], dtype='int64', value=1)
)
return paddle.increment(i), test_list

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)

# The length of `output_vars` with mutable value should keep same with `loop_vars`
# TODO(zhangbo): slice error need to fix, loop_vars support list/dict
def value_error_body_returns_with_mutable_dict():
test_dict = {
"int_constant": paddle.tensor.fill_constant(
Expand All @@ -673,6 +684,7 @@ def value_error_body_returns_with_mutable_dict():
ValueError, value_error_body_returns_with_mutable_dict
)

# TODO(zhangbo): loop_vars support list/dict
def value_error_body_returns_with_mutable_list():
test_list = [
paddle.tensor.fill_constant(
Expand All @@ -691,7 +703,8 @@ def value_error_body_returns_with_mutable_list():


class TestApiWhileLoopSliceInBody(unittest.TestCase):
# @compare_legacy_with_pt
@compare_legacy_with_pt
# @test_with_pir_api (need to fix slice bug in pir)
def test_var_slice(self):
def cond(z, i):
return i + 1 <= x_shape[0]
Expand Down