-
Notifications
You must be signed in to change notification settings - Fork 268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.26】为 Paddle 新增 diagonal_scatter API #669
Conversation
e02499c
to
d45536d
Compare
d45536d
to
65bf3be
Compare
|
||
- 方案一:通过调用`fill_diagonal_tensor`实现对应逻辑,但是该方法只能在动态图中使用 | ||
|
||
- 方案二:调用`paddle.static.setitem`方法,覆盖diagonal_slice的元素,但是该方法在动态图中调用时,只会返回新的tensor,而不是inplace写 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inplace写是指?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 实现的逻辑是这样的,我们通过
diagonal()
函数获得output矩阵(x矩阵的clone)的diagonal_slice视图- 在动态图中,我们可以直接通过张量索引将y矩阵的元素写入diagonal_slice视图,这样能够直接让output矩阵中对应元素也发生修改
- 在静态图中,设置元素需要通过
paddle.static.setitem()
函数实现,但是它在静态图中的逻辑不是直接修改diagonal_slice视图中的元素,而是返回一个新的矩阵,这样就没法实现对output矩阵的修改
paddle.static.setitem(x, index, value)
的调用最终会调用_setitem_static(x, indices, values)
函数- 函数链接如下:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果使用setitem返回一个新的矩阵,再通过Variable中的set_value函数向原矩阵赋值的做法是否能够满足需求?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None):
output = x.clone()
diagonal_slice = output.diagonal(offset, axis1, axis2)
if diagonal_slice.shape != y.shape:
raise ValueError(
"The shape of diagonal_slice is not equal to the shape of y. diagonal_slice shape = {}, y shape = {}".format(
diagonal_slice.shape, y.shape
)
)
if in_dynamic_mode():
diagonal_slice[:] = y
else:
from builtins import slice
diagonal_slice = paddle.static.setitem(diagonal_slice, slice(None, None, None), y)
return output
我们的理解上应该有些误差,这个是我目前的实现方案:
- 在动态图中,结果符合预期
- diagonal_slice是x矩阵diagonal的视图,使用
diagonal_slice[:] = y
,可以直接对视图中元素进行修改,这样x矩阵的值也会改变
- diagonal_slice是x矩阵diagonal的视图,使用
- 在静态图中,我也想做相同的事情,就是拿到x矩阵的diagonal视图,通过对视图元素的修改,引起源矩阵x的修改
- 但是
paddle.static.setitem
函数在这里不会直接对diagonal_slice视图中的元素修改,而是返回一个新的diagonal_slice矩阵
- 但是
所以并不是新的矩阵符合需求,而是没法通过这种方式得到符合需求的x矩阵
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我感觉这样做可能会有问题,你尝试过这样修改在动态图是可行的吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在动态图上测试是可行的,这个是目前的PR:PaddlePaddle/Paddle#57879
在torch的实现中,整体逻辑也是类似的:
at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
// See Note [*_scatter ops preserve strides]
auto output = clone_preserve_strides(self);
auto slice = output.diagonal(offset, dim1, dim2);
TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
slice.copy_(src);
return output;
}
@register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
output = clone(input)
target = diagonal(output, offset, dim1, dim2)
mutate_to(target, src)
return output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
静态图是不支持view,所以这样确实不行。另外看了一下pytorch,diagal_scatter是返回一个新的tensor,并不是在原值上修改,设计的时候返回一个新的tensor就可以了。这样动静态图都可以支持。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处使用view的方式来修改元素,是因为没有python中没有特别好的能拿到diagonal的index的方式,所以只好通过view的方式来修改,torch中也是通过调用diagonal函数拿到slice,然后直接将src矩阵的元素写入slice中
或者我可以通过复用_C_ops.fill_diagonal_tensor(x, y, offset, dim1, dim2)
的实现,来实现diagonal_scatter
吗?我看在 https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/tensor/manipulation.py#L1019 描述的基本操作逻辑和diagonal_scatter
是一致的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看逻辑应该可以试试,辛苦修改一下方案
|
||
```python | ||
paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None) | ||
``` | ||
参数定义: | ||
|
||
- `x(Tensor)`:输入张量,张量的维度至少为2维 | ||
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中 | ||
- `x(Tensor)`:输入张量,张量的维度至少为2维,支持bool、int32、int64、float16、float32、float64数据类型 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fill_diagonal_tensor也支持复数类型,这个类型是否也能支持?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该也是支持的,已修改设计文档
以下是fill_diagonal_tensor GPU 和CPU支持的数据类型,烦请参考这里将类型补全吧。
|
5894867
to
05dff6b
Compare
@zxcd dtype已经补全完了,辛苦review一下 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
为 Paddle 新增 diagonal_scatter API