-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.20】为 Paddle 新增 Tensor.resize_ -part #69841
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3389,3 +3389,82 @@ def set_( | |
shape = source.shape | ||
|
||
return _C_ops.set_(x, source, shape, stride, offset) | ||
|
||
|
||
@inplace_apis_in_dygraph_only | ||
def resize_( | ||
x: paddle.Tensor, | ||
shape: Sequence[int], | ||
fill_zero: bool = False, | ||
name: str | None = None, | ||
) -> paddle.Tensor: | ||
""" | ||
Resize ``x`` with specified ``shape``. | ||
|
||
Args: | ||
x (Tensor): An arbitrary Tensor. The data type supports ``bfloat16``, ``float16``, ``float32``, ``float64``, | ||
``bool``, ``int8``, ``int16``, ``int32``, ``int64``, ``uint8``, ``complex64`` or ``complex128``. | ||
shape (list|tuple): Define the target shape. Each element of it should be integer. | ||
fill_zero (bool, optional): If the size of specified ``shape`` is greater than the original Tensor size, the | ||
new Tensor will be filled with zero if ``fill_zero`` is True. Default: False, which means the filled value | ||
will be undetermined. | ||
name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. | ||
|
||
Returns: | ||
Tensor, the resized Tensor. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
|
||
>>> x = paddle.to_tensor([1., 2., 3.]) | ||
>>> x.resize_([2, 1]) | ||
>>> print(x) | ||
Tensor(shape=[2, 1], dtype=float32, place=Place(cpu), stop_gradient=True, | ||
[[1.], | ||
[2.]]) | ||
|
||
>>> x = paddle.to_tensor([1., 2., 3.]) | ||
>>> x.resize_([2, 3], fill_zero=True) | ||
>>> print(x) | ||
Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, | ||
[[1., 2., 3.], | ||
[0., 0., 0.]]) | ||
|
||
""" | ||
if in_dynamic_mode(): | ||
check_dtype( | ||
x.dtype, | ||
'x', | ||
[ | ||
'bool', | ||
'float16', | ||
'uint16', | ||
'float32', | ||
'float64', | ||
'int8', | ||
'int16', | ||
'int32', | ||
'int64', | ||
'uint8', | ||
'complex64', | ||
'complex128', | ||
], | ||
'resize', | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. miss data format check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also converge this codes in unit test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
if not isinstance(shape, (list, tuple)): | ||
raise ValueError( | ||
f"Input (shape) should be list or tuple but received {type(shape)}" | ||
) | ||
new_size = math.prod(shape) | ||
old_size = math.prod(x.shape) | ||
if (new_size > old_size) and fill_zero: | ||
repeats = -(-new_size // old_size) # ceil division | ||
flatten_x = x.flatten() | ||
tmp = paddle.concat( | ||
(flatten_x,) + (paddle.zeros_like(flatten_x),) * (repeats - 1) | ||
)[:new_size] | ||
return x.set_(tmp, shape) | ||
|
||
return x.set_(x, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where did this parameter come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make the functionality aligned with numpy (when calling
ndarray.resize()
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then you should add this parameter to interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean I should not add this parameter and functionality to the interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
your interface have three inputs, but doc get four?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I forgot to include the parameter
name
.added.