Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR API adaptor No.65, 69】Migrate some ops into pir #58698

Merged
merged 8 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,7 +1183,9 @@ def eye(num_rows, num_columns=None, dtype=None, name=None):
"""

def _check_attr(attr, message):
if isinstance(attr, ((Variable, core.eager.Tensor))):
if isinstance(
attr, ((Variable, core.eager.Tensor, paddle.pir.OpResult))
):
assert len(attr.shape) == 1 and attr.shape[0] in [1, -1]
elif not isinstance(attr, int) or attr < 0:
raise TypeError(f"{message} should be a non-negative int.")
Expand All @@ -1199,7 +1201,7 @@ def _check_attr(attr, message):
else:
num_columns = num_rows

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的适配有些问题。上面的 _check_attr 应该还需要适配 OpResult 的情况:
image

out = _C_ops.eye(
num_rows, num_columns, dtype, _current_expected_place()
)
Expand Down Expand Up @@ -2177,7 +2179,7 @@ def empty_like(x, dtype=None, name=None):
dtype = x.dtype
dtype = convert_dtype(dtype)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里静态图多了一些处理逻辑,有时候动态图逻辑是不能复用的,需要有pir自己的逻辑的
image
可修改如下:

    elif in_pir_mode():
        shape = paddle.shape(x)
        out = _C_ops.empty(
            shape,
            convert_np_dtype_to_dtype_(dtype),
            _current_expected_place(),
        )
        out.stop_gradient = True
        return out

out = _C_ops.empty(
x.shape,
convert_np_dtype_to_dtype_(dtype),
Expand Down
33 changes: 18 additions & 15 deletions test/legacy_test/test_empty_like_op.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件下继承自 TestEmptyLikeAPICommon 的单测均为动态图单测,不在本次组网API PIR 迁移的单测覆盖目标内。目前看来 empty_like 的单测中,除了 TestEmptyError 尚未支持,其他都已经适配了。所以请更新一下 pr 描述里的单测覆盖率吧~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import paddle
from paddle.base import core
from paddle.base.data_feeder import convert_dtype
from paddle.static import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestEmptyLikeAPICommon(unittest.TestCase):
Expand Down Expand Up @@ -163,32 +163,33 @@ class TestEmptyLikeAPI_Static(TestEmptyLikeAPICommon):
def setUp(self):
self.init_config()

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
train_program = Program()
startup_program = Program()
train_program = paddle.static.Program()
startup_program = paddle.static.Program()

with program_guard(train_program, startup_program):
with paddle.static.program_guard(train_program, startup_program):
x = np.random.random(self.x_shape).astype(self.dtype)
data_x = paddle.static.data(
'x', shape=self.data_x_shape, dtype=self.dtype
)

out = paddle.empty_like(data_x)

place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])
place = (
paddle.CUDAPlace(0)
if core.is_compiled_with_cuda()
else paddle.CPUPlace()
)
exe = paddle.static.Executor(place)
res = exe.run(train_program, feed={'x': x}, fetch_list=[out])

self.dst_dtype = self.dtype
self.dst_shape = x.shape
self.__check_out__(res[0])
self.dst_dtype = self.dtype
self.dst_shape = x.shape
self.__check_out__(res[0])

paddle.disable_static()
paddle.disable_static()

def init_config(self):
self.x_shape = (200, 3)
Expand All @@ -212,6 +213,7 @@ def init_config(self):
self.data_x_shape = [200, 3]
self.dtype = 'float16'

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand Down Expand Up @@ -245,6 +247,7 @@ def init_config(self):
self.data_x_shape = [200, 3]
self.dtype = 'uint16'

@test_with_pir_api
def test_static_graph(self):
paddle.enable_static()
if paddle.base.core.is_compiled_with_cuda():
Expand Down
10 changes: 6 additions & 4 deletions test/legacy_test/test_eye_op.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件下还缺少了 API_TestTensorEye 单测没有适配,辛苦有空的时候适配一下吧~ 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle import base
from paddle.base import core, framework
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api


class TestEyeOp(OpTest):
Expand All @@ -46,7 +47,7 @@ def setUp(self):
}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def init_dtype(self):
self.dtype = np.int32
Expand All @@ -69,7 +70,7 @@ def setUp(self):
self.outputs = {'Out': np.eye(50, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestEyeOp2(OpTest):
Expand All @@ -85,7 +86,7 @@ def setUp(self):
self.outputs = {'Out': np.eye(99, 1, dtype=float)}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class API_TestTensorEye(unittest.TestCase):
Expand Down Expand Up @@ -144,6 +145,7 @@ def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())

@test_with_pir_api
Copy link
Contributor

@MarioLulab MarioLulab Nov 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以先跳过这个单测,然后在 PR 描述里记录该单测暂不支持

def test_static(self):
main_prog = Program()
starup_prog = Program()
Expand Down Expand Up @@ -215,7 +217,7 @@ def setUp(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)


if __name__ == "__main__":
Expand Down