Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.27】为 Paddle 新增 select_scatter API -part #59343

Merged
merged 13 commits into from
Dec 7, 2023
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
tolist,
take_along_axis,
put_along_axis,
select_scatter,
tensordot,
as_complex,
as_real,
Expand Down Expand Up @@ -893,6 +894,7 @@
'renorm_',
'take_along_axis',
'put_along_axis',
'select_scatter',
'multigammaln',
'multigammaln_',
'nan_to_num',
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
zeros,
zeros_like,
)

from .einsum import einsum # noqa: F401
from .linalg import ( # noqa: F401
bincount,
Expand Down Expand Up @@ -173,6 +174,7 @@
scatter_,
scatter_nd,
scatter_nd_add,
select_scatter,
shard_index,
slice,
split,
Expand Down Expand Up @@ -693,6 +695,7 @@
'repeat_interleave',
'take_along_axis',
'put_along_axis',
'select_scatter',
'put_along_axis_',
'exponential_',
'heaviside',
Expand Down
106 changes: 106 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5921,3 +5921,109 @@ def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None):

"""
return fill_diagonal_tensor(x, y, offset, axis1, axis2, name)


def select_scatter(x, values, axis, index, name=None):
"""
Embeds the values of the values tensor into x at the given index of axis.

Args:
x (Tensor) : The Destination Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex64`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex64, complex64 should be complex64, complex128 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

values (Tensor) : The tensor to embed into x. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex64`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex64, complex64 should be complex64, complex128 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

axis (int) : the dimension to insert the slice into.
index (int) : the index to select with.
name (str, optional): Name for the operation (optional, default is None).

Returns:
Tensor, same dtype and shape with x

Examples:
.. code-block:: python

>>> import paddle

>>> x = paddle.zeros((2,3,4)).astype("float32")
>>> values = paddle.ones((2,4)).astype("float32")
>>> res = paddle.select_scatter(x,values,1,1)
>>> print(res)
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]]])

"""
x_shape = x.shape
value_shape = values.shape
if not isinstance(x_shape, list):
x_shape = list(x_shape)
if index < 0:
index += x_shape[axis]
if axis < 0:
axis += len(x_shape)
del x_shape[axis]
if len(x_shape) != len(value_shape):
raise RuntimeError(
"expected values to have a size equal to the slice of x. value size = "
+ str(value_shape)
+ " slice size = "
+ str(x_shape)
)
for i in range(len(x_shape)):
if x_shape[i] != value_shape[i]:
raise RuntimeError(
"expected values to have a size equal to the slice of x. value size = "
+ str(value_shape)
+ " slice size = "
+ str(x_shape)
)
from ..base.framework import default_main_program

starts = [index]
ends = [index + 1]
steps = [1]
axes = [axis]
none_axes = []
decrease_axes = [axis]
inputs = {'Input': x}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

dtype = x.dtype
attrs['dtype'] = dtype

values = values.astype(dtype)
inputs["ValueTensor"] = values

if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
x,
values,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
)
else:
helper = LayerHelper('select_scatter', **locals())
output = helper.create_variable_for_type_inference(dtype=x.dtype)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

return output
167 changes: 167 additions & 0 deletions test/legacy_test/test_select_scatter_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np

import paddle
from paddle.framework import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestSelectScatterAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4]
self.type = np.float32
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 1
self.index = 1
self.value_shape = [2, 4]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(4):
out_ref[i, index, j] = value_np[i, j]

@test_with_pir_api
def test_api_static(self):
paddle.enable_static()

def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('Src', self.shape, self.type)
value = paddle.static.data(
'Values', self.value_shape, self.type
)
out = paddle.select_scatter(x, value, self.axis, self.index)
exe = paddle.static.Executor(place)
res = exe.run(
feed={
'Src': self.x_feed,
'Values': self.value_np,
},
fetch_list=[out],
)

out_ref = copy.deepcopy(self.x_np)
self.get_out_ref(out_ref, self.index, self.value_np)
for out in res:
np.testing.assert_allclose(out, out_ref, rtol=0.001)

for place in self.place:
run(place)

def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.select_scatter(
x_tensor, value_tensor, self.axis, self.index
)
out_ref = copy.deepcopy(self.x_np)
self.get_out_ref(out_ref, self.index, self.value_np)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)

paddle.enable_static()

for place in self.place:
run(place)


class TestSelectScatterAPICase2(TestSelectScatterAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4, 5]
self.type = np.float64
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 2
self.index = 1
self.value_shape = [2, 3, 5]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(3):
for k in range(5):
out_ref[i, j, index, k] = value_np[i, j, k]


class TestSelectScatterAPICase3(TestSelectScatterAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4, 5, 6]
self.type = np.int32
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 2
self.index = 1
self.value_shape = [2, 3, 5, 6]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(3):
for k in range(5):
for w in range(6):
out_ref[i, j, index, k, w] = value_np[i, j, k, w]


class TestSelectScatterAPIError(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4]
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 1
self.index = 1
self.value_shape = [2, 4]
self.value_np = np.random.random(self.value_shape).astype(np.float32)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_len_of_shape_not_equal_error(self):
with self.assertRaises(RuntimeError):
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor(self.value_np).reshape((2, 2, 2))
res = paddle.select_scatter(x_tensor, value_tensor, 1, 1)

def test_one_of_size_not_equal_error(self):
with self.assertRaises(RuntimeError):
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor([[2, 2], [2, 2]]).astype(np.float32)
res = paddle.select_scatter(x_tensor, value_tensor, 1, 1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里查错辛苦直接用 self.assertRaises吧,否则后续如果某些修改导致try中语句能过了,这个单测会检测不到。

此外,这两个case能拆成两个函数吗,命名上再清晰一些说明是检查的什么case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


if __name__ == "__main__":
paddle.enable_static()
unittest.main()