From d45536d017a1abd60440394362a883fa1023d233 Mon Sep 17 00:00:00 2001 From: wucc <77946882+DanGuge@users.noreply.github.com> Date: Thu, 5 Oct 2023 23:13:27 +0800 Subject: [PATCH] update api implementation scheme --- .../20230929_api_design_for_diagonal_scatter.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/rfcs/APIs/20230929_api_design_for_diagonal_scatter.md b/rfcs/APIs/20230929_api_design_for_diagonal_scatter.md index d727e8261..df67ea0d2 100644 --- a/rfcs/APIs/20230929_api_design_for_diagonal_scatter.md +++ b/rfcs/APIs/20230929_api_design_for_diagonal_scatter.md @@ -265,9 +265,20 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None) ## API实现方案 在python/paddle/tensor/manipulation.py中增加diagonal_scatter函数 -- 方案一:通过调用`paddle.fill_diagonal_tensor_impl`实现对应逻辑 +- 动态图 + + 1. clone输入张量,获得output张量 + 2. 调用diagonal方法,获得output张量对应位置上的张量视图diagonal_slice + 3. 通过张量索引,将diagonal_slice中的元素都变为嵌入张量 -- 方案二:clone input张量得到output张量,再对output张量的diagnoal位置上的元素使用src张量元素进行覆盖 +- 静态图(无法仅通过修改python代码实现) + + - 方案一:通过调用`fill_diagonal_tensor`实现对应逻辑,但是该方法只能在动态图中使用 + + - 方案二:调用`paddle.static.setitem`方法,覆盖diagonal_slice的元素,但是该方法在动态图中调用时,只会返回新的tensor,而不是inplace写 + - 如果想要调用`paddle.static.setitem(x, index, y)`,通过index来修改输入张量diagonal对应位置的元素,没有现成实现获得diagonal元素对应的index + + - 方案三:类似torch实现方案,实现cpp算子逻辑 ## 代码实现文件路径 @@ -287,6 +298,8 @@ Tensor.diagonal_scatter(x, offset=0, axis1=0, axis2=1, name=None) - 检查input的slice和src的维度是否相等,这样才能进行覆盖 +- 对多种offset/axis1/axis2设置的情况进行测试 + # 七、可行性分析和排期规划 方案实施难度可控,工期上可以满足在当前版本周期内开发完成