From 63323080be426be5c67e5add827097ebdf866ad3 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 17 Jan 2023 22:16:33 +0800 Subject: [PATCH] [Zero-Dim] Support input 0D Tensor for masked_select (#365) --- .../tests/unittests/test_zero_dim_tensor_mlu.py | 14 ++++++++++++++ .../tests/unittests/test_zero_dim_tensor_npu.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/backends/mlu/tests/unittests/test_zero_dim_tensor_mlu.py b/backends/mlu/tests/unittests/test_zero_dim_tensor_mlu.py index b54baffba..e6a0e5b53 100644 --- a/backends/mlu/tests/unittests/test_zero_dim_tensor_mlu.py +++ b/backends/mlu/tests/unittests/test_zero_dim_tensor_mlu.py @@ -681,6 +681,20 @@ def test_argsort(self): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) + def test_maseked_select(self): + x = paddle.rand([]) + x.stop_gradient = False + mask = paddle.full([], True, dtype="bool") + y = paddle.masked_select(x, mask) + + y.retain_grads() + y.backward() + self.assertEqual(y.shape, [1]) + self.assertEqual(y.numpy(), x.numpy()) + self.assertEqual(y.grad.shape, [1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(x.grad.numpy(), 1) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): diff --git a/backends/npu/tests/unittests/test_zero_dim_tensor_npu.py b/backends/npu/tests/unittests/test_zero_dim_tensor_npu.py index 4a6fa5742..81a5f4350 100644 --- a/backends/npu/tests/unittests/test_zero_dim_tensor_npu.py +++ b/backends/npu/tests/unittests/test_zero_dim_tensor_npu.py @@ -706,6 +706,20 @@ def test_argsort(self): self.assertEqual(x1.grad.numpy(), 0) self.assertEqual(x2.grad.numpy(), 0) + def test_maseked_select(self): + x = paddle.rand([]) + x.stop_gradient = False + mask = paddle.full([], True, dtype="bool") + y = paddle.masked_select(x, mask) + + y.retain_grads() + y.backward() + self.assertEqual(y.shape, [1]) + self.assertEqual(y.numpy(), x.numpy()) + self.assertEqual(y.grad.shape, [1]) + self.assertEqual(x.grad.shape, []) + self.assertEqual(x.grad.numpy(), 1) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase):