Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon 5th No.49][pir] add some method property - Part 2 #58042

Merged
merged 9 commits into from
Oct 17, 2023
4 changes: 2 additions & 2 deletions python/paddle/base/layers/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def pop(self, *args):

if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `append` method, but received type: {}".format(
"Only Variable with VarType.LOD_TENSOR_ARRAY support `pop` method, but received type: {}".format(
self.type
)
)
Expand Down Expand Up @@ -393,7 +393,7 @@ def _ndim_(self):
>>> # create a static Variable
>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
>>> # print the dimension of the Variable
>>> print(x.ndim())
>>> print(x.ndim)
3
"""
return len(self.shape)
Expand Down
114 changes: 106 additions & 8 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,23 @@
# limitations under the License.


import warnings

from paddle.base.libpaddle import DataType

from . import OpResult

_already_patch_opresult = False

_supported_int_dtype_ = [
DataType.BOOL,
DataType.UINT8,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
]


def create_tensor_with_batchsize(ref_var, value, dtype):
assert isinstance(ref_var, OpResult)
Expand Down Expand Up @@ -54,14 +65,96 @@ def safe_get_dtype(var):
raise ValueError("Cannot get data type from var")
return dtype

_supported_int_dtype_ = [
DataType.BOOL,
DataType.UINT8,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
]
def place(self):
"""
OpResult don't have 'place' interface in static graph mode
But this interface can greatly facilitate dy2static.
So we give a warnning here and return None.
"""
warnings.warn(
"OpResult do not have 'place' interface for pir graph mode, try not to use it. None will be returned."
)

@property
def _ndim_(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最后为什么还有一个下划线,_item只有前下划线。可以统一一下,比如:

  1. 对于builtin 方法的映射,若有,可以「单下划线+名字+单下划线」,如 xx
  2. 对于 property 属性的映射,若有,可以「单下划线+名字」,如 _x
  3. 对于 方法的映射,若有,直接「名字」即可,如x

类似如上的约定

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, 已同步至动态图 Tensor 和老 IR Varialbe

"""
Returns the dimension of current OpResult

Returns:
the dimension

Examples:
.. code-block:: python

>>> import paddle

>>> paddle.enable_static()

>>> # create a static OpResult
>>> with paddle.pir_utils.IrGuard():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新加的接口如果和之前是一致的,这里可以去掉with paddle.pir_utils.IrGuard(),后期pir推广默认不需要这个guard

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, 已全部删除

>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
>>> # print the dimension of the OpResult
>>> print(x.ndim)
3
"""
return len(self.shape)

def ndimension(self):
"""
Returns the dimension of current OpResult

Returns:
the dimension

Examples:
.. code-block:: python

>>> import paddle

>>> paddle.enable_static()

>>> # create a static OpResult
>>> with paddle.pir_utils.IrGuard():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
>>> # print the dimension of the OpResult
>>> print(x.ndimension())
3
"""
return len(self.shape)

def dim(self):
"""
Returns the dimension of current OpResult

Returns:
the dimension

Examples:
.. code-block:: python

>>> import paddle

>>> paddle.enable_static()

>>> # create a static OpResult
>>> with paddle.pir_utils.IrGuard():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
>>> # print the dimension of the OpResult
>>> print(x.dim())
3
"""
return len(self.shape)

def _item(self):
"""
In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self.
It will check that the shape must be a 1-D tensor
"""
if len(self.shape) > 1:
raise TypeError(
f"Required input var should be 1-D OpResult, but received {self.shape}"
)
return self

def _scalar_div_(var, value):
return paddle.scale(var, 1.0 / value, 0.0)
Expand Down Expand Up @@ -166,6 +259,11 @@ def __impl__(self, other_var):
import paddle

opresult_methods = [
('place', place),
('item', _item),
('dim', dim),
('ndimension', ndimension),
('ndim', _ndim_),
(
'__div__',
_binary_creator_(
Expand Down
30 changes: 30 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,43 @@

import inspect
import unittest
import warnings

import paddle

paddle.enable_static()


class TestMathOpPatchesPir(unittest.TestCase):
def test_item(self):
with paddle.pir_utils.IrGuard():
x = paddle.static.data(name='x', shape=[3, 2, 1])
y = paddle.static.data(
name='y',
shape=[
3,
],
)
self.assertTrue(y.item() == y)
with self.assertRaises(TypeError):
x.item()

def test_place(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.pir_utils.IrGuard():
x = paddle.static.data(name='x', shape=[3, 2, 1])
x.place()
self.assertTrue(len(w) == 1)
self.assertTrue("place" in str(w[-1].message))

def test_some_dim(self):
with paddle.pir_utils.IrGuard():
x = paddle.static.data(name='x', shape=[3, 2, 1])
self.assertEqual(x.dim(), 3)
self.assertEqual(x.ndimension(), 3)
self.assertEqual(x.ndim, 3)

def test_math_exists(self):
with paddle.pir_utils.IrGuard():
a = paddle.static.data(name='a', shape=[1], dtype='float32')
Expand Down