diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index dc0a84f8ef16e..67b25a0e98ff9 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -269,6 +269,7 @@ concat, crop, diagonal_scatter, + dsplit, expand, expand_as, flatten, @@ -276,6 +277,7 @@ flip as reverse, gather, gather_nd, + hsplit, index_add, index_add_, index_fill, @@ -309,6 +311,7 @@ row_stack, strided_slice, take_along_axis, + tensor_split, tensordot, tile, tolist, @@ -631,6 +634,9 @@ 'searchsorted', 'bucketize', 'split', + 'tensor_split', + 'hsplit', + 'dsplit', 'vsplit', 'logical_and', 'logical_and_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 56b5022b86536..267160a7a227d 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -149,6 +149,7 @@ column_stack, concat, diagonal_scatter, + dsplit, dstack, expand, expand_as, @@ -158,6 +159,7 @@ flip as reverse, gather, gather_nd, + hsplit, hstack, index_add, index_add_, @@ -189,6 +191,7 @@ stack, strided_slice, take_along_axis, + tensor_split, tensordot, tile, unbind, @@ -608,6 +611,9 @@ 'shard_index', 'slice', 'split', + 'tensor_split', + 'hsplit', + 'dsplit', 'vsplit', 'chunk', 'tensordot', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 807ad1aae5242..39c196dffbfa7 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2571,17 +2571,227 @@ def _get_SectionsTensorList(one_list): return outs -def vsplit(x, num_or_sections, name=None): +def tensor_split(x, num_or_indices, axis=0, name=None): """ - Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``. + Split the input tensor into multiple sub-Tensors along ``axis``, allowing not being of equal size. Args: - x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. - num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` - indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. - If ``num_or_sections`` is a list or tuple, the length of it indicates the number of - sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. - The length of the list must not be larger than the ``x`` 's size of axis 0. + x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64. + num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``. + If ``x`` is divisible by ``n``, each section will be ``x.shape[axis] / n``. If ``x`` is not divisible by ``n``, the first + ``int(x.shape[axis] % n)`` sections will have size ``int(x.shape[axis] / n) + 1``, and the rest will be ``int(x.shape[axis] / n). + If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split along ``axis`` at each of the indices. For instance, + ``num_or_indices=[2, 4]`` with ``axis=0`` would split ``x`` into ``x[:2]``, ``x[2:4]`` and ``x[4:]`` along axis 0. + axis (int|Tensor, optional): The axis along which to split, it can be a integer or a ``0-D Tensor`` + with shape [] and data type ``int32`` or ``int64``. + If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Returns: + list[Tensor], The list of segmented Tensors. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # x is a Tensor of shape [8] + >>> # evenly split + >>> x = paddle.rand([8]) + >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2) + >>> print(out0.shape) + [4] + >>> print(out1.shape) + [4] + + >>> # not evenly split + >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3) + >>> print(out0.shape) + [3] + >>> print(out1.shape) + [3] + >>> print(out2.shape) + [2] + + >>> # split with indices + >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3]) + >>> print(out0.shape) + [2] + >>> print(out1.shape) + [1] + >>> print(out2.shape) + [5] + + >>> # split along axis + >>> # x is a Tensor of shape [7, 8] + >>> x = paddle.rand([7, 8]) + >>> out0, out1 = paddle.tensor_split(x, num_or_indices=2, axis=1) + >>> print(out0.shape) + [7, 4] + >>> print(out1.shape) + [7, 4] + + >>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1) + >>> print(out0.shape) + [7, 2] + >>> print(out1.shape) + [7, 1] + >>> print(out2.shape) + [7, 5] + + """ + if x.ndim <= 0 or x.ndim <= axis: + raise ValueError( + f"The input tensor's dimension must be greater than 0 or axis which is {axis}, but got {x.ndim}" + ) + + total_n = x.shape[axis] + + def _tensor_split_indices(x, total_n, indices, axis): + splits = [] + + starts = 0 + ends = 0 + for idx in list(indices) + [total_n]: + ends = idx + # convert index < 0 to positive + starts_index = starts if starts >= 0 else total_n + starts + ends_index = ends if ends >= 0 else total_n + ends + # ends index should equal or larger than starts + ends_index = max(starts_index, ends_index) + + sub_array = paddle.slice( + x, axes=[axis], starts=[starts_index], ends=[ends_index] + ) + splits.append(sub_array) + starts = ends + + return splits + + def _tensor_split_sections(x, total_n, sections, axis): + if sections <= 0: + raise ValueError('num_or_indices must be larger than 0.') + + base, mod = divmod(total_n, sections) + num_or_sections = [base + 1] * mod + [base] * (sections - mod) + return split(x, num_or_sections, axis) + + if isinstance(num_or_indices, int): + return _tensor_split_sections(x, total_n, num_or_indices, axis) + + elif isinstance(num_or_indices, (list, tuple)): + return _tensor_split_indices(x, total_n, num_or_indices, axis) + + else: + raise ValueError( + f"The num_or_indices should be int, list or tuple of ints, but got {type(num_or_indices)}" + ) + + +def hsplit(x, num_or_indices, name=None): + """ + Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.tensor_split`` with ``axis=1`` + when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.tensor_split`` with ``axis=0`` when ``x`` 's dimension is 1. + + Args: + x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64. + num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections. + If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Returns: + list[Tensor], The list of segmented Tensors. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # x is a Tensor of shape [8] + >>> x = paddle.rand([8]) + >>> out0, out1 = paddle.hsplit(x, num_or_indices=2) + >>> print(out0.shape) + [4] + >>> print(out1.shape) + [4] + + >>> # x is a Tensor of shape [7, 8] + >>> x = paddle.rand([7, 8]) + >>> out0, out1 = paddle.hsplit(x, num_or_indices=2) + >>> print(out0.shape) + [7, 4] + >>> print(out1.shape) + [7, 4] + + >>> out0, out1, out2 = paddle.hsplit(x, num_or_indices=[1, 4]) + >>> print(out0.shape) + [7, 1] + >>> print(out1.shape) + [7, 3] + >>> print(out2.shape) + [7, 4] + + """ + if x.ndim < 1: + raise ValueError( + f"The input tensor's dimension must be greater than 0, but got {x.ndim}" + ) + if x.ndim > 1: + return tensor_split(x, num_or_indices, axis=1, name=name) + else: + return tensor_split(x, num_or_indices, axis=0, name=name) + + +def dsplit(x, num_or_indices, name=None): + """ + Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``. + + Args: + x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64. + num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections. + If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Returns: + list[Tensor], The list of segmented Tensors. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> # x is a Tensor of shape [7, 6, 8] + >>> x = paddle.rand([7, 6, 8]) + >>> out0, out1 = paddle.dsplit(x, num_or_indices=2) + >>> print(out0.shape) + [7, 6, 4] + >>> print(out1.shape) + [7, 6, 4] + + >>> out0, out1, out2 = paddle.dsplit(x, num_or_indices=[1, 4]) + >>> print(out0.shape) + [7, 6, 1] + >>> print(out1.shape) + [7, 6, 3] + >>> print(out2.shape) + [7, 6, 4] + + """ + if x.ndim < 3: + raise ValueError( + f"The input tensor's dimension must be greater than 2, but got {x.ndim}" + ) + return tensor_split(x, num_or_indices, axis=2, name=name) + + +def vsplit(x, num_or_indices, name=None): + """ + Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``. + + Args: + x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64. + num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections. + If ``num_or_indices`` is a list or tuple of integter indices, ``x`` is split at each of the indices. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: @@ -2594,31 +2804,26 @@ def vsplit(x, num_or_sections, name=None): >>> # x is a Tensor of shape [8, 6, 7] >>> x = paddle.rand([8, 6, 7]) - >>> out0, out1 = paddle.vsplit(x, num_or_sections=2) + >>> out0, out1 = paddle.vsplit(x, num_or_indices=2) >>> print(out0.shape) [4, 6, 7] >>> print(out1.shape) [4, 6, 7] - >>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) + + >>> out0, out1, out2 = paddle.vsplit(x, num_or_indices=[1, 4]) >>> print(out0.shape) [1, 6, 7] >>> print(out1.shape) [3, 6, 7] >>> print(out2.shape) [4, 6, 7] - >>> out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) - >>> print(out0.shape) - [2, 6, 7] - >>> print(out1.shape) - [3, 6, 7] - >>> print(out2.shape) - [3, 6, 7] + """ if x.ndim < 2: raise ValueError( f"The input tensor's dimension must be greater than 1, but got {x.ndim}" ) - return split(x, num_or_sections, axis=0, name=name) + return tensor_split(x, num_or_indices, axis=0, name=name) def squeeze(x, axis=None, name=None): diff --git a/test/legacy_test/test_splits_api.py b/test/legacy_test/test_splits_api.py index 4e319e6cb4b91..a065a0eb7fb1a 100644 --- a/test/legacy_test/test_splits_api.py +++ b/test/legacy_test/test_splits_api.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import unittest import numpy as np @@ -20,168 +21,699 @@ from paddle.base import core from paddle.pir_utils import test_with_pir_api +RTOL = 1e-5 +ATOL = 1e-8 -def func_ref(func, x, num_or_sections): - # Convert the num_or_sections in paddle to indices_or_sections in numpy - # Do not support -1 - if isinstance(num_or_sections, int): - indices_or_sections = num_or_sections - else: - indices_or_sections = np.cumsum(num_or_sections)[:-1] - return func(x, indices_or_sections) +DTYPE_ALL_CPU = { + 'float64', + 'float16', + 'float32', + 'bool', + 'uint8', + 'int32', + 'int64', +} +# add `bfloat16` if core is complied with CUDA and support the bfloat16 +DTYPE_ALL_GPU = DTYPE_ALL_CPU | ( + {'bfloat16'} + if core.is_compiled_with_cuda() + and core.is_bfloat16_supported(paddle.CUDAPlace(0)) + else set() +) -# TODO: add other split API, such as dsplit、hsplit -test_list = [ - (paddle.vsplit, np.vsplit), -] +PLACES = [paddle.CPUPlace()] + ( + [paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else [] +) -class TestSplitsAPI(unittest.TestCase): - def setUp(self): - self.rtol = 1e-5 - self.atol = 1e-8 - self.set_input() - def set_input(self): - self.shape = [4, 5, 2] - self.num_or_sections = 2 - self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() - ) +def generate_data(shape, dtype='int64'): + """generate test data + + Args: + shape(list of int): shape of inputs + dtype(str): dtype + + Returns: + x, dtype, shape, name + """ + return { + # bfloat16 convert to uint16 for numpy + 'x': np.random.randint(0, 255, size=shape).astype( + dtype if dtype != 'bfloat16' else 'uint16' + ), + 'dtype': dtype, + 'shape': shape, + 'name': f'{shape}_{dtype}', + } + + +class BaseTest(unittest.TestCase): + """Test in each `PLACES` and in `static/dygraph`""" @test_with_pir_api - def test_static_api(self): + def _test_static_api( + self, + func_paddle, + func_numpy, + x, + dtype, + shape, + name, + split_paddle, + split_numpy, + places=None, + ): + """Test `static` + + Args: + func_paddle: `hsplit`, `vsplit`, `dsplit`, `tensor_split` + func_numpy: `hsplit`, `vsplit`, `dsplit`, `array_split` + x: input tensor + dtype: input tensor's dtype + shape: input tensor's shape + name: input tensor's name + split_paddle: num_or_sections or indices_or_sections in paddle + split_numpy: `hsplit`, `vsplit`, `dsplit` should convert num_or_sections in paddle to indices_or_sections in numpy. For test error, `split_numpy` is None and skip compare result, ensure the error only raised from paddle. + places: exec place, default to PLACES + """ paddle.enable_static() - for func, func_type in test_list: - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) - out = func(x, self.num_or_sections) - exe = paddle.static.Executor(self.place) - res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) - out_ref = func_ref(func_type, self.x_np, self.num_or_sections) - for n, p in zip(out_ref, res): - np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol) - - def test_dygraph_api(self): - paddle.disable_static(self.place) - x = paddle.to_tensor(self.x_np) - for func, func_type in test_list: - out = func(x, self.num_or_sections) - out_ref = func_ref(func_type, self.x_np, self.num_or_sections) - for n, p in zip(out_ref, out): - np.testing.assert_allclose( - n, p.numpy(), rtol=self.rtol, atol=self.atol + + places = PLACES if places is None else places + for place in places: + program = paddle.static.Program() + exe = paddle.static.Executor(place) + + with paddle.static.program_guard(program): + input = paddle.static.data(name, shape, dtype) + input.stop_gradient = False + + feed = {name: x} + + out = func_paddle(input, split_paddle) + + if paddle.framework.in_pir_mode(): + fetch_list = [out] + grads = paddle.autograd.ir_backward.grad(out, [input]) + out_grad = grads[0] + fetch_list.append(out_grad) + + *res, res_grad = exe.run(feed=feed, fetch_list=fetch_list) + + self.assertEqual(list(res_grad.shape), list(input.shape)) + + else: + res = exe.run(feed=feed, fetch_list=[out]) + + if split_numpy is not None: + out_ref = func_numpy(x, split_numpy) + + for n, p in zip(out_ref, res): + np.testing.assert_allclose(n, p, rtol=RTOL, atol=ATOL) + + def _test_dygraph_api( + self, + func_paddle, + func_numpy, + x, + dtype, + shape, + name, + split_paddle, + split_numpy, + places=None, + ): + """Test `dygraph`, and check grads""" + paddle.disable_static() + + places = PLACES if places is None else places + for place in places: + out = func_paddle(paddle.to_tensor(x).astype(dtype), split_paddle) + + if split_numpy is not None: + out_ref = func_numpy(x, split_numpy) + + for n, p in zip(out_ref, out): + np.testing.assert_allclose( + n, p.numpy(), rtol=RTOL, atol=ATOL + ) + + # check grads for the first tensor + out = out[0] + + for y in out: + y.stop_gradient = False + z = y * 123 + grads = paddle.grad(z, y) + self.assertTrue(len(grads), 1) + self.assertEqual(grads[0].dtype, y.dtype) + self.assertEqual(grads[0].shape, y.shape) + + def _test_all( + self, + kwargs, + ): + self._test_dygraph_api(self.func_paddle, self.func_numpy, **kwargs) + self._test_static_api(self.func_paddle, self.func_numpy, **kwargs) + + +class TestHSplit(BaseTest): + def setUp(self): + self.func_paddle = paddle.hsplit + self.func_numpy = np.hsplit + + def test_split_dim(self): + x = generate_data([6]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } + ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + x = generate_data([4, 6]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } + ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + x = generate_data([4, 6, 3]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } + ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + def test_dtype(self): + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, ) - paddle.enable_static() + def test_error_dim(self): + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) -class TestSplitsSections(TestSplitsAPI): - """ - Test num_or_sections which is a list and date type is float64. - """ + def test_error_split(self): + x = generate_data([5]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 0, 'split_numpy': None}) - def set_input(self): - self.shape = [6, 2, 4] - self.num_or_sections = [2, 1, 3] - self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() + +class TestVSplit(BaseTest): + def setUp(self): + self.func_paddle = paddle.vsplit + self.func_numpy = np.vsplit + + def test_split_dim(self): + x = generate_data([6, 4]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } + ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + x = generate_data([6, 4, 3]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + def test_dtype(self): + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([6, 4], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([6, 4], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, + ) + def test_error_dim(self): + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) -class TestSplitsFloat32(TestSplitsAPI): - """ - Test num_or_sections which is an integer and data type is float32. - """ + # test 1-d + x = generate_data([6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + def test_error_split(self): + x = generate_data([5, 4]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 0, 'split_numpy': None}) + + +class TestDSplit(BaseTest): + def setUp(self): + self.func_paddle = paddle.dsplit + self.func_numpy = np.dsplit + + def test_split_dim(self): + x = generate_data([4, 3, 6]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + + self._test_all( + { + **x, + 'split_paddle': [2, 4], + 'split_numpy': [2, 4], + } + ) + self._test_all( + { + **x, + 'split_paddle': (2, 1, 3), + 'split_numpy': (2, 1, 3), + } + ) + self._test_all( + {**x, 'split_paddle': [-1, 1, 3], 'split_numpy': [-1, 1, 3]} + ) + self._test_all({**x, 'split_paddle': [-1], 'split_numpy': [-1]}) + + def test_dtype(self): + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([4, 2, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([4, 2, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, + ) + + def test_error_dim(self): + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # test 1-d + x = generate_data([6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # test 2-d + x = generate_data([4, 6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + def test_error_split(self): + x = generate_data([3, 6, 5]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 0, 'split_numpy': None}) + + +class TestTensorSplit(BaseTest): + def setUp(self): + self.func_paddle = paddle.tensor_split + self.func_numpy = np.array_split + + def test_split_dim(self): + x = generate_data([6]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 4], 'split_numpy': [2, 4]}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 5), 'split_numpy': (2, 5)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 5], 'split_numpy': [2, 4, 5]} + ) - def set_input(self): - self.shape = [2, 3, 4] - self.num_or_sections = 2 - self.x_np = np.random.uniform(-1, 1, self.shape).astype('float32') - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() + # not evenly split + x = generate_data([7]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 4], 'split_numpy': [2, 4]}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} ) + x = generate_data([7, 4]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 4], 'split_numpy': [2, 4]}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} + ) -class TestSplitsInt32(TestSplitsAPI): - """ - Test data type int32. - """ + x = generate_data([7, 4, 3]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 4], 'split_numpy': [2, 4]}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} + ) - def set_input(self): - self.shape = [5, 1, 2] - self.num_or_sections = 5 - self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32') - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() + def test_split_axis(self): + # 1-d + self.func_paddle = functools.partial(paddle.tensor_split, axis=0) + self.func_numpy = functools.partial(np.array_split, axis=0) + + x = generate_data([7]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} ) + # 2-d + self.func_paddle = functools.partial(paddle.tensor_split, axis=1) + self.func_numpy = functools.partial(np.array_split, axis=1) + + x = generate_data([4, 7]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} + ) -class TestSplitsInt64(TestSplitsAPI): - """ - Test data type int64. - """ + # 3-d + self.func_paddle = functools.partial(paddle.tensor_split, axis=2) + self.func_numpy = functools.partial(np.array_split, axis=2) + + x = generate_data([4, 4, 7]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} + ) - def set_input(self): - self.shape = [4, 3, 2] - self.num_or_sections = 2 - self.x_np = np.random.uniform(-1, 1, self.shape).astype('int64') - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() + # n-d + self.func_paddle = functools.partial(paddle.tensor_split, axis=3) + self.func_numpy = functools.partial(np.array_split, axis=3) + + x = generate_data([4, 4, 4, 7]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} ) + # axis -2 + self.func_paddle = functools.partial(paddle.tensor_split, axis=-2) + self.func_numpy = functools.partial(np.array_split, axis=-2) + + x = generate_data([4, 4, 7, 4]) + self._test_all({**x, 'split_paddle': 3, 'split_numpy': 3}) + self._test_all({**x, 'split_paddle': 2, 'split_numpy': 2}) + self._test_all({**x, 'split_paddle': [2, 3], 'split_numpy': [2, 3]}) + self._test_all({**x, 'split_paddle': (2, 6), 'split_numpy': (2, 6)}) + self._test_all( + {**x, 'split_paddle': [2, 4, 6], 'split_numpy': [2, 4, 6]} + ) -class TestSplitsCPU(TestSplitsAPI): - """ - Test cpu place and num_or_sections which is a tuple. - """ + def test_special_indices(self): + """indices in a mess, negative index, index out of range""" + self.func_paddle = functools.partial(paddle.tensor_split, axis=0) + self.func_numpy = functools.partial(np.array_split, axis=0) - def set_input(self): - self.shape = [8, 2, 3, 5] - self.num_or_sections = (2, 3, 3) - self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') - self.place = paddle.CPUPlace() + x = generate_data([7]) + # indices' order in a mess + self._test_all( + {**x, 'split_paddle': [2, 1, 3], 'split_numpy': [2, 1, 3]} + ) + # index out of range + self._test_all( + {**x, 'split_paddle': [2, 3, 16], 'split_numpy': [2, 3, 16]} + ) -class TestSplitsError(unittest.TestCase): - """ - Test the situation that input shape less than 2. - """ + # index with -1 + self._test_all( + {**x, 'split_paddle': [3, -1, 16], 'split_numpy': [3, -1, 16]} + ) - def setUp(self): - self.num_or_sections = 1 - self.place = ( - paddle.CUDAPlace(0) - if core.is_compiled_with_cuda() - else paddle.CPUPlace() + # mix index + self._test_all( + { + **x, + 'split_paddle': [3, -1, 5, 2, 16], + 'split_numpy': [3, -1, 5, 2, 16], + } ) - @test_with_pir_api - def test_static_error(self): - paddle.enable_static() - for func, _ in test_list: - with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('X', [5], 'float32') - self.assertRaises(ValueError, func, x, self.num_or_sections) - - def test_dygraph_error(self): - paddle.disable_static(self.place) - for func, _ in test_list: - x_np = np.random.randn(2) - x = paddle.to_tensor(x_np, dtype='float64') - self.assertRaises(ValueError, func, x, self.num_or_sections) + def test_dtype(self): + self.func_paddle = functools.partial(paddle.tensor_split, axis=0) + self.func_numpy = functools.partial(np.array_split, axis=0) + + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, + ) + + self.func_paddle = functools.partial(paddle.tensor_split, axis=1) + self.func_numpy = functools.partial(np.array_split, axis=1) + + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([4, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([4, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, + ) + + self.func_paddle = functools.partial(paddle.tensor_split, axis=2) + self.func_numpy = functools.partial(np.array_split, axis=2) + + for dtype in DTYPE_ALL_CPU: + self._test_all( + { + **generate_data([4, 4, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CPUPlace()], + }, + ) + + if core.is_compiled_with_cuda(): + for dtype in DTYPE_ALL_GPU: + self._test_all( + { + **generate_data([4, 4, 6], dtype=dtype), + 'split_paddle': 3, + 'split_numpy': 3, + 'places': [paddle.CUDAPlace(0)], + }, + ) + + def test_error_dim(self): + # axis 0 + self.func_paddle = functools.partial(paddle.tensor_split, axis=0) + self.func_numpy = functools.partial(np.array_split, axis=0) + + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # axis 1 + self.func_paddle = functools.partial(paddle.tensor_split, axis=1) + self.func_numpy = functools.partial(np.array_split, axis=1) + + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # test 1-d + x = generate_data([6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # axis 2 + self.func_paddle = functools.partial(paddle.tensor_split, axis=2) + self.func_numpy = functools.partial(np.array_split, axis=2) + + # test 0-d + x = generate_data([]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # test 1-d + x = generate_data([6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + # test 2-d + x = generate_data([4, 6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 3, 'split_numpy': None}) + + def test_error_split(self): + x = generate_data([6]) + with self.assertRaises(ValueError): + self._test_all({**x, 'split_paddle': 0, 'split_numpy': None}) if __name__ == '__main__':