Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part #58917

Merged
merged 15 commits into from
Dec 13, 2023

Conversation

megemini
Copy link
Contributor

PR types

New features

PR changes

APIs

Description

【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API

RFC: PaddlePaddle/community#682

涉及文件:

  • python/paddle/__init__.py 将 API 暴露出来
  • python/paddle/tensor/__init__.py 将 API 暴露出来,并绑定为 Tensor 的方法
  • python/paddle/tensor/manipulation.py 实现 API
  • test/legacy_test/test_splits_api.py 单元测试。这里为了保持测试的一致性,重写(并涵盖)了原有的 vsplit 测试用例。

请评审!

@paddle-bot paddle-bot bot added the contributor External developers label Nov 10, 2023
@megemini
Copy link
Contributor Author

@zoooo0820 PR-CI-Inference 不清楚为什么 fail 了 ~ 重新跑了一次还是一样 ~

请评审 ~

starts = ends
ends = total_n
sub_array = paddle.slice(x, axes=[axis], starts=[starts], ends=[ends])
splits.append(sub_array)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处的实现方式为多次调用slice操作并append到一个list中,预期耗时会较高。
能否考虑将indices_or_sections转化成num_or_sections 再一次性调用用split操作完成?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好想法!我试试看~

DTYPE_STATIC_V_SPLIT = DTYPE_STATIC_SPLIT

DTYPE_DYGRAPH_D_SPLIT = DTYPE_DYGRAPH_SPLIT
DTYPE_STATIC_D_SPLIT = DTYPE_STATIC_SPLIT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些常量值都是同一个,是否简化统一一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是特意这样写的,很多接口的文档对于数据类型的支持都不完善或者明确,我希望至少在单测里面明确一下,所以才逐个写出来 ~ 如果不需要的话,可以合并 ~

with self.assertRaises(TypeError):
self._test_static_api(
self.func_paddle, self.func_numpy, **kwargs
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处不支持的dtype,属于kernel缺少支持而非预期不应当支持。建议不要增加单测case来测试是否会报错,直接测试已经支持的dtype即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK ~

def test_error_split(self):
x = generate_data([6])
with self.assertRaises(ValueError):
self._test_all({**x, 'split_paddle': 0, 'split_numpy': None})


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来paddle.splitAPI的测试被移除了,这里应该补上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zoooo0820
这个单测文件里面原来只有 paddle.vsplit 的测试(已经覆盖测试),没有 paddle.splitpaddle.split 的单测在 test/legacy_test/test_split_op.py,这里 补上 是指?

'split_numpy': convert_num_or_sections((2, 1, 3)),
}
)
self._test_all({**x, 'split_paddle': [1, -1, 3], 'split_numpy': [1, 3]})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再补充个-1推断出非0的case吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK ~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update 20231114

@zoooo0820

之前看这条 review 意见理解错了,这里解释一下:

由于 paddle 的 split 与 numpy 的不同,paddle [1, -1, 3] 推断出的 sections 相当于 [1, 2, 3] (总长度为 6,-1 推断为 2,并不是 0),对应 numpy (indices_or_sections = np.cumsum(num_or_sections)[:-1])的 indices_or_sections 为 [1, 3]。

之前的 convert_num_or_sections 函数是不支持 -1 的,如果要支持的话会把问题复杂化,所以,这里对于 -1 这种情况,是把 paddle 与 numpy 分别输入的。

def convert_num_or_sections(num_or_sections):
    # hsplit, vsplit, dsplit
    # Convert the num_or_sections in paddle to indices_or_sections in numpy
    # Do not support -1
    if isinstance(num_or_sections, int):
        indices_or_sections = num_or_sections
    else:
        indices_or_sections = np.cumsum(num_or_sections)[:-1]

    return indices_or_sections

所以,这里 -1 并不是推断为 0,是我 test case 写的不好 ~ 为避免歧义,我把这里改一下吧 ~

另外,推断为 0 是 error case,我这边补充一下 ~

'complex128',
'bfloat16',
'bool',
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp16/bf16类型在某些设备上运行可能会有问题,这里最好确认下。可以参考下其他单测文件里的is_bfloat16_supported等写法,把这两个dtype单独处理下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK 这里疏忽了~

@megemini
Copy link
Contributor Author

megemini commented Nov 14, 2023

Update 20231114

  • 修改 tensor_split 调用 split ,但是需要说明:
    由于不使用 slice,而 split 有一些局限性,所以,修改后的函数不能支持如 [2, 1, 4] 这种乱序的 indices,以及超过 range 的 index。如果需要支持,那么使用 split 优化的好处就不明显了,如,需要判断 indices 的顺序、范围等。由于这些 case 属于 异常 或者不常用的情况,所以就没做处理,但是增加了单测进行验证。

  • 合并 DTYPE_ALL 数据类型的测试,并且增加 bfloat16 在 GPU 环境下的单独处理

  • 将 CPU/GPU 支持的数据类型分开测试

  • 增加其他正常/异常单测

另外,paddle.splittest/legacy_test/test_split_op.py 已经有测试了,还需要再测一下?

@zoooo0820
Copy link
Contributor

Update 20231114

  • 修改 tensor_split 调用 split ,但是需要说明:
    由于不使用 slice,而 split 有一些局限性,所以,修改后的函数不能支持如 [2, 1, 4] 这种乱序的 indices,以及超过 range 的 index。如果需要支持,那么使用 split 优化的好处就不明显了,如,需要判断 indices 的顺序、范围等。由于这些 case 属于 异常 或者不常用的情况,所以就没做处理,但是增加了单测进行验证。

这个地方想确认下,np.array_split / torch.tensor_split功能上是否支持上述两种情况?

另外,paddle.splittest/legacy_test/test_split_op.py 已经有测试了,还需要再测一下?

之前是我遗漏了,那这里可以不用再测了

@megemini
Copy link
Contributor Author

Update 20231114

  • 修改 tensor_split 调用 split ,但是需要说明:
    由于不使用 slice,而 split 有一些局限性,所以,修改后的函数不能支持如 [2, 1, 4] 这种乱序的 indices,以及超过 range 的 index。如果需要支持,那么使用 split 优化的好处就不明显了,如,需要判断 indices 的顺序、范围等。由于这些 case 属于 异常 或者不常用的情况,所以就没做处理,但是增加了单测进行验证。

这个地方想确认下,np.array_split / torch.tensor_split功能上是否支持上述两种情况?

支持 ~

In [1]: import numpy as np

In [2]: a = np.arange(8)

In [3]: np.array_split(a, [2, 1, 4])
Out[3]: [array([0, 1]), array([], dtype=int64), array([1, 2, 3]), array([4, 5, 6, 7])]

In [4]: import torch

In [5]: x = torch.arange(8)

In [6]: torch.tensor_split(x, [2, 1, 4])
Out[6]: 
(tensor([0, 1]),
 tensor([], dtype=torch.int64),
 tensor([1, 2, 3]),
 tensor([4, 5, 6, 7]))

In [7]: np.array_split(a, [2, 1, 15])
Out[7]: 
[array([0, 1]),
 array([], dtype=int64),
 array([1, 2, 3, 4, 5, 6, 7]),
 array([], dtype=int64)]

In [8]: torch.tensor_split(x, [2, 1, 15])
Out[8]: 
(tensor([0, 1]),
 tensor([], dtype=torch.int64),
 tensor([1, 2, 3, 4, 5, 6, 7]),
 tensor([], dtype=torch.int64))

两者是用类似我之前的 slice 方式处理的,所以能够处理这类 异常 情况 ~

之前测试用例里面也有这类 异常 情况的测试用例,是可以对齐 numpy 的 ~

如果要支持这类 异常 情况,那么可能还是要回退到之前用 slice 的方法,用 split 来做就更复杂了 ... ...

比如 np.array_split(a, [2, 1, 4]) 这种情况,结果是 [array([0, 1]), array([], dtype=int64), array([1, 2, 3]), array([4, 5, 6, 7])],虽然 paddle.split 可以生成长度为 0 的结果(第二个分段),但是,接下来的 index 是从 1 开始的(第三个分段),这个 paddle.split 好像是做不了吧?

因此,如果要达到上面的结果,一个 paddle.split 好像实现不了?

@zoooo0820 看看有什么好办法?!

@zoooo0820
Copy link
Contributor

Update 20231114

  • 修改 tensor_split 调用 split ,但是需要说明:
    由于不使用 slice,而 split 有一些局限性,所以,修改后的函数不能支持如 [2, 1, 4] 这种乱序的 indices,以及超过 range 的 index。如果需要支持,那么使用 split 优化的好处就不明显了,如,需要判断 indices 的顺序、范围等。由于这些 case 属于 异常 或者不常用的情况,所以就没做处理,但是增加了单测进行验证。

这个地方想确认下,np.array_split / torch.tensor_split功能上是否支持上述两种情况?

支持 ~

In [1]: import numpy as np

In [2]: a = np.arange(8)

In [3]: np.array_split(a, [2, 1, 4])
Out[3]: [array([0, 1]), array([], dtype=int64), array([1, 2, 3]), array([4, 5, 6, 7])]

In [4]: import torch

In [5]: x = torch.arange(8)

In [6]: torch.tensor_split(x, [2, 1, 4])
Out[6]: 
(tensor([0, 1]),
 tensor([], dtype=torch.int64),
 tensor([1, 2, 3]),
 tensor([4, 5, 6, 7]))

In [7]: np.array_split(a, [2, 1, 15])
Out[7]: 
[array([0, 1]),
 array([], dtype=int64),
 array([1, 2, 3, 4, 5, 6, 7]),
 array([], dtype=int64)]

In [8]: torch.tensor_split(x, [2, 1, 15])
Out[8]: 
(tensor([0, 1]),
 tensor([], dtype=torch.int64),
 tensor([1, 2, 3, 4, 5, 6, 7]),
 tensor([], dtype=torch.int64))

两者是用类似我之前的 slice 方式处理的,所以能够处理这类 异常 情况 ~

之前测试用例里面也有这类 异常 情况的测试用例,是可以对齐 numpy 的 ~

如果要支持这类 异常 情况,那么可能还是要回退到之前用 slice 的方法,用 split 来做就更复杂了 ... ...

比如 np.array_split(a, [2, 1, 4]) 这种情况,结果是 [array([0, 1]), array([], dtype=int64), array([1, 2, 3]), array([4, 5, 6, 7])],虽然 paddle.split 可以生成长度为 0 的结果(第二个分段),但是,接下来的 index 是从 1 开始的(第三个分段),这个 paddle.split 好像是做不了吧?

因此,如果要达到上面的结果,一个 paddle.split 好像实现不了?

@zoooo0820 看看有什么好办法?!

@megemini 这个地方,从用户使用的角度来看需要优先对齐功能。那还是按照此前的做法吧,辛苦修改一下~

@megemini
Copy link
Contributor Author

@zoooo0820

刚才想要修改回之前的 slice 的实现方式,结果发现 slice 接口实现改变了,不能 starts 大于 ends 了 😆😆😆 ... ...

这里我单独处理一下吧 ~

另外,slice 有点问题,对于 axis 为负值,处理有问题:

动态图可以:

In [11]: input = paddle.rand(shape=[4, 5, 6], dtype='float32')

In [12]: paddle.slice(input, axes=[-2], starts=[1], ends=[3])
Out[12]: 
Tensor(shape=[4, 2, 6], dtype=float32, place=Place(gpu:0), stop_gradient=True,
...

静态图不行:

    ...: program = paddle.static.Program()
    ...: exe = paddle.static.Executor()
    ...: 
    ...: with paddle.static.program_guard(program):
    ...:     input = paddle.static.data('x', [4, 5, 6], 'float64')
    ...:     feed = {'x': np.random.rand(4, 5, 6)}
    ...: 
    ...:     out = paddle.slice(input, axes=[-2], starts=[1], ends=[3])
    ...:     res = exe.run(feed=feed, fetch_list=[out])


In [24]: res[0].shape
Out[24]: (4, 5, 6)

In [25]: len(res)
Out[25]: 1

静态图没有 slice ... ... 我单独提个 PR 改一下吧?

@zoooo0820
Copy link
Contributor

静态图没有 slice ... ... 我单独提个 PR 改一下吧?

可以的,如果slice API验证下来确实有问题,可以提个PR修复一下

Copy link

paddle-ci-bot bot commented Nov 23, 2023

Sorry to inform you that fc8fdbd's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@megemini
Copy link
Contributor Author

@zoooo0820

这几个 api 依赖的底层算子静态图回传好像有问题 ~

#59365 这个 PR 里面我整理了一下 ~

等问题解决了,我会把这几个 api 的静态图回传再测一下 ~

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, 关于反馈的pir下的其他op导致的问题,后续根据相关任务进展修复即可

@luotao1
Copy link
Contributor

luotao1 commented Nov 30, 2023

关于反馈的pir下的其他op导致的问题,后续根据相关任务进展修复即可

@megemini @zoooo0820
@xiaoguoguo626807 讨论后:本PR可以先合,PIR下其他OP导致的问题,@xiaoguoguo626807 后面会修。具体见 #59365 (comment) (因为这几个测试是本地测试,没有加入单测中,不影响合)

@zoooo0820
Copy link
Contributor

@megemini 你好,
这个PR在进一步review的时候,考虑到用户侧的易用性和API设计原则等因素,我们内部做了讨论,主要有下面两个结论:

  1. 考虑到hsplit / dsplit / vsplit API的行为,在numpy / pytorch均是基于切分数量(num)和末位索引位置(indices),如果我们同名API行为却有差异,从用户角度出发会增大使用成本。 我们讨论后决定这几个API也采用这个行为,即基于该PR新增的paddle.tensor_split,而非此前的paddle.split实现。 通过排查发现vsplit不兼容改动影响面不大,所以我们可以在这个PR中也一并修改了。

2.切分参数支持int / int数组的输入,语义上,int表示切分数量(num),int数组表示末位索引位置(indices),现参数名indices_or_sections并不完全符合这个语义,所以这里最好修改成num_or_indices,涉及tensor_split / hsplit / dsplit / vsplit

辛苦再根据这个意见做下改动,同时也同步再提PR在对应的rfc中做下修改~

@megemini
Copy link
Contributor Author

megemini commented Dec 6, 2023

@zoooo0820

赞同!其实之前也不是很理解为什么不对齐 numpy 和 pytorch 的实现 ~

OK,再确认一下:

  • 修改 hsplit / dsplit / vsplit 基于 tensor_split
  • 以上四者的参数为 num_or_indices,表示切分数量和索引
  • split 保持不变,参数为 num_or_sections
  • 以上所有接口相关参数都不使用 indices_or_sections

请帮忙看一下是否有错误或遗漏?

@zoooo0820
Copy link
Contributor

@zoooo0820

赞同!其实之前也不是很理解为什么不对齐 numpy 和 pytorch 的实现 ~

OK,再确认一下:

  • 修改 hsplit / dsplit / vsplit 基于 tensor_split
  • 以上四者的参数为 num_or_indices,表示切分数量和索引
  • split 保持不变,参数为 num_or_sections
  • 以上所有接口相关参数都不使用 indices_or_sections

请帮忙看一下是否有错误或遗漏?

嗯嗯,按这个开发就好

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jeff41404
Copy link
Contributor

jeff41404 commented Dec 7, 2023

code is fine, but the implement and parameter names in rfc also need to be modified according to the new conclusion

@megemini
Copy link
Contributor Author

megemini commented Dec 7, 2023

code is fine, but the implement and parameter names in rfc also need to be modified according to the new conclusion

roger ~

@megemini
Copy link
Contributor Author

megemini commented Dec 7, 2023

PaddlePaddle/community#776

请评审 ~

Copy link
Contributor

@zoooo0820 zoooo0820 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

jeff41404
jeff41404 previously approved these changes Dec 12, 2023
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1
Copy link
Contributor

luotao1 commented Dec 12, 2023

中文文档也可以提了

@luotao1
Copy link
Contributor

luotao1 commented Dec 12, 2023

需要解决下冲突

x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``.
If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n).
``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n)``.

少了(其实应该在 CI 检查里加一下关于 的检查...顺师傅有意向做不🐶)
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不如顺师傅下一个PR单独改一下这个吧

Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改的东西比较多,先 approve ~顺师傅新提一个 PR 修改上述的文档问题叭

@luotao1 luotao1 merged commit 538905c into PaddlePaddle:develop Dec 13, 2023
29 checks passed
@luotao1 luotao1 changed the title 【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API 【Hackathon 5th No.32】为 Paddle 新增 tensor_split / hsplit / dsplit API -part Dec 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants