Skip to content

Commit

Permalink
【Hackathon 5th No.27】为 Paddle 新增 select_scatter API -part (#59343)
Browse files Browse the repository at this point in the history
* support select_scatter op

* fix example code

* fix sc

* update example

* remove unused files

* add name

* fix conflict

* update

* remove

* update

* add type

* update type
  • Loading branch information
YibinLiu666 authored Dec 7, 2023
1 parent 6815027 commit 3a0aeaa
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@
tolist,
take_along_axis,
put_along_axis,
select_scatter,
tensordot,
as_complex,
as_real,
Expand Down Expand Up @@ -894,6 +895,7 @@
'renorm_',
'take_along_axis',
'put_along_axis',
'select_scatter',
'multigammaln',
'multigammaln_',
'nan_to_num',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
scatter_,
scatter_nd,
scatter_nd_add,
select_scatter,
shard_index,
slice,
split,
Expand Down Expand Up @@ -697,6 +698,7 @@
'repeat_interleave',
'take_along_axis',
'put_along_axis',
'select_scatter',
'put_along_axis_',
'exponential_',
'heaviside',
Expand Down
106 changes: 106 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6006,3 +6006,109 @@ def diagonal_scatter(x, y, offset=0, axis1=0, axis2=1, name=None):
"""
return fill_diagonal_tensor(x, y, offset, axis1, axis2, name)


def select_scatter(x, values, axis, index, name=None):
"""
Embeds the values of the values tensor into x at the given index of axis.
Args:
x (Tensor) : The Destination Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`.
values (Tensor) : The tensor to embed into x. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex128`.
axis (int) : the dimension to insert the slice into.
index (int) : the index to select with.
name (str, optional): Name for the operation (optional, default is None).
Returns:
Tensor, same dtype and shape with x
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.zeros((2,3,4)).astype("float32")
>>> values = paddle.ones((2,4)).astype("float32")
>>> res = paddle.select_scatter(x,values,1,1)
>>> print(res)
Tensor(shape=[2, 3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[1., 1., 1., 1.],
[0., 0., 0., 0.]]])
"""
x_shape = x.shape
value_shape = values.shape
if not isinstance(x_shape, list):
x_shape = list(x_shape)
if index < 0:
index += x_shape[axis]
if axis < 0:
axis += len(x_shape)
del x_shape[axis]
if len(x_shape) != len(value_shape):
raise RuntimeError(
"expected values to have a size equal to the slice of x. value size = "
+ str(value_shape)
+ " slice size = "
+ str(x_shape)
)
for i in range(len(x_shape)):
if x_shape[i] != value_shape[i]:
raise RuntimeError(
"expected values to have a size equal to the slice of x. value size = "
+ str(value_shape)
+ " slice size = "
+ str(x_shape)
)
from ..base.framework import default_main_program

starts = [index]
ends = [index + 1]
steps = [1]
axes = [axis]
none_axes = []
decrease_axes = [axis]
inputs = {'Input': x}
attrs = {
'axes': axes,
'starts': starts,
'ends': ends,
'steps': steps,
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

dtype = x.dtype
attrs['dtype'] = dtype

values = values.astype(dtype)
inputs["ValueTensor"] = values

if in_dynamic_or_pir_mode():
return _C_ops.set_value_with_tensor(
x,
values,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
)
else:
helper = LayerHelper('select_scatter', **locals())
output = helper.create_variable_for_type_inference(dtype=x.dtype)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

return output
167 changes: 167 additions & 0 deletions test/legacy_test/test_select_scatter_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np

import paddle
from paddle.framework import core
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestSelectScatterAPI(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4]
self.type = np.float32
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 1
self.index = 1
self.value_shape = [2, 4]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(4):
out_ref[i, index, j] = value_np[i, j]

@test_with_pir_api
def test_api_static(self):
paddle.enable_static()

def run(place):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('Src', self.shape, self.type)
value = paddle.static.data(
'Values', self.value_shape, self.type
)
out = paddle.select_scatter(x, value, self.axis, self.index)
exe = paddle.static.Executor(place)
res = exe.run(
feed={
'Src': self.x_feed,
'Values': self.value_np,
},
fetch_list=[out],
)

out_ref = copy.deepcopy(self.x_np)
self.get_out_ref(out_ref, self.index, self.value_np)
for out in res:
np.testing.assert_allclose(out, out_ref, rtol=0.001)

for place in self.place:
run(place)

def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor(self.value_np)
out = paddle.select_scatter(
x_tensor, value_tensor, self.axis, self.index
)
out_ref = copy.deepcopy(self.x_np)
self.get_out_ref(out_ref, self.index, self.value_np)
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)

paddle.enable_static()

for place in self.place:
run(place)


class TestSelectScatterAPICase2(TestSelectScatterAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4, 5]
self.type = np.float64
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 2
self.index = 1
self.value_shape = [2, 3, 5]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(3):
for k in range(5):
out_ref[i, j, index, k] = value_np[i, j, k]


class TestSelectScatterAPICase3(TestSelectScatterAPI):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4, 5, 6]
self.type = np.int32
self.x_np = np.random.random(self.shape).astype(self.type)
self.place = [paddle.CPUPlace()]
self.axis = 2
self.index = 1
self.value_shape = [2, 3, 5, 6]
self.value_np = np.random.random(self.value_shape).astype(self.type)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def get_out_ref(self, out_ref, index, value_np):
for i in range(2):
for j in range(3):
for k in range(5):
for w in range(6):
out_ref[i, j, index, k, w] = value_np[i, j, k, w]


class TestSelectScatterAPIError(unittest.TestCase):
def setUp(self):
np.random.seed(0)
self.shape = [2, 3, 4]
self.x_np = np.random.random(self.shape).astype(np.float32)
self.place = [paddle.CPUPlace()]
self.axis = 1
self.index = 1
self.value_shape = [2, 4]
self.value_np = np.random.random(self.value_shape).astype(np.float32)
self.x_feed = copy.deepcopy(self.x_np)
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def test_len_of_shape_not_equal_error(self):
with self.assertRaises(RuntimeError):
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor(self.value_np).reshape((2, 2, 2))
res = paddle.select_scatter(x_tensor, value_tensor, 1, 1)

def test_one_of_size_not_equal_error(self):
with self.assertRaises(RuntimeError):
x_tensor = paddle.to_tensor(self.x_np)
value_tensor = paddle.to_tensor([[2, 2], [2, 2]]).astype(np.float32)
res = paddle.select_scatter(x_tensor, value_tensor, 1, 1)


if __name__ == "__main__":
paddle.enable_static()
unittest.main()

0 comments on commit 3a0aeaa

Please sign in to comment.