-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PIR API adaptor No.65, 69】Migrate some ops into pir #58698
Changes from 5 commits
c51a08b
1978cf5
f1fb0ee
773d8c2
ba8dc2b
f5ee528
ebb025b
77ed49e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1183,7 +1183,9 @@ def eye(num_rows, num_columns=None, dtype=None, name=None): | |
""" | ||
|
||
def _check_attr(attr, message): | ||
if isinstance(attr, ((Variable, core.eager.Tensor))): | ||
if isinstance( | ||
attr, ((Variable, core.eager.Tensor, paddle.pir.OpResult)) | ||
): | ||
assert len(attr.shape) == 1 and attr.shape[0] in [1, -1] | ||
elif not isinstance(attr, int) or attr < 0: | ||
raise TypeError(f"{message} should be a non-negative int.") | ||
|
@@ -1199,7 +1201,7 @@ def _check_attr(attr, message): | |
else: | ||
num_columns = num_rows | ||
|
||
if in_dynamic_mode(): | ||
if in_dynamic_or_pir_mode(): | ||
out = _C_ops.eye( | ||
num_rows, num_columns, dtype, _current_expected_place() | ||
) | ||
|
@@ -2177,7 +2179,7 @@ def empty_like(x, dtype=None, name=None): | |
dtype = x.dtype | ||
dtype = convert_dtype(dtype) | ||
|
||
if in_dynamic_mode(): | ||
if in_dynamic_or_pir_mode(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
out = _C_ops.empty( | ||
x.shape, | ||
convert_np_dtype_to_dtype_(dtype), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件下继承自 TestEmptyLikeAPICommon 的单测均为动态图单测,不在本次组网API PIR 迁移的单测覆盖目标内。目前看来 empty_like 的单测中,除了 TestEmptyError 尚未支持,其他都已经适配了。所以请更新一下 pr 描述里的单测覆盖率吧~ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件下还缺少了 API_TestTensorEye 单测没有适配,辛苦有空的时候适配一下吧~ 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from paddle import base | ||
from paddle.base import core, framework | ||
from paddle.base.framework import Program, program_guard | ||
from paddle.pir_utils import test_with_pir_api | ||
|
||
|
||
class TestEyeOp(OpTest): | ||
|
@@ -46,7 +47,7 @@ def setUp(self): | |
} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
self.check_output(check_pir=True) | ||
|
||
def init_dtype(self): | ||
self.dtype = np.int32 | ||
|
@@ -69,7 +70,7 @@ def setUp(self): | |
self.outputs = {'Out': np.eye(50, dtype=float)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
self.check_output(check_pir=True) | ||
|
||
|
||
class TestEyeOp2(OpTest): | ||
|
@@ -85,7 +86,7 @@ def setUp(self): | |
self.outputs = {'Out': np.eye(99, 1, dtype=float)} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
self.check_output(check_pir=True) | ||
|
||
|
||
class API_TestTensorEye(unittest.TestCase): | ||
|
@@ -144,6 +145,7 @@ def init_info(self): | |
self.shapes = [[2, 3, 4]] | ||
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) | ||
|
||
@test_with_pir_api | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以先跳过这个单测,然后在 PR 描述里记录该单测暂不支持 |
||
def test_static(self): | ||
main_prog = Program() | ||
starup_prog = Program() | ||
|
@@ -215,7 +217,7 @@ def setUp(self): | |
|
||
def test_check_output(self): | ||
place = core.CUDAPlace(0) | ||
self.check_output_with_place(place) | ||
self.check_output_with_place(place, check_pir=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的适配有些问题。上面的 _check_attr 应该还需要适配 OpResult 的情况: