Skip to content

Commit

Permalink
NPU and MLU support 0D index for gather.
Browse files Browse the repository at this point in the history
NPU and MLU support 0D index and 0D updates for scatter.
  • Loading branch information
FeixLiu committed Dec 2, 2022
1 parent b8f2cad commit 841da05
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 10 deletions.
12 changes: 6 additions & 6 deletions backends/mlu/kernels/gather_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ void GatherKernel(const Context& dev_ctx,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}

Expand Down Expand Up @@ -80,10 +80,10 @@ void GatherGradKernel(const Context& dev_ctx,
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
index_dims.size() == 1 || index_dims.size() == 0,
true,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
"The index should be 0D or 1D, when it is not 2D, but we get %d",
index_dims.size()));
}

Expand Down
49 changes: 49 additions & 0 deletions backends/mlu/tests/unittests/test_zero_dim_tensor_mlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,55 @@ def test_searchsorted(self):
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)

def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()

self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])

def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()

self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])

def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)

self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])

def test_scatter_0D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)

self.assertEqual(out.numpy()[2], 4)

def test_scatter_XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)

for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions backends/npu/kernels/gather_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ void GatherGradKernel(const Context& dev_ctx,
// step1: Unsqueeze index
phi::DenseTensor tmp_tensor(index);
const auto index_dims = index.dims();
if (index_dims.size() == 1) {
std::vector<int64_t> new_dim = {index_dims[0], 1};
if (index_dims.size() == 1 || index_dims.size() == 0) {
std::vector<int64_t> new_dim =
{index_dims.size() == 0 ? 1 : index_dims[0], 1};
tmp_tensor.Resize(phi::make_ddim(new_dim));
p_index = &tmp_tensor;
}
Expand Down
5 changes: 3 additions & 2 deletions backends/npu/kernels/scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ void ScatterKernel(const Context& dev_ctx,

phi::DenseTensor tmp_tensor(index);
const auto index_dims = index.dims();
if (index_dims.size() == 1) {
std::vector<int64_t> new_dim = {index_dims[0], 1};
if (index_dims.size() == 1 || index_dims.size() == 0) {
std::vector<int64_t> new_dim =
{index_dims.size() == 0 ? 1 : index_dims[0], 1};
tmp_tensor.Resize(phi::make_ddim(new_dim));
}

Expand Down
49 changes: 49 additions & 0 deletions backends/npu/tests/unittests/test_zero_dim_tensor_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,55 @@ def test_searchsorted(self):
self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 0)

def test_gather_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
out = paddle.gather(x, index)
out.backward()

self.assertEqual(out.shape, [])
self.assertEqual(out.numpy(), 5)
self.assertEqual(out.grad.shape, [])

def test_gather_xD_axis_0(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index)
out.backward()

self.assertEqual(out.shape, [3])
for i in range(3):
self.assertEqual(out.numpy()[i], x.numpy()[1][i])
self.assertEqual(out.grad.shape, [3])

def test_gather_xD_axis_1(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)

self.assertEqual(out.shape, [2])
for i in range(2):
self.assertEqual(out.numpy()[i], x.numpy()[i][1])

def test_scatter_0D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0])
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)

self.assertEqual(out.numpy()[2], 4)

def test_scatter_XD(self):
x = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)

for i in range(3):
self.assertEqual(out.numpy()[1][i], updates.numpy()[i])

# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 841da05

Please sign in to comment.