Skip to content

Commit

Permalink
update api implementation scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGuge committed Oct 5, 2023
1 parent 1be2cf6 commit e02499c
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions rfcs/APIs/20230929_api_design_for_diagonal_scatter.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,20 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None)
## API实现方案
在python/paddle/tensor/manipulation.py中增加diagonal_scatter函数

- 方案一:通过调用`paddle.fill_diagonal_tensor_impl`实现对应逻辑
- 动态图

1. clone输入张量,获得output张量
2. 调用diagonal方法,获得output张量对应位置上的张量视图diagonal_slice
3. 通过张量索引,将diagonal_slice中的元素都变为嵌入张量

- 方案二:clone input张量得到output张量,再对output张量的diagnoal位置上的元素使用src张量元素进行覆盖
- 静态图(无法仅通过修改python代码实现)

- 方案一:通过调用`fill_diagonal_tensor`实现对应逻辑,但是该方法只能在动态图中使用

- 方案二:调用`paddle.static.setitem`方法,覆盖diagonal_slice的元素,但是该方法在动态图中调用时,只会返回新的tensor,而不是inplace写
- 如果想要调用`paddle.static.setitem(x, index, y)`,通过index来修改输入张量diagonal对应位置的元素,没有现成实现获得diagonal元素对应的index

- 方案三:类似torch实现方案,实现cpp算子逻辑

## 代码实现文件路径

Expand All @@ -287,6 +298,8 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None)

- 检查inputslice和src的维度是否相等,这样才能进行覆盖

- 对多种offset/axis1/axis2设置的情况进行测试

# 七、可行性分析和排期规划

方案实施难度可控,工期上可以满足在当前版本周期内开发完成
Expand Down

0 comments on commit e02499c

Please sign in to comment.