-
Notifications
You must be signed in to change notification settings - Fork 273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.27】为 Paddle 新增 select_scatter API RFC #757
Conversation
const DenseTensor& value, | ||
int axis, | ||
int index, | ||
DenseTensor* out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方,可以调研下paddle的OP set_value
呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我再看看,目前已经实现了一个版本,如果方便的话可以先麻烦您看看嘛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我参考https://github.com/PaddlePaddle/Paddle/blob/f822c32b3734e971e3b71e2b78dcf16096528d91/python/paddle/base/variable_index.py#L917
这个里面的setitem_static实现了一个版本做了尝试,只把动态模式下的_C_ops.set_value_换成了_C_ops.set_value,但是这个op似乎不支持PIR模式,用我之前的测试代码会爆这个错。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zoooo0820 这个PIR模式是否需要支持呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zoooo0820 这个PIR模式是否需要支持呢
理论上应该支持的,现在PIR里应该是有相关实现的(参考paddle/fluid/pir/dialect/operator/ir/pd_api.cc
)。辛苦再调试下,还是不行的话,可以把具体代码和报错贴出来一起看下呢
@zoooo0820 我目前实现的代码如下,我的理解是原本能通过setitem_static来组合实现这个算子,我需要做的就是简化解析过程,然后直接调用set_value的底层C算子就行,所以我在setitem的代码基础上修改的代码为: def select_scatter(src, values, axis, index):
"""
Embeds the values of the values tensor into src at the given index of axis.
Args:
src (Tensor) : The Destination Tensor.
values (Tensor) : The tensor to embed into src.
axis (int) : the dimension to insert the slice into.
index (int) : the index to select with.
Returns:
Tensor, same dtype and shape with src
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.zeros((2,3,4)).astype("float32")
>>> values = paddle.ones((2,4)).astype("float32")
>>> res = paddle.select_scatter(x,values,1,1)
>>> print(res)
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]]])
"""
from ..base.framework import default_main_program
starts = [index]
ends = [index+1]
steps = [1]
axes = [axis]
none_axes = []
decrease_axes = [axis]
inputs = {'Input': src}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}
StartsTensorList = None
EndsTensorList = None
StepsTensorList = None
if paddle.utils._contain_var(starts):
StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
inputs['StartsTensorList'] = StartsTensorList
del attrs['starts']
if paddle.utils._contain_var(ends):
EndsTensorList = paddle.utils._convert_to_tensor_list(ends)
inputs['EndsTensorList'] = EndsTensorList
del attrs['ends']
if paddle.utils._contain_var(steps):
StepsTensorList = paddle.utils._convert_to_tensor_list(steps)
inputs['StepsTensorList'] = StepsTensorList
del attrs['steps']
# step2. Parse values
dtype = src.dtype
attrs['dtype'] = dtype
values = values.astype(dtype)
inputs["ValueTensor"] = values
if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
src,
values,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
)
else:
helper = LayerHelper(
'set_value', **locals()
)
if helper.main_program.current_block_idx != 0:
# not in global block, we should create a global variable.
output = helper._create_global_variable_for_type_inference(
dtype=src.dtype
)
else:
output = helper.create_variable_for_type_inference(
dtype=src.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, src.desc.id(), output
)
return output 但是这个代码在我之前提的PR中测试代码(https://github.com/PaddlePaddle/Paddle/pull/59343/files#diff-dd8e117af37176658197bd7ef61a66708f4f18e5dbf273fdebc6de5eccd4c84c) 中pir模式下的测试样例全挂了 |
@YibinLiu666 这个问题初步看的确是这个算子没适配好PIR,相关问题我这边这两天排查和修复下。PR里可以暂时不用管PIR报错,优先先review代码及完成其他部分的单测。 待前面PIR的问题解决后再看上述报错的问题是否解决呢 |
目前已经修改了RFC, @zoooo0820 可以先麻烦review一下RFC,我先忽略掉PIR模式的测试修改一下代码 |
|
||
# 四、对比分析 | ||
|
||
PyTorch 是使用 C++ API 实现的,Python 端直接调用 C++ 接口,性能较好。尽管paddle能够通过算子组合实现该api,但是使用slice来 setitem 性能较差,并且无法达到非inplace的效果。因此计划在实现paddle的`select_scatter`时实现相关c++ kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以在现状中简要说明下paddle set_value
OP的情况。以及这里需要修改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
**axis** (int) – 需要嵌入到src Tensor的维度。 | ||
|
||
**index** (int) – 选择的索引。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名可以参考下 API 设计和命名规范 主要可以关注下 src
/ name
此外数据类型上,目前应该是支持全dtype了,可以简单验证下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
再加上name
参数吧,可以参考下其他API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
上述问题应该已经解决了,PaddlePaddle/Paddle#59457 ,可以带着这个PR的修改测试下PIR下的情况呢 |
好的,我重新测一下 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
新增 select_scatter API RFC
PaddlePaddle/Paddle#57262