Skip to content

Commit

Permalink
【Hackathon 5th No.5】 为 Paddle 增强scatter API (#631)
Browse files Browse the repository at this point in the history
* add scatter rfc

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix

* update

* update

* update
  • Loading branch information
lisamhy authored Sep 21, 2023
1 parent 7e2b63f commit 9559b65
Showing 1 changed file with 264 additions and 0 deletions.
264 changes: 264 additions & 0 deletions rfcs/APIs/20230916_api_design_for_scatter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
# paddle.scatter API增强设计文档

| API名称 | paddle.scatter |
| ------------ | -------------------------------------- |
| 提交作者 | mhy |
| 提交时间 | 2023-09-16 |
| 版本号 | V1.0 |
| 依赖飞桨版本 | develop |
| 文件名 | 20230916_api_design_for_scatter.md |

# 一、概述

## 1、相关背景

`scatter` 是一个常用的API, API提供了根据index信息更新原Tensor的功能。

## 2、功能目标

当前 `paddle.scatter` API提供了根据index信息更新原Tensor的功能,但指定维度和归约方式功能尚不支持。

目前 `paddle.scatter` 的规约支持 None和sum。

## 3、意义

该API是一个常用的API,可以方便用户使用。让用户不用自己实现该功能,提高用户的使用效率。

# 二、飞桨现状

当前 `paddle.scatter` API但缺少指定轴和归约方式等功能。

paddle是通过 op kernel 的形式实现 `scatter``scatter_` API。
```python
_C_ops.scatter(x, index, updates, overwrite)
_C_ops.scatter_(x, index, updates, overwrite)

helper.append_op(
type="scatter",
inputs={"X": x, "Ids": index, "Updates": updates},
attrs={'overwrite': overwrite},
outputs={"Out": out},
)
```

paddle 实现的 index 功能和 torch的`scatter`中的index的功能不一致。paddle 的 index 只能是一维或者0维的。

```python
import paddle
#input:
x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
# shape of updates should be the same as x
# shape of updates with dim > 1 should be the same as input
updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
overwrite = False
# calculation:
if not overwrite:
for i in range(len(index)):
x[index[i]] = paddle.zeros([2])

for i in range(len(index)):
if (overwrite):
x[index[i]] = updates[i]
else:
x[index[i]] += updates[i]
# output:
out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]])
out.shape # [3, 2]
```

# 三、业内方案调研

## Pytorch

### Pytorch中 有 API `Tensor.scatter_(dim, index, src, reduce=None) → Tensor`

在pytorch中,介绍为:

```
Writes all values from the tensor `src` into `self` at the indices specified in the `index` tensor. For each value in `src`, its output index is specified by its index in `src` for` dimension != dim` and by the corresponding value in `index` for `dimension = dim`.
```

其中输入参数的描述如下:

- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
- src (Tensor or float) – the source element(s) to scatter.
- reduce (str, optional) – reduction operation to apply, can be either 'sum' or 'multiply'.

`torch.scatter` 和 paddle.scatter 的区别在于:
1. torch 支持 dim 配置.

```python
# For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
```

2. reduce 额外支持 ‘multiply'
3. index的维度和src是一致的。


### Pytorch中 有 API `Tensor.index_reduce_(dim, index, source, reduce, *, include_self=True) → Tensor`

其中输入参数的描述如下:

- dim (int) – dimension along which to index
- index (Tensor) – indices of source to select from, should have dtype either torch.int64 or torch.int32
- source (FloatTensor) – the tensor containing values to accumulate
- reduce (str) – the reduction operation to apply ("prod", "mean", "amax", "amin")

`torch.index_reduce_` 和 paddle.scatter 的区别在于:
1. torch 支持 dim 配置.

```python
self[index[i], :, :] *= src[i, :, :] # if dim == 0
self[:, index[i], :] *= src[:, i, :] # if dim == 1
self[:, :, index[i]] *= src[:, :, i] # if dim == 2
```

2. reduce 支持 mean, amax, amin, 但不支持 add
3. index的维度是一维的。
4. paddle 只支持 include_self = False。


## Tensorflow

Tensorflow 没有提供 `scatter` 的API。

# 四、对比分析
- Pytorch 自定义Kernel的方式更加高效. index支持多维度,支持指定dim和reduce方式。
- Tensorflow 不支持 scatter API

# 五、方案设计

## 命名与参数设计

为了兼容现有API,并支持axis, reduce的功能扩展,scatter API 按如下接口设计:

```python
# https://github.com/PaddlePaddle/Paddle/blob/release/2.5/python/paddle/tensor/manipulation.py#L2849
paddle.scatter(x, index, updates, overwrite=True, axis=0, reduce='sum', include_self=False, name=None)
paddle.scatter_(x, index, updates, overwrite=True, axis=0, reduce='sum', include_self=False, name=None)

scatter 参数如下:
```
- `x (Tensor)` - ndim > = 1 的输入 N-D Tensor。数据类型可以是 float32,float64,int32,int64。GPU额外支持 bfloat16和float16。
- `index (Tensor)`- 一维或者零维 Tensor。数据类型可以是 int32,int64。 index 的长度不能超过 updates 的长度,并且 index 中的值不能超过输入的长度。
- `updates (Tensor` - 根据 index 使用 update 参数更新输入 x。当 index 为一维 tensor 时,updates 形状应与输入 x 相同,并且 dim>1 的 dim 值应与输入 x 相同。当 index 为零维 tensor 时,updates 应该是一个 (N-1)-D 的 Tensor,并且 updates 的第 i 个维度应该与 x 的 i+1 个维度相同。
- `overwrite (bool,可选)`- 指定索引 index 相同时,更新输出的方式。如果为 True,则使用覆盖模式更新相同索引的输出,如果为 False,则根据`reduce`参数指定的模式更新相同索引的输出。默认值为 True。
- `axis (int, 可选)` - 要索引的维度。默认值为0.
- `reduce(str,可选)` - 指定规约运算,可以是 add, mul/multiply, mean, amax, amin。默认值为 add.
- `include_self (bool,可选)` - arr 张量中的元素是否包含在规约中。默认值 include_self = False.
- `name (str,可选)` - 具体用法请参见 [Name](https://www.paddlepaddle.org.cn/documentation/docs/zh/api_guides/low_level/program.html#api-guide-name),一般无需设置,默认值为 None。


相比于 paddle.scatter :
1. 目前不支持执行axis的属性,默认支持axis=0。
2. reduce方式只支持 add,需要新增 mul、mean、amax、amin等规约方式。
3. 不支持include_self配置,目前默认是 include_self=False。
4. assign规约是由overwite控制,而不是reduce。注意当是assign模式时,index的数值不唯一时,计算结果行为不确定,梯度计算也不对。

## 底层OP设计

增强现有 scatter op 实现,支持指定dim和指定reduce方法。具体为`_C_ops.scatter``_C_ops.scatter_`. OP 的参数遵循API的参数与顺序。

axis 使用如下规则索引:

```python
self[index[i], :, :] *= src[i, :, :] # if axis == 0
self[:, index[i], :] *= src[:, i, :] # if axis == 1
self[:, :, index[i]] *= src[:, :, i] # if axis == 2
```

为了保证兼容性,保留 `overwrite` 字段。只有到 `overwrite` 为False时,按照 `reduce` 选择规约方式;否则则按照 `assign` 逻辑处理。


反向梯度计算逻辑如下:

```c++
if (reduce == "prod") {
Tensor masked_self = self.masked_fill(self == 0, 1);
Tensor masked_self_result = masked_self.index_reduce(dim, index, source, reduce, include_self);
grad_self = grad * masked_self_result / masked_self;
Tensor src_zero = source == 0;
Tensor src_num_zeros = zeros_like(self).index_add(dim, index, src_zero.to(self.dtype())).index_select(dim, index);
Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
// For src positions with src_single_zero, (grad * result).index_select(dim,index) / source.masked_fill(src_zero, 1)
// would incorrectly propagate zeros as the gradient
Tensor masked_src = source.masked_fill(src_single_zero, 1);
Tensor masked_src_result = self.index_reduce(dim, index, masked_src, reduce, include_self);
Tensor grad_src1 = where(src_single_zero,
(grad * masked_src_result).index_select(dim, index),
(grad * result).index_select(dim, index) / source.masked_fill(src_zero, 1));
if ((src_num_zeros > 1).any().item<bool>()) {
auto node = std::make_shared<DelayedError>(
"index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
/* num inputs */ 1);
auto result = node->apply({ grad_src1 });
grad_src = result[0];
} else {
grad_src = grad_src1;
}
} else if (reduce == "mean") {
Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
N = N.index_add(dim, index, ones_like(source));
N.masked_fill_(N == 0, 1);
grad_self = grad / N;
Tensor N_src = N.index_select(dim, index);
grad_src = grad.index_select(dim, index) / N_src;
} else if (reduce == "amax" || reduce == "amin") {
Tensor value = result.index_select(dim, index);
Tensor self_is_result = (self == result).to(self.scalar_type());
Tensor source_is_result = (source == value).to(self.scalar_type());
Tensor N_to_distribute = self_is_result.index_add(dim, index, source_is_result);
Tensor grad_distributed = grad / N_to_distribute;
grad_self = self_is_result * grad_distributed;
grad_src = source_is_result * grad_distributed.index_select(dim, index);
} else {
AT_ERROR("Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ", reduce, ".");
}

if (!include_self) {
grad_self = grad_self.index_fill(dim, index, 0);
}
```
reduce=sum 的计算逻辑和 mean 类似。
overwite的计算逻辑不变,同assign。
## API实现方案
[API已有](https://github.com/PaddlePaddle/Paddle/blob/release/2.5/python/paddle/tensor/manipulation.py#L2849),需要新增axis和reduce参数。
## 代码实现文件路径
函数API实现路径: python/paddle/tensor/manipulation.py
单元测试路径:在 Paddle repo 的 test/ 目录, 同时在 paddle/test/legacy_test/test_inplace.py、paddle/test/legacy_test/test_scatter_op.py 修改对应的单侧。
# 六、测试和验收的考量
测试考虑的case如下:
- 不同 dim 下功能是否符合预期。
- 不同 reduce 下功能是否符合预期。
- 验证反向梯度是否符合预期。
# 七、可行性分析及规划排期
方案实施难度可控,工期上可以满足在当前版本周期内开发完成。
# 八、影响面
为独立新增API,对其他模块没有影响
# 名词解释
# 附件及参考资料

0 comments on commit 9559b65

Please sign in to comment.