Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.26】为 Paddle 新增 diagonal_scatter API #669

Merged
merged 10 commits into from
Oct 19, 2023

Conversation

DanGuge
Copy link
Contributor

@DanGuge DanGuge commented Sep 28, 2023

为 Paddle 新增 diagonal_scatter API

@paddle-bot
Copy link

paddle-bot bot commented Sep 28, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请检查PR提交格式和内容是否完备,具体请参考示例模版
Your PR has been submitted. Thanks for your contribution!
Please check its format and content. For this, you can refer to Template and Demo.

@CLAassistant
Copy link

CLAassistant commented Sep 28, 2023

CLA assistant check
All committers have signed the CLA.

@DanGuge
Copy link
Contributor Author

DanGuge commented Oct 7, 2023

补充一下MindSpore上的实现,该目前应该只支持方阵,受限于下图红框中的运算:

image

link:https://github.com/mindspore-ai/mindspore/blob/master/mindspore/python/mindspore/ops/function/array_func.py#L6942

该方法试图保留input矩阵的对角线上的元素,但是ones_like(src)和input的形状使该运算失败

测试

  • 输入矩阵x的shape为(2, 3),矩阵y的shape(2, )

  • 在mindspore 2.1.1上的测试
    mindspore

  • 在torch 2.0.1上的测试
    torch


- 方案一:通过调用`fill_diagonal_tensor`实现对应逻辑,但是该方法只能在动态图中使用

- 方案二:调用`paddle.static.setitem`方法,覆盖diagonal_slice的元素,但是该方法在动态图中调用时,只会返回新的tensor,而不是inplace写
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inplace写是指?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果使用setitem返回一个新的矩阵,再通过Variable中的set_value函数向原矩阵赋值的做法是否能够满足需求?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None):
    output = x.clone()
    diagonal_slice = output.diagonal(offset, axis1, axis2)
    if diagonal_slice.shape != y.shape:
        raise ValueError(
            "The shape of diagonal_slice is not equal to the shape of y. diagonal_slice shape = {}, y shape = {}".format(
                diagonal_slice.shape, y.shape
            )
        )
    if in_dynamic_mode():
        diagonal_slice[:] = y
    else:
        from builtins import slice
        diagonal_slice = paddle.static.setitem(diagonal_slice, slice(None, None, None), y)
    return output

我们的理解上应该有些误差,这个是我目前的实现方案:

  • 在动态图中,结果符合预期
    • diagonal_slice是x矩阵diagonal的视图,使用diagonal_slice[:] = y,可以直接对视图中元素进行修改,这样x矩阵的值也会改变
  • 在静态图中,我也想做相同的事情,就是拿到x矩阵的diagonal视图,通过对视图元素的修改,引起源矩阵x的修改
    • 但是paddle.static.setitem函数在这里不会直接对diagonal_slice视图中的元素修改,而是返回一个新的diagonal_slice矩阵

所以并不是新的矩阵符合需求,而是没法通过这种方式得到符合需求的x矩阵

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉这样做可能会有问题,你尝试过这样修改在动态图是可行的吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在动态图上测试是可行的,这个是目前的PR:PaddlePaddle/Paddle#57879
在torch的实现中,整体逻辑也是类似的:

at::Tensor diagonal_scatter(const at::Tensor& self, const at::Tensor& src, int64_t offset, int64_t dim1, int64_t dim2) {
    // See Note [*_scatter ops preserve strides]
    auto output = clone_preserve_strides(self);
    auto slice = output.diagonal(offset, dim1, dim2);
    TORCH_CHECK(slice.sizes() == src.sizes(), "expected src to have a size equal to the slice of self. src size = ", src.sizes(), ", slice size = ", slice.sizes());
    slice.copy_(src);
    return output;
}
@register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
    output = clone(input)
    target = diagonal(output, offset, dim1, dim2)
    mutate_to(target, src)
    return output

Copy link
Contributor

@zxcd zxcd Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

静态图是不支持view,所以这样确实不行。另外看了一下pytorch,diagal_scatter是返回一个新的tensor,并不是在原值上修改,设计的时候返回一个新的tensor就可以了。这样动静态图都可以支持。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处使用view的方式来修改元素,是因为没有python中没有特别好的能拿到diagonal的index的方式,所以只好通过view的方式来修改,torch中也是通过调用diagonal函数拿到slice,然后直接将src矩阵的元素写入slice中

或者我可以通过复用_C_ops.fill_diagonal_tensor(x, y, offset, dim1, dim2)的实现,来实现diagonal_scatter吗?我看在 https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/tensor/manipulation.py#L1019 描述的基本操作逻辑和diagonal_scatter是一致的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看逻辑应该可以试试,辛苦修改一下方案

@DanGuge DanGuge requested a review from zxcd October 11, 2023 02:18
@DanGuge DanGuge requested a review from zxcd October 13, 2023 07:25

```python
paddle.diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None)
```
参数定义:

- `x(Tensor)`:输入张量,张量的维度至少为2维
- `y(Tensor)`:嵌入张量,将会被嵌入到输入张量中
- `x(Tensor)`:输入张量,张量的维度至少为2维,支持bool、int32、int64、float16、float32、float64数据类型
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fill_diagonal_tensor也支持复数类型,这个类型是否也能支持?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

应该也是支持的,已修改设计文档

@DanGuge DanGuge requested a review from zxcd October 17, 2023 03:01
@zxcd
Copy link
Contributor

zxcd commented Oct 17, 2023

以下是fill_diagonal_tensor GPU 和CPU支持的数据类型,烦请参考这里将类型补全吧。

PD_REGISTER_KERNEL(fill_diagonal_tensor,
                   GPU,
                   ALL_LAYOUT,
                   phi::FillDiagonalTensorKernel,
                   float,
                   double,
                   int64_t,
                   int,
                   int8_t,
                   uint8_t,
                   phi::dtype::float16,
                   phi::dtype::bfloat16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>,
                   bool) {}
PD_REGISTER_KERNEL(fill_diagonal_tensor,
                   CPU,
                   ALL_LAYOUT,
                   phi::FillDiagonalTensorKernel,
                   float,
                   double,
                   int64_t,
                   int,
                   int8_t,
                   uint8_t,
                   phi::dtype::float16,
                   phi::dtype::complex<float>,
                   phi::dtype::complex<double>,
                   bool) {}

@DanGuge
Copy link
Contributor Author

DanGuge commented Oct 19, 2023

@zxcd dtype已经补全完了,辛苦review一下

Copy link
Contributor

@zxcd zxcd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants