-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part #58917
Conversation
@zoooo0820 PR-CI-Inference 不清楚为什么 fail 了 ~ 重新跑了一次还是一样 ~ 请评审 ~ |
python/paddle/tensor/manipulation.py
Outdated
starts = ends | ||
ends = total_n | ||
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends]) | ||
splits.append(sub_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处的实现方式为多次调用slice
操作并append到一个list中,预期耗时会较高。
能否考虑将indices_or_sections
转化成num_or_sections
再一次性调用用split
操作完成?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好想法!我试试看~
test/legacy_test/test_splits_api.py
Outdated
DTYPE_STATIC_V_SPLIT = DTYPE_STATIC_SPLIT | ||
|
||
DTYPE_DYGRAPH_D_SPLIT = DTYPE_DYGRAPH_SPLIT | ||
DTYPE_STATIC_D_SPLIT = DTYPE_STATIC_SPLIT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些常量值都是同一个,是否简化统一一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是特意这样写的,很多接口的文档对于数据类型的支持都不完善或者明确,我希望至少在单测里面明确一下,所以才逐个写出来 ~ 如果不需要的话,可以合并 ~
test/legacy_test/test_splits_api.py
Outdated
with self.assertRaises(TypeError): | ||
self._test_static_api( | ||
self.func_paddle, self.func_numpy, **kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处不支持的dtype,属于kernel缺少支持而非预期不应当支持。建议不要增加单测case来测试是否会报错,直接测试已经支持的dtype即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK ~
def test_error_split(self): | ||
x = generate_data([6]) | ||
with self.assertRaises(ValueError): | ||
self._test_all({**x, 'split_paddle': 0, 'split_numpy': None}) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原来paddle.split
API的测试被移除了,这里应该补上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zoooo0820
这个单测文件里面原来只有 paddle.vsplit
的测试(已经覆盖测试),没有 paddle.split
,paddle.split
的单测在 test/legacy_test/test_split_op.py
,这里 补上
是指?
test/legacy_test/test_splits_api.py
Outdated
'split_numpy': convert_num_or_sections((2, 1, 3)), | ||
} | ||
) | ||
self._test_all({**x, 'split_paddle': [1, -1, 3], 'split_numpy': [1, 3]}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再补充个-1推断出非0的case吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update 20231114
之前看这条 review 意见理解错了,这里解释一下:
由于 paddle 的 split 与 numpy 的不同,paddle [1, -1, 3] 推断出的 sections 相当于 [1, 2, 3] (总长度为 6,-1 推断为 2,并不是 0),对应 numpy (indices_or_sections = np.cumsum(num_or_sections)[:-1])的 indices_or_sections 为 [1, 3]。
之前的 convert_num_or_sections
函数是不支持 -1
的,如果要支持的话会把问题复杂化,所以,这里对于 -1
这种情况,是把 paddle 与 numpy 分别输入的。
def convert_num_or_sections(num_or_sections):
# hsplit, vsplit, dsplit
# Convert the num_or_sections in paddle to indices_or_sections in numpy
# Do not support -1
if isinstance(num_or_sections, int):
indices_or_sections = num_or_sections
else:
indices_or_sections = np.cumsum(num_or_sections)[:-1]
return indices_or_sections
所以,这里 -1 并不是推断为 0,是我 test case 写的不好 ~ 为避免歧义,我把这里改一下吧 ~
另外,推断为 0 是 error case,我这边补充一下 ~
test/legacy_test/test_splits_api.py
Outdated
'complex128', | ||
'bfloat16', | ||
'bool', | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp16/bf16类型在某些设备上运行可能会有问题,这里最好确认下。可以参考下其他单测文件里的is_bfloat16_supported
等写法,把这两个dtype单独处理下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK 这里疏忽了~
Update 20231114
另外, |
这个地方想确认下,np.array_split / torch.tensor_split功能上是否支持上述两种情况?
之前是我遗漏了,那这里可以不用再测了 |
支持 ~ In [1]: import numpy as np
In [2]: a = np.arange(8)
In [3]: np.array_split(a, [2, 1, 4])
Out[3]: [array([0, 1]), array([], dtype=int64), array([1, 2, 3]), array([4, 5, 6, 7])]
In [4]: import torch
In [5]: x = torch.arange(8)
In [6]: torch.tensor_split(x, [2, 1, 4])
Out[6]:
(tensor([0, 1]),
tensor([], dtype=torch.int64),
tensor([1, 2, 3]),
tensor([4, 5, 6, 7]))
In [7]: np.array_split(a, [2, 1, 15])
Out[7]:
[array([0, 1]),
array([], dtype=int64),
array([1, 2, 3, 4, 5, 6, 7]),
array([], dtype=int64)]
In [8]: torch.tensor_split(x, [2, 1, 15])
Out[8]:
(tensor([0, 1]),
tensor([], dtype=torch.int64),
tensor([1, 2, 3, 4, 5, 6, 7]),
tensor([], dtype=torch.int64)) 两者是用类似我之前的 slice 方式处理的,所以能够处理这类 之前测试用例里面也有这类 如果要支持这类 比如 因此,如果要达到上面的结果,一个 @zoooo0820 看看有什么好办法?! |
@megemini 这个地方,从用户使用的角度来看需要优先对齐功能。那还是按照此前的做法吧,辛苦修改一下~ |
刚才想要修改回之前的 这里我单独处理一下吧 ~ 另外, 动态图可以: In [11]: input = paddle.rand(shape=[4, 5, 6], dtype='float32')
In [12]: paddle.slice(input, axes=[-2], starts=[1], ends=[3])
Out[12]:
Tensor(shape=[4, 2, 6], dtype=float32, place=Place(gpu:0), stop_gradient=True,
... 静态图不行: ...: program = paddle.static.Program()
...: exe = paddle.static.Executor()
...:
...: with paddle.static.program_guard(program):
...: input = paddle.static.data('x', [4, 5, 6], 'float64')
...: feed = {'x': np.random.rand(4, 5, 6)}
...:
...: out = paddle.slice(input, axes=[-2], starts=[1], ends=[3])
...: res = exe.run(feed=feed, fetch_list=[out])
In [24]: res[0].shape
Out[24]: (4, 5, 6)
In [25]: len(res)
Out[25]: 1 静态图没有 slice ... ... 我单独提个 PR 改一下吧? |
可以的,如果slice API验证下来确实有问题,可以提个PR修复一下 |
Sorry to inform you that fc8fdbd's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
这几个 api 依赖的底层算子静态图回传好像有问题 ~ #59365 这个 PR 里面我整理了一下 ~ 等问题解决了,我会把这几个 api 的静态图回传再测一下 ~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, 关于反馈的pir下的其他op导致的问题,后续根据相关任务进展修复即可
@megemini @zoooo0820 |
@megemini 你好,
2.切分参数支持int / int数组的输入,语义上,int表示切分数量(num),int数组表示末位索引位置(indices),现参数名 辛苦再根据这个意见做下改动,同时也同步再提PR在对应的rfc中做下修改~ |
赞同!其实之前也不是很理解为什么不对齐 numpy 和 pytorch 的实现 ~ OK,再确认一下:
请帮忙看一下是否有错误或遗漏? |
嗯嗯,按这个开发就好 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
code is fine, but the implement and parameter names in rfc also need to be modified according to the new conclusion |
roger ~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
中文文档也可以提了 |
需要解决下冲突 |
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64. | ||
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``. | ||
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first | ||
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n). | |
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n)``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不如顺师傅下一个PR单独改一下这个吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改的东西比较多,先 approve ~顺师傅新提一个 PR 修改上述的文档问题叭
PR types
New features
PR changes
APIs
Description
【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API
RFC: PaddlePaddle/community#682
涉及文件:
python/paddle/__init__.py
将 API 暴露出来python/paddle/tensor/__init__.py
将 API 暴露出来,并绑定为 Tensor 的方法python/paddle/tensor/manipulation.py
实现 APItest/legacy_test/test_splits_api.py
单元测试。这里为了保持测试的一致性,重写(并涵盖)了原有的vsplit
测试用例。请评审!