-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon 5th No.49][pir] add some method property - Part 2 #58042
Changes from 7 commits
ccbd29c
2ef5686
52e41ae
596b03e
211db2b
7bc46e3
46bfd7c
4806073
2ee5a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,12 +13,23 @@ | |
# limitations under the License. | ||
|
||
|
||
import warnings | ||
|
||
from paddle.base.libpaddle import DataType | ||
|
||
from . import OpResult | ||
|
||
_already_patch_opresult = False | ||
|
||
_supported_int_dtype_ = [ | ||
DataType.BOOL, | ||
DataType.UINT8, | ||
DataType.INT8, | ||
DataType.INT16, | ||
DataType.INT32, | ||
DataType.INT64, | ||
] | ||
|
||
|
||
def create_tensor_with_batchsize(ref_var, value, dtype): | ||
assert isinstance(ref_var, OpResult) | ||
|
@@ -54,14 +65,96 @@ def safe_get_dtype(var): | |
raise ValueError("Cannot get data type from var") | ||
return dtype | ||
|
||
_supported_int_dtype_ = [ | ||
DataType.BOOL, | ||
DataType.UINT8, | ||
DataType.INT8, | ||
DataType.INT16, | ||
DataType.INT32, | ||
DataType.INT64, | ||
] | ||
def place(self): | ||
""" | ||
OpResult don't have 'place' interface in static graph mode | ||
But this interface can greatly facilitate dy2static. | ||
So we give a warnning here and return None. | ||
""" | ||
warnings.warn( | ||
"OpResult do not have 'place' interface for pir graph mode, try not to use it. None will be returned." | ||
) | ||
|
||
@property | ||
def _ndim_(self): | ||
""" | ||
Returns the dimension of current OpResult | ||
|
||
Returns: | ||
the dimension | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> paddle.enable_static() | ||
|
||
>>> # create a static OpResult | ||
>>> with paddle.pir_utils.IrGuard(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里新加的接口如果和之前是一致的,这里可以去掉with paddle.pir_utils.IrGuard(),后期pir推广默认不需要这个guard There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done, 已全部删除 |
||
>>> x = paddle.static.data(name='x', shape=[3, 2, 1]) | ||
>>> # print the dimension of the OpResult | ||
>>> print(x.ndim) | ||
3 | ||
""" | ||
return len(self.shape) | ||
|
||
def ndimension(self): | ||
""" | ||
Returns the dimension of current OpResult | ||
|
||
Returns: | ||
the dimension | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> paddle.enable_static() | ||
|
||
>>> # create a static OpResult | ||
>>> with paddle.pir_utils.IrGuard(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
>>> x = paddle.static.data(name='x', shape=[3, 2, 1]) | ||
>>> # print the dimension of the OpResult | ||
>>> print(x.ndimension()) | ||
3 | ||
""" | ||
return len(self.shape) | ||
|
||
def dim(self): | ||
""" | ||
Returns the dimension of current OpResult | ||
|
||
Returns: | ||
the dimension | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> paddle.enable_static() | ||
|
||
>>> # create a static OpResult | ||
>>> with paddle.pir_utils.IrGuard(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
>>> x = paddle.static.data(name='x', shape=[3, 2, 1]) | ||
>>> # print the dimension of the OpResult | ||
>>> print(x.dim()) | ||
3 | ||
""" | ||
return len(self.shape) | ||
|
||
def _item(self): | ||
""" | ||
In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self. | ||
It will check that the shape must be a 1-D tensor | ||
""" | ||
if len(self.shape) > 1: | ||
raise TypeError( | ||
f"Required input var should be 1-D OpResult, but received {self.shape}" | ||
) | ||
return self | ||
|
||
def _scalar_div_(var, value): | ||
return paddle.scale(var, 1.0 / value, 0.0) | ||
|
@@ -166,6 +259,11 @@ def __impl__(self, other_var): | |
import paddle | ||
|
||
opresult_methods = [ | ||
('place', place), | ||
('item', _item), | ||
('dim', dim), | ||
('ndimension', ndimension), | ||
('ndim', _ndim_), | ||
( | ||
'__div__', | ||
_binary_creator_( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里最后为什么还有一个下划线,_item只有前下划线。可以统一一下,比如:
类似如上的约定
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, 已同步至动态图 Tensor 和老 IR Varialbe