Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.183】 Migrate python/paddle/nn/layer/rnn.py into pir #60180

Merged
merged 8 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from paddle.base.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.base.framework import (
default_startup_program,
in_dygraph_mode,
in_dynamic_or_pir_mode,
program_guard,
)
from paddle.common_ops_import import Variable
Expand Down Expand Up @@ -106,7 +106,7 @@ def rnn(

"""

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
return _rnn_dynamic_graph(
cell,
inputs,
Expand Down Expand Up @@ -1590,7 +1590,7 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])

if in_dygraph_mode():
if in_dynamic_or_pir_mode():
out, _, state = _C_ops.rnn(
inputs,
initial_states,
Expand All @@ -1606,29 +1606,6 @@ def _cudnn_impl(self, inputs, initial_states, sequence_length):
0,
not self.training,
)
elif in_dynamic_mode():
_, _, out, state = _legacy_C_ops.rnn(
inputs,
initial_states,
self._all_weights,
sequence_length,
self._dropout_state,
self.state_components,
'dropout_prob',
self.dropout,
'is_bidirec',
self.num_directions == 2,
'input_size',
self.input_size,
'hidden_size',
self.hidden_size,
'num_layers',
self.num_layers,
'mode',
self.mode,
'is_test',
not self.training,
)
else:
out = self._helper.create_variable_for_type_inference(inputs.dtype)
state = [
Expand Down
97 changes: 62 additions & 35 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
from testsuite import append_input_output, append_loss_ops, create_op, set_input

sys.path.append("..")
# Add test/legacy and test to sys.path
legacy_test_dir = pathlib.Path(__file__).parent # test/legacy_test
test_dir = legacy_test_dir.parent # test
sys.path.append(str(legacy_test_dir.absolute()))
sys.path.append(str(test_dir.absolute()))

from utils import static_guard
from white_list import (
check_shape_white_list,
Expand All @@ -66,8 +71,6 @@
)
from paddle.base.wrapped_decorator import signature_safe_contextmanager

sys.path.append(os.path.abspath(os.path.dirname(__file__)))


@signature_safe_contextmanager
def paddle_static_guard():
Expand Down Expand Up @@ -1385,7 +1388,8 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
fetch_list = getattr(self, "fetch_list", [])
# if the fetch_list is customized by user, we use it directly.
# if not, fill the fetch_list by the user configured outputs in test.

# filter ret_tuple
ret_to_check = []
if len(fetch_list) == 0:
if isinstance(ret_tuple, (tuple, list)):
assert len(ret_tuple) == len(outputs_sig)
Expand All @@ -1395,14 +1399,17 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if not self._need_fetch(sig_name):
continue
if isinstance(var, list):
ret_to_check.append(var)
for v in var:
fetch_list.append(v)
else:
ret_to_check.append(var)
fetch_list.append(var)
elif isinstance(
ret_tuple, paddle.base.libpaddle.pir.OpResult
):
fetch_list.append(ret_tuple)
ret_to_check = ret_tuple
elif ret_tuple is None:
pass
else:
Expand All @@ -1415,19 +1422,27 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
outs = executor.run(
ir_program, feed=feed, fetch_list=[fetch_list]
)

outputs_sig = [
sig_name
for sig_name in outputs_sig
if self._need_fetch(sig_name)
]

if paddle.utils.is_sequence(
ret_to_check
) and paddle.utils.is_sequence(outs):
outs = paddle.utils.pack_sequence_as(ret_to_check, outs)

result = construct_output_dict_by_kernel_sig(outs, outputs_sig)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
result[key][0] = {
a: [b]
for a, b in zip(
self.python_out_sig_sub_name[key],
result[key][0],
)
}
return result

def _check_ir_output(self, place, program, feed_map, fetch_list, outs):
Expand Down Expand Up @@ -2435,12 +2450,24 @@ def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
expect_np = convert_uint16_to_float(expect_np)
return actual_np, expect_np

def find_imperative_actual(target_name, pir_outs, place):
def find_pir_actual(self, target_name, pir_outs, place):
for name in pir_outs:
if name == target_name:
return pir_outs[name][0]

var_list = pir_outs[name]
sub_dict = pir_outs[name][0]
if isinstance(sub_dict, dict):
for key, value in sub_dict.items():
if key == target_name:
return value[0]

raise AssertionError("No pir output named " + target_name)

def find_pir_expect(self, target_name, dygraph_outs, place):
for name in dygraph_outs:
if name == target_name:
return dygraph_outs[name][0]
var_list = dygraph_outs[name]
for i, var in enumerate(var_list):
if isinstance(var, list):
for tensor in var:
Expand All @@ -2450,26 +2477,14 @@ def find_imperative_actual(target_name, pir_outs, place):
isinstance(var, paddle.Tensor)
and var.name == target_name
):
return pir_outs[name][i]
self.assertTrue(
False,
f"Found failed {pir_outs.keys()} {target_name}",
)

def find_imperative_expect(self, target_name, pir_outs, place):
for name in pir_outs:
if name == target_name:
return pir_outs[name][0]
self.assertTrue(
False,
f"Found failed {pir_outs.keys()} {target_name}",
)
return dygraph_outs[name][i]
raise AssertionError("No pir ref_output named " + target_name)

def find_actual_value(self, target_name):
with paddle.pir.core.program_guard(
paddle.pir.core.default_main_program()
):
actual = find_imperative_actual(
actual = self.find_pir_actual(
target_name, self.outputs, place
)
actual_t = np.array(actual)
Expand All @@ -2479,7 +2494,7 @@ def find_expect_value(self, target_name):
with paddle.pir.core.program_guard(
paddle.pir.core.default_main_program()
):
expect = self.find_imperative_expect(
expect = self.find_pir_expect(
target_name, self.ref_outputs, place
)
expect_t = np.array(expect)
Expand Down Expand Up @@ -3674,10 +3689,19 @@ def _get_gradient(

return res

def _find_var_in_pir(self, output_vars, name):
if name in output_vars:
return output_vars[name]
raise AssertionError(name, " not in outputs:", output_vars.keys())
def _find_var_in_pir(self, output_vars, target_name):
for name in output_vars:
if name == target_name:
return output_vars[name]

sub_dict = output_vars[name][0]
if isinstance(sub_dict, dict):
for key, value in sub_dict.items():
if key == target_name:
return value
raise AssertionError(
target_name, " not in outputs:", output_vars.keys()
)

def _get_ir_gradient(
self,
Expand Down Expand Up @@ -3751,10 +3775,13 @@ def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
outputs[key][0][
i
].name = self.python_out_sig_sub_name[key][i]
outputs[key][0] = {
a: [b]
for a, b in zip(
self.python_out_sig_sub_name[key],
outputs[key][0],
)
}
fetch_list = getattr(self, "fetch_list", [])

# cast outputs
Expand Down
27 changes: 20 additions & 7 deletions test/legacy_test/test_rnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
import random
import sys
import unittest
from pathlib import Path

import numpy as np
from op_test import OpTest

import paddle
from paddle.base import core

sys.path.append("../../test/rnn")
# Add test/rnn to sys.path
legacy_test_dir = Path(__file__).resolve().parents[1]
sys.path.append(str(legacy_test_dir / "rnn"))
from convert import get_params_for_net
from rnn_numpy import LSTM

Expand All @@ -45,7 +48,7 @@ def rnn_wrapper(
seed=0,
is_test=False,
):
dropout_state_in = paddle.Tensor()
dropout_state_in = paddle.tensor.fill_constant([], "float32", 0.0)
return paddle._C_ops.rnn(
Input,
PreState,
Expand Down Expand Up @@ -168,7 +171,9 @@ def rocm_rnn_get_place():
}

def test_output(self):
self.check_output(no_check_set=['Reserve', 'DropoutState'])
self.check_output(
no_check_set=['Reserve', 'DropoutState'], check_pir=True
)

def set_attrs(self):
pass
Expand All @@ -179,7 +184,9 @@ def test_grad(self):
grad_check_list = ['Input', 'init_h', 'init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_input(self):
Expand All @@ -188,7 +195,9 @@ def test_grad_only_input(self):
grad_check_list = ['Input']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_h(self):
Expand All @@ -197,7 +206,9 @@ def test_grad_only_h(self):
grad_check_list = ['init_h']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)

def test_grad_only_c(self):
Expand All @@ -206,7 +217,9 @@ def test_grad_only_c(self):
grad_check_list = ['init_c']
grad_check_list.extend(var_name_list)
self.check_grad(
set(grad_check_list), ['Out', 'last_hidden', 'last_cell']
set(grad_check_list),
['Out', 'last_hidden', 'last_cell'],
check_pir=True,
)


Expand Down