From 09e7e3317ff8f68738a4ba06e90d14759bc34535 Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Tue, 26 Mar 2024 13:42:12 +0100 Subject: [PATCH] Implement `dpnp.searchsorted` (#1751) * Implement dpnp.searchsorted * Stated explicit wrapping behavior of out of bound values of sorter per dpctl spec * Muted previously enabled overflow tests * Corrected test_argsort_ndarray test with unique generated random values --- dpnp/backend/include/dpnp_iface_fptr.hpp | 12 +- dpnp/backend/kernels/dpnp_krnl_sorting.cpp | 20 - dpnp/dpnp_algo/dpnp_algo.pxd | 1 - dpnp/dpnp_algo/dpnp_algo_sorting.pxi | 50 --- dpnp/dpnp_array.py | 12 +- dpnp/dpnp_iface_searching.py | 54 ++- dpnp/dpnp_iface_sorting.py | 38 +- tests/skipped_tests.tbl | 22 - tests/skipped_tests_gpu.tbl | 22 - tests/test_sort.py | 351 ++++++++++----- tests/test_sycl_queue.py | 23 + tests/test_usm_type.py | 14 + .../cupy/sorting_tests/test_search.py | 409 +++++++++++++----- 13 files changed, 634 insertions(+), 394 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 780a49f37ac..4b8c4e86c0d 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -332,13 +332,11 @@ enum class DPNPFuncName : size_t DPNP_FN_RNG_ZIPF_EXT, /**< Used in numpy.random.zipf() impl, requires extra parameters */ DPNP_FN_SEARCHSORTED, /**< Used in numpy.searchsorted() impl */ - DPNP_FN_SEARCHSORTED_EXT, /**< Used in numpy.searchsorted() impl, requires - extra parameters */ - DPNP_FN_SIGN, /**< Used in numpy.sign() impl */ - DPNP_FN_SIN, /**< Used in numpy.sin() impl */ - DPNP_FN_SINH, /**< Used in numpy.sinh() impl */ - DPNP_FN_SORT, /**< Used in numpy.sort() impl */ - DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */ + DPNP_FN_SIGN, /**< Used in numpy.sign() impl */ + DPNP_FN_SIN, /**< Used in numpy.sin() impl */ + DPNP_FN_SINH, /**< Used in numpy.sinh() impl */ + DPNP_FN_SORT, /**< Used in numpy.sort() impl */ + DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */ DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters */ DPNP_FN_SQUARE, /**< Used in numpy.square() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp index 6f33c1af723..2c50a490f15 100644 --- a/dpnp/backend/kernels/dpnp_krnl_sorting.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_sorting.cpp @@ -403,17 +403,6 @@ void (*dpnp_searchsorted_default_c)(void *, const size_t) = dpnp_searchsorted_c<_DataType, _IndexingType>; -template -DPCTLSyclEventRef (*dpnp_searchsorted_ext_c)(DPCTLSyclQueueRef, - void *, - const void *, - const void *, - bool, - const size_t, - const size_t, - const DPCTLEventVectorRef) = - dpnp_searchsorted_c<_DataType, _IndexingType>; - template class dpnp_sort_c_kernel; @@ -507,15 +496,6 @@ void func_map_init_sorting(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_searchsorted_default_c}; - fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_searchsorted_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_searchsorted_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_searchsorted_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_searchsorted_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = { eft_INT, (void *)dpnp_sort_default_c}; fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index fb87cdd30d8..cc87448fc5a 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -94,7 +94,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_RNG_WALD_EXT DPNP_FN_RNG_WEIBULL_EXT DPNP_FN_RNG_ZIPF_EXT - DPNP_FN_SEARCHSORTED_EXT DPNP_FN_TRACE_EXT DPNP_FN_TRAPZ_EXT diff --git a/dpnp/dpnp_algo/dpnp_algo_sorting.pxi b/dpnp/dpnp_algo/dpnp_algo_sorting.pxi index 069b5335c1c..4947fa9e41d 100644 --- a/dpnp/dpnp_algo/dpnp_algo_sorting.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_sorting.pxi @@ -37,7 +37,6 @@ and the rest of the library __all__ += [ "dpnp_partition", - "dpnp_searchsorted", ] @@ -49,14 +48,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_partition_t)(c_dpctl.DPCTLSyclQueu const shape_elem_type * , const size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQueueRef, - void * , - const void * , - const void * , - bool, - const size_t, - const size_t, - const c_dpctl.DPCTLEventVectorRef) cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None): @@ -98,44 +89,3 @@ cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, a c_dpctl.DPCTLEvent_Delete(event_ref) return result - - -cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.dpnp_descriptor v, side='left'): - if side is 'left': - side_ = True - else: - side_ = False - - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SEARCHSORTED_EXT, param1_type, param1_type) - - arr_obj = arr.get_array() - - cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(v.shape, - dpnp.int64, - None, - device=arr_obj.sycl_device, - usm_type=arr_obj.usm_type, - sycl_queue=arr_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_dpnp_searchsorted_t func = kernel_data.ptr - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - arr.get_data(), - v.get_data(), - result.get_data(), - side_, - arr.size, - v.size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 8f784201a2b..8a9a7717c26 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -1121,7 +1121,17 @@ def round(self, decimals=0, out=None): return dpnp.around(self, decimals, out) - # 'searchsorted', + def searchsorted(self, v, side="left", sorter=None): + """ + Find indices where elements of `v` should be inserted in `a` + to maintain order. + + Refer to :obj:`dpnp.searchsorted` for full documentation + + """ + + return dpnp.searchsorted(self, v, side=side, sorter=sorter) + # 'setfield', # 'setflags', diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 3e1fb4c4d98..907214058a7 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -232,9 +232,61 @@ def searchsorted(a, v, side="left", sorter=None): For full documentation refer to :obj:`numpy.searchsorted`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input 1-D array. If `sorter` is ``None``, then it must be sorted in + ascending order, otherwise `sorter` must be an array of indices that + sort it. + v : {dpnp.ndarray, usm_ndarray, scalar} + Values to insert into `a`. + side : {'left', 'right'}, optional + If ``'left'``, the index of the first suitable location found is given. + If ``'right'``, return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `a`). + Default is ``'left'``. + sorter : {dpnp.ndarray, usm_ndarray}, optional + Optional 1-D array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + Out of bound index values of `sorter` array are treated using `"wrap"` + mode documented in :py:func:`dpnp.take`. + Default is ``None``. + + Returns + ------- + indices : dpnp.ndarray + Array of insertion points with the same shape as `v`, + or 0-D array if `v` is a scalar. + + See Also + -------- + :obj:`dpnp.sort` : Return a sorted copy of an array. + :obj:`dpnp.histogram` : Produce histogram from 1-D data. + + Examples + -------- + >>> import dpnp as np + >>> a = np.array([11,12,13,14,15]) + >>> np.searchsorted(a, 13) + array(2) + >>> np.searchsorted(a, 13, side='right') + array(3) + >>> v = np.array([-10, 20, 12, 13]) + >>> np.searchsorted(a, v) + array([0, 5, 1, 2]) + """ - return call_origin(numpy.where, a, v, side, sorter) + usm_a = dpnp.get_usm_ndarray(a) + if dpnp.isscalar(v): + usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) + else: + usm_v = dpnp.get_usm_ndarray(v) + + usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter) + return dpnp_array._create_from_usm_ndarray( + dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) + ) def where(condition, x=None, y=None, /): diff --git a/dpnp/dpnp_iface_sorting.py b/dpnp/dpnp_iface_sorting.py index aaf62fc0995..511b3722eea 100644 --- a/dpnp/dpnp_iface_sorting.py +++ b/dpnp/dpnp_iface_sorting.py @@ -47,14 +47,13 @@ # pylint: disable=no-name-in-module from .dpnp_algo import ( dpnp_partition, - dpnp_searchsorted, ) from .dpnp_array import dpnp_array from .dpnp_utils import ( call_origin, ) -__all__ = ["argsort", "partition", "searchsorted", "sort"] +__all__ = ["argsort", "partition", "sort"] def argsort(a, axis=-1, kind=None, order=None): @@ -189,41 +188,6 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None): return call_origin(numpy.partition, x1, kth, axis, kind, order) -def searchsorted(x1, x2, side="left", sorter=None): - """ - Find indices where elements should be inserted to maintain order. - - For full documentation refer to :obj:`numpy.searchsorted`. - - Limitations - ----------- - Input arrays is supported as :obj:`dpnp.ndarray`. - Input array is supported only sorted. - Input side is supported only values ``left``, ``right``. - Parameter `sorter` is supported only with default values. - - """ - - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) - # pylint: disable=condition-evals-to-constant - if 0 and x1_desc and x2_desc: - if x1_desc.ndim != 1: - pass - elif x1_desc.dtype != x2_desc.dtype: - pass - elif side not in ["left", "right"]: - pass - elif sorter is not None: - pass - elif x1_desc.size < 2: - pass - else: - return dpnp_searchsorted(x1_desc, x2_desc, side=side).get_pyobj() - - return call_origin(numpy.searchsorted, x1, x2, side=side, sorter=sorter) - - def sort(a, axis=-1, kind=None, order=None): """ Return a sorted copy of an array. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 7ee04717abe..b0d7e34fefe 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -701,28 +701,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2 -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}] - -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2] - -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4] - -tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero -tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero - tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1 tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2 diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index c0401fc1a6d..e2b95476da0 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -763,28 +763,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2 -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}] - -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1] -tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2] - -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3] -tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4] - -tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero -tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero - tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1 tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2 diff --git a/tests/test_sort.py b/tests/test_sort.py index 1899604a304..e9e8afb4454 100644 --- a/tests/test_sort.py +++ b/tests/test_sort.py @@ -1,164 +1,343 @@ import numpy import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_equal, assert_raises import dpnp -from .helper import assert_dtype_allclose, get_all_dtypes, get_complex_dtypes +from .helper import ( + assert_dtype_allclose, + get_all_dtypes, + get_complex_dtypes, + get_float_dtypes, +) -class TestSort: +class TestArgsort: @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) - def test_sort_dtype(self, dtype): + def test_argsort_dtype(self, dtype): a = numpy.random.uniform(-5, 5, 10) np_array = numpy.array(a, dtype=dtype) dp_array = dpnp.array(np_array) - result = dpnp.sort(dp_array) - expected = numpy.sort(np_array) + result = dpnp.argsort(dp_array, kind="stable") + expected = numpy.argsort(np_array, kind="stable") assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_complex_dtypes()) - def test_sort_complex(self, dtype): + def test_argsort_complex(self, dtype): a = numpy.random.uniform(-5, 5, 10) b = numpy.random.uniform(-5, 5, 10) np_array = numpy.array(a + b * 1j, dtype=dtype) dp_array = dpnp.array(np_array) - result = dpnp.sort(dp_array) - expected = numpy.sort(np_array) + result = dpnp.argsort(dp_array) + expected = numpy.argsort(np_array) assert_dtype_allclose(result, expected) @pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2]) - def test_sort_axis(self, axis): + def test_argsort_axis(self, axis): a = numpy.random.uniform(-10, 10, 36) np_array = numpy.array(a).reshape(3, 4, 3) dp_array = dpnp.array(np_array) - result = dpnp.sort(dp_array, axis=axis) - expected = numpy.sort(np_array, axis=axis) + result = dpnp.argsort(dp_array, axis=axis) + expected = numpy.argsort(np_array, axis=axis) assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_all_dtypes()) - @pytest.mark.parametrize("axis", [-2, -1, 0, 1]) - def test_sort_ndarray(self, dtype, axis): - a = numpy.random.uniform(-10, 10, 12) + @pytest.mark.parametrize("axis", [None, -2, -1, 0, 1]) + def test_argsort_ndarray(self, dtype, axis): + if dtype and issubclass(dtype, numpy.integer): + a = numpy.random.choice( + numpy.arange(-10, 10), replace=False, size=12 + ) + else: + a = numpy.random.uniform(-10, 10, 12) np_array = numpy.array(a, dtype=dtype).reshape(6, 2) dp_array = dpnp.array(np_array) - dp_array.sort(axis=axis) - np_array.sort(axis=axis) - assert_dtype_allclose(dp_array, np_array) + result = dp_array.argsort(axis=axis) + expected = np_array.argsort(axis=axis) + assert_dtype_allclose(result, expected) - def test_sort_stable(self): + def test_argsort_stable(self): np_array = numpy.repeat(numpy.arange(10), 10) dp_array = dpnp.array(np_array) - result = dpnp.sort(dp_array, kind="stable") - expected = numpy.sort(np_array, kind="stable") + result = dpnp.argsort(dp_array, kind="stable") + expected = numpy.argsort(np_array, kind="stable") assert_dtype_allclose(result, expected) - def test_sort_ndarray_axis_none(self): - a = numpy.random.uniform(-10, 10, 12) - dp_array = dpnp.array(a).reshape(6, 2) - with pytest.raises(TypeError): - dp_array.sort(axis=None) - - def test_sort_zero_dim(self): + def test_argsort_zero_dim(self): np_array = numpy.array(2.5) dp_array = dpnp.array(np_array) # with default axis=-1 with pytest.raises(numpy.AxisError): - dpnp.sort(dp_array) + dpnp.argsort(dp_array) # with axis = None - result = dpnp.sort(dp_array, axis=None) - expected = numpy.sort(np_array, axis=None) + result = dpnp.argsort(dp_array, axis=None) + expected = numpy.argsort(np_array, axis=None) assert_dtype_allclose(result, expected) def test_sort_notimplemented(self): dp_array = dpnp.arange(10) with pytest.raises(NotImplementedError): - dpnp.sort(dp_array, kind="quicksort") + dpnp.argsort(dp_array, kind="quicksort") with pytest.raises(NotImplementedError): - dpnp.sort(dp_array, order=["age"]) + dpnp.argsort(dp_array, order=["age"]) -class TestArgsort: +class TestSearchSorted: + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("dtype", get_float_dtypes(no_float16=False)) + def test_nans_float(self, side, dtype): + a = numpy.array([0, 1, numpy.nan], dtype=dtype) + dp_a = dpnp.array(a) + + result = dp_a.searchsorted(dp_a, side=side) + expected = a.searchsorted(a, side=side) + assert_equal(result, expected) + + result = dpnp.searchsorted(dp_a, dp_a[-1], side=side) + expected = numpy.searchsorted(a, a[-1], side=side) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + def test_nans_complex(self, side, dtype): + a = numpy.zeros(9, dtype=dtype) + a.real += [0, 0, 1, 1, 0, 1, numpy.nan, numpy.nan, numpy.nan] + a.imag += [0, 1, 0, 1, numpy.nan, numpy.nan, 0, 1, numpy.nan] + dp_a = dpnp.array(a) + + result = dp_a.searchsorted(dp_a, side=side) + expected = a.searchsorted(a, side=side) + assert_equal(result, expected) + + @pytest.mark.parametrize("n", range(3)) + @pytest.mark.parametrize("side", ["left", "right"]) + def test_n_elements(self, n, side): + a = numpy.ones(n) + dp_a = dpnp.array(a) + + v = numpy.array([0, 1, 2]) + dp_v = dpnp.array(v) + + result = dp_a.searchsorted(dp_v, side=side) + expected = a.searchsorted(v, side=side) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + def test_smart_resetting(self, side): + a = numpy.arange(5) + dp_a = dpnp.array(a) + + v = numpy.array([6, 5, 4]) + dp_v = dpnp.array(v) + + result = dp_a.searchsorted(dp_v, side=side) + expected = a.searchsorted(v, side=side) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False)) + def test_type_specific(self, side, dtype): + if dtype == numpy.bool_: + a = numpy.arange(2, dtype=dtype) + else: + a = numpy.arange(0, 5, dtype=dtype) + dp_a = dpnp.array(a) + + result = dp_a.searchsorted(dp_a, side=side) + expected = a.searchsorted(a, side=side) + assert_equal(result, expected) + + e = numpy.ndarray(shape=0, buffer=b"", dtype=dtype) + dp_e = dpnp.array(e) + + result = dp_e.searchsorted(dp_a, side=side) + expected = e.searchsorted(a, side=side) + assert_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + def test_sorter(self, dtype): + a = numpy.random.rand(300).astype(dtype) + s = a.argsort() + k = numpy.linspace(0, 1, 20, dtype=dtype) + + dp_a = dpnp.array(a) + dp_s = dpnp.array(s) + dp_k = dpnp.array(k) + + result = dp_a.searchsorted(dp_k, sorter=dp_s) + expected = a.searchsorted(k, sorter=s) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + def test_sorter_with_side(self, side): + a = numpy.array([0, 1, 2, 3, 5] * 20) + s = a.argsort() + k = [0, 1, 2, 3, 5] + + dp_a = dpnp.array(a) + dp_s = dpnp.array(s) + dp_k = dpnp.array(k) + + result = dp_a.searchsorted(dp_k, side=side, sorter=dp_s) + expected = a.searchsorted(k, side=side, sorter=s) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_float16=False)) + def test_sorter_type_specific(self, side, dtype): + if dtype == numpy.bool_: + a = numpy.array([1, 0], dtype=dtype) + # a sorter array to be of a type that is different + # from np.intp in all platforms + s = numpy.array([1, 0], dtype=numpy.int16) + else: + a = numpy.arange(0, 5, dtype=dtype) + # a sorter array to be of a type that is different + # from np.intp in all platforms + s = numpy.array([4, 2, 3, 0, 1], dtype=numpy.int16) + + dp_a = dpnp.array(a) + dp_s = dpnp.array(s) + + result = dp_a.searchsorted(dp_a, side, dp_s) + expected = a.searchsorted(a, side, s) + assert_equal(result, expected) + + @pytest.mark.parametrize("side", ["left", "right"]) + def test_sorter_non_contiguous(self, side): + a = numpy.array([3, 4, 1, 2, 0]) + srt = numpy.empty((10,), dtype=numpy.intp) + srt[1::2] = -1 + srt[::2] = [4, 2, 3, 0, 1] + s = srt[::2] + + dp_a = dpnp.array(a) + dp_s = dpnp.array(s) + + result = dp_a.searchsorted(dp_a, side=side, sorter=dp_s) + expected = a.searchsorted(a, side=side, sorter=s) + assert_equal(result, expected) + + def test_invalid_sorter(self): + for xp in [dpnp, numpy]: + a = xp.array([5, 2, 1, 3, 4]) + + assert_raises( + TypeError, + ValueError, + xp.searchsorted, + a, + 0, + sorter=xp.array([1.1]), + ) + assert_raises( + ValueError, xp.searchsorted, a, 0, sorter=xp.array([1, 2, 3, 4]) + ) + assert_raises( + ValueError, + xp.searchsorted, + a, + 0, + sorter=xp.array([1, 2, 3, 4, 5, 6]), + ) + + def test_v_scalar(self): + v = 0 + a = numpy.array([-8, -5, -1, 3, 6, 10]) + dp_a = dpnp.array(a) + + result = dpnp.searchsorted(dp_a, v) + expected = numpy.searchsorted(a, v) + assert_equal(result, expected) + + +class TestSort: @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) - def test_argsort_dtype(self, dtype): + def test_sort_dtype(self, dtype): a = numpy.random.uniform(-5, 5, 10) np_array = numpy.array(a, dtype=dtype) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array, kind="stable") - expected = numpy.argsort(np_array, kind="stable") + result = dpnp.sort(dp_array) + expected = numpy.sort(np_array) assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_complex_dtypes()) - def test_argsort_complex(self, dtype): + def test_sort_complex(self, dtype): a = numpy.random.uniform(-5, 5, 10) b = numpy.random.uniform(-5, 5, 10) np_array = numpy.array(a + b * 1j, dtype=dtype) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array) - expected = numpy.argsort(np_array) + result = dpnp.sort(dp_array) + expected = numpy.sort(np_array) assert_dtype_allclose(result, expected) @pytest.mark.parametrize("axis", [None, -2, -1, 0, 1, 2]) - def test_argsort_axis(self, axis): + def test_sort_axis(self, axis): a = numpy.random.uniform(-10, 10, 36) np_array = numpy.array(a).reshape(3, 4, 3) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array, axis=axis) - expected = numpy.argsort(np_array, axis=axis) + result = dpnp.sort(dp_array, axis=axis) + expected = numpy.sort(np_array, axis=axis) assert_dtype_allclose(result, expected) @pytest.mark.parametrize("dtype", get_all_dtypes()) - @pytest.mark.parametrize("axis", [None, -2, -1, 0, 1]) - def test_argsort_ndarray(self, dtype, axis): + @pytest.mark.parametrize("axis", [-2, -1, 0, 1]) + def test_sort_ndarray(self, dtype, axis): a = numpy.random.uniform(-10, 10, 12) np_array = numpy.array(a, dtype=dtype).reshape(6, 2) dp_array = dpnp.array(np_array) - result = dp_array.argsort(axis=axis) - expected = np_array.argsort(axis=axis) - assert_dtype_allclose(result, expected) + dp_array.sort(axis=axis) + np_array.sort(axis=axis) + assert_dtype_allclose(dp_array, np_array) - def test_argsort_stable(self): + def test_sort_stable(self): np_array = numpy.repeat(numpy.arange(10), 10) dp_array = dpnp.array(np_array) - result = dpnp.argsort(dp_array, kind="stable") - expected = numpy.argsort(np_array, kind="stable") + result = dpnp.sort(dp_array, kind="stable") + expected = numpy.sort(np_array, kind="stable") assert_dtype_allclose(result, expected) - def test_argsort_zero_dim(self): + def test_sort_ndarray_axis_none(self): + a = numpy.random.uniform(-10, 10, 12) + dp_array = dpnp.array(a).reshape(6, 2) + with pytest.raises(TypeError): + dp_array.sort(axis=None) + + def test_sort_zero_dim(self): np_array = numpy.array(2.5) dp_array = dpnp.array(np_array) # with default axis=-1 with pytest.raises(numpy.AxisError): - dpnp.argsort(dp_array) + dpnp.sort(dp_array) # with axis = None - result = dpnp.argsort(dp_array, axis=None) - expected = numpy.argsort(np_array, axis=None) + result = dpnp.sort(dp_array, axis=None) + expected = numpy.sort(np_array, axis=None) assert_dtype_allclose(result, expected) def test_sort_notimplemented(self): dp_array = dpnp.arange(10) with pytest.raises(NotImplementedError): - dpnp.argsort(dp_array, kind="quicksort") + dpnp.sort(dp_array, kind="quicksort") with pytest.raises(NotImplementedError): - dpnp.argsort(dp_array, order=["age"]) + dpnp.sort(dp_array, order=["age"]) @pytest.mark.parametrize("kth", [0, 1], ids=["0", "1"]) @@ -191,61 +370,3 @@ def test_partition(array, dtype, kth): assert (p[..., 0:kth] <= p[..., kth : kth + 1]).all() assert (p[..., kth : kth + 1] <= p[..., kth + 1 :]).all() - - -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@pytest.mark.parametrize("side", ["left", "right"], ids=['"left"', '"right"']) -@pytest.mark.parametrize( - "v_", - [ - [[3, 4], [2, 1]], - [[1, 0], [3, 0]], - [[3, 2, 1, 6]], - [[4, 2], [3, 3], [4, 1]], - [[1, -3, 3], [0, 5, 2], [0, 1, 1], [0, 0, 1]], - [ - [[[8, 2], [3, 0]], [[5, 2], [0, 1]]], - [[[1, 3], [3, 1]], [[5, 2], [0, 1]]], - ], - ], - ids=[ - "[[3, 4], [2, 1]]", - "[[1, 0], [3, 0]]", - "[[3, 2, 1, 6]]", - "[[4, 2], [3, 3], [4, 1]]", - "[[1, -3, 3], [0, 5, 2], [0, 1, 1], [0, 0, 1]]", - "[[[[8, 2], [3, 0]], [[5, 2], [0, 1]]], [[[1, 3], [3, 1]], [[5, 2], [0, 1]]]]", - ], -) -@pytest.mark.parametrize( - "dtype", get_all_dtypes(no_none=True, no_bool=True, no_complex=True) -) -@pytest.mark.parametrize( - "array", - [ - [1, 2, 3, 4], - [-5, -1, 0, 3, 17, 100], - [1, 0, 3, 0], - [3, 2, 1, 6], - [4, 2, 3, 3, 4, 1], - [1, -3, 3, 0, 5, 2, 0, 1, 1, 0, 0, 1], - [8, 2, 3, 0, 5, 2, 0, 1, 1, 3, 3, 1, 5, 2, 0, 1], - ], - ids=[ - "[1, 2, 3, 4]", - "[-5, -1, 0, 3, 17, 100]", - "[1, 0, 3, 0]", - "[3, 2, 1, 6]", - "[4, 2, 3, 3, 4, 1]", - "[1, -3, 3, 0, 5, 2, 0, 1, 1, 0, 0, 1]", - "[8, 2, 3, 0, 5, 2, 0, 1, 1, 3, 3, 1, 5, 2, 0, 1]", - ], -) -def test_searchsorted(array, dtype, v_, side): - a = numpy.array(array, dtype) - ia = dpnp.array(array, dtype) - v = numpy.array(v_, dtype) - iv = dpnp.array(v_, dtype) - expected = numpy.searchsorted(a, v, side=side) - result = dpnp.searchsorted(ia, iv, side=side) - assert_array_equal(expected, result) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 06ca1462525..acf2801faf9 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -631,6 +631,7 @@ def test_reduce_hypot(device): [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0], ), + pytest.param("searchsorted", [11, 12, 13, 14, 15], [-10, 20, 12, 13]), pytest.param( "subtract", [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], @@ -668,6 +669,28 @@ def test_2in_1out(func, data1, data2, device): assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue) +@pytest.mark.parametrize( + "func, data, scalar", + [ + pytest.param("searchsorted", [11, 12, 13, 14, 15], 13), + ], +) +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_2in_with_scalar_1out(func, data, scalar, device): + x1_orig = numpy.array(data) + expected = getattr(numpy, func)(x1_orig, scalar) + + x1 = dpnp.array(data, device=device) + result = getattr(dpnp, func)(x1, scalar) + + assert_allclose(result, expected) + assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue) + + @pytest.mark.parametrize( "func,data1,data2", [ diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index cac0fe33ae0..4592bc91484 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -575,6 +575,7 @@ def test_1in_1out(func, data, usm_type): pytest.param("logaddexp", [[-1, 2, 5, 9]], [[4, -3, 2, -8]]), pytest.param("maximum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), pytest.param("minimum", [[0.0, 1.0, 2.0]], [[3.0, 4.0, 5.0]]), + pytest.param("searchsorted", [11, 12, 13, 14, 15], [-10, 20, 12, 13]), pytest.param( "tensordot", [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], @@ -599,6 +600,19 @@ def test_2in_1out(func, data1, data2, usm_type_x, usm_type_y): assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y]) +@pytest.mark.parametrize( + "func, data, scalar", + [ + pytest.param("searchsorted", [11, 12, 13, 14, 15], 13), + ], +) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_2in_with_scalar_1out(func, data, scalar, usm_type): + x = dp.array(data, usm_type=usm_type) + z = getattr(dp, func)(x, scalar) + assert z.usm_type == usm_type + + @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) def test_broadcast_to(usm_type): x = dp.ones(7, usm_type=usm_type) diff --git a/tests/third_party/cupy/sorting_tests/test_search.py b/tests/third_party/cupy/sorting_tests/test_search.py index 875fa071f1a..7f947498e64 100644 --- a/tests/third_party/cupy/sorting_tests/test_search.py +++ b/tests/third_party/cupy/sorting_tests/test_search.py @@ -1,13 +1,10 @@ -import unittest - import numpy import pytest import dpnp as cupy +from tests.helper import has_support_aspect64 from tests.third_party.cupy import testing -# from cupy.core import _accelerator - class TestSearch: @testing.for_all_dtypes(no_complex=True) @@ -84,6 +81,12 @@ def test_argmax_zero_size_axis1(self, xp, dtype): a = testing.shaped_random((0, 1), xp, dtype) return a.argmax(axis=1) + @testing.slow + @pytest.mark.skip("slow mark is not implemented") + def test_argmax_int32_overflow(self): + a = testing.shaped_arange((2**32 + 1,), cupy, numpy.float64) + assert a.argmax().item() == 2**32 + @testing.for_all_dtypes(no_complex=True) @testing.numpy_cupy_allclose() def test_argmin_all(self, xp, dtype): @@ -158,62 +161,114 @@ def test_argmin_zero_size_axis1(self, xp, dtype): a = testing.shaped_random((0, 1), xp, dtype) return a.argmin(axis=1) + @testing.slow + @pytest.mark.skip("slow mark is not implemented") + def test_argmin_int32_overflow(self): + a = testing.shaped_arange((2**32 + 1,), cupy, numpy.float64) + cupy.negative(a, out=a) + assert a.argmin().item() == 2**32 + # This class compares CUB results against NumPy's # TODO(leofang): test axis after support is added -# @testing.parameterize(*testing.product({ -# 'shape': [(10,), (10, 20), (10, 20, 30), (10, 20, 30, 40)], -# 'order': ('C', 'F'), -# })) -# @unittest.skipUnless(cupy.cuda.cub.available, 'The CUB routine is not enabled') -# class TestCubReduction(unittest.TestCase): - -# def setUp(self): -# self.old_accelerators = _accelerator.get_routine_accelerators() -# _accelerator.set_routine_accelerators(['cub']) - -# def tearDown(self): -# _accelerator.set_routine_accelerators(self.old_accelerators) - -# @testing.for_dtypes('bhilBHILefdFD') -# @testing.numpy_cupy_allclose(rtol=1E-5) -# def test_cub_argmin(self, xp, dtype): -# a = testing.shaped_random(self.shape, xp, dtype) -# if self.order == 'C': -# a = xp.ascontiguousarray(a) -# else: -# a = xp.asfortranarray(a) - -# if xp is numpy: -# return a.argmin() - -# # xp is cupy, first ensure we really use CUB -# ret = cupy.empty(()) # Cython checks return type, need to fool it -# func = 'cupy.core._routines_statistics.cub.device_reduce' -# with testing.AssertFunctionIsCalled(func, return_value=ret): -# a.argmin() -# # ...then perform the actual computation -# return a.argmin() - -# @testing.for_dtypes('bhilBHILefdFD') -# @testing.numpy_cupy_allclose(rtol=1E-5) -# def test_cub_argmax(self, xp, dtype): -# a = testing.shaped_random(self.shape, xp, dtype) -# if self.order == 'C': -# a = xp.ascontiguousarray(a) -# else: -# a = xp.asfortranarray(a) - -# if xp is numpy: -# return a.argmax() - -# # xp is cupy, first ensure we really use CUB -# ret = cupy.empty(()) # Cython checks return type, need to fool it -# func = 'cupy.core._routines_statistics.cub.device_reduce' -# with testing.AssertFunctionIsCalled(func, return_value=ret): -# a.argmax() -# # ...then perform the actual computation -# return a.argmax() +@testing.parameterize( + *testing.product( + { + "shape": [(10,), (10, 20), (10, 20, 30), (10, 20, 30, 40)], + "order_and_axis": (("C", -1), ("C", None), ("F", 0), ("F", None)), + "backend": ("device", "block"), + } + ) +) +@pytest.mark.skip("The CUB routine is not enabled") +class TestCubReduction: + @pytest.fixture(autouse=True) + def setUp(self): + self.order, self.axis = self.order_and_axis + old_routine_accelerators = _acc.get_routine_accelerators() + old_reduction_accelerators = _acc.get_reduction_accelerators() + if self.backend == "device": + if self.axis is not None: + pytest.skip("does not support") + _acc.set_routine_accelerators(["cub"]) + _acc.set_reduction_accelerators([]) + elif self.backend == "block": + _acc.set_routine_accelerators([]) + _acc.set_reduction_accelerators(["cub"]) + yield + _acc.set_routine_accelerators(old_routine_accelerators) + _acc.set_reduction_accelerators(old_reduction_accelerators) + + @testing.for_dtypes("bhilBHILefdFD") + @testing.numpy_cupy_allclose(rtol=1e-5, contiguous_check=False) + def test_cub_argmin(self, xp, dtype): + a = testing.shaped_random(self.shape, xp, dtype) + if self.order == "C": + a = xp.ascontiguousarray(a) + else: + a = xp.asfortranarray(a) + + if xp is numpy: + return a.argmin(axis=self.axis) + + # xp is cupy, first ensure we really use CUB + ret = cupy.empty(()) # Cython checks return type, need to fool it + if self.backend == "device": + func_name = "cupy._core._routines_statistics.cub." + func_name += "device_reduce" + with testing.AssertFunctionIsCalled(func_name, return_value=ret): + a.argmin(axis=self.axis) + elif self.backend == "block": + # this is the only function we can mock; the rest is cdef'd + func_name = "cupy._core._cub_reduction." + func_name += "_SimpleCubReductionKernel_get_cached_function" + # func = _cub_reduction._SimpleCubReductionKernel_get_cached_function + if self.axis is not None and len(self.shape) > 1: + times_called = 1 # one pass + else: + times_called = 2 # two passes + with testing.AssertFunctionIsCalled( + func_name, wraps=func, times_called=times_called + ): + a.argmin(axis=self.axis) + # ...then perform the actual computation + return a.argmin(axis=self.axis) + + @testing.for_dtypes("bhilBHILefdFD") + @testing.numpy_cupy_allclose(rtol=1e-5, contiguous_check=False) + def test_cub_argmax(self, xp, dtype): + # _skip_cuda90(dtype) + a = testing.shaped_random(self.shape, xp, dtype) + if self.order == "C": + a = xp.ascontiguousarray(a) + else: + a = xp.asfortranarray(a) + + if xp is numpy: + return a.argmax(axis=self.axis) + + # xp is cupy, first ensure we really use CUB + ret = cupy.empty(()) # Cython checks return type, need to fool it + if self.backend == "device": + func_name = "cupy._core._routines_statistics.cub." + func_name += "device_reduce" + with testing.AssertFunctionIsCalled(func_name, return_value=ret): + a.argmax(axis=self.axis) + elif self.backend == "block": + # this is the only function we can mock; the rest is cdef'd + func_name = "cupy._core._cub_reduction." + func_name += "_SimpleCubReductionKernel_get_cached_function" + # func = _cub_reduction._SimpleCubReductionKernel_get_cached_function + if self.axis is not None and len(self.shape) > 1: + times_called = 1 # one pass + else: + times_called = 2 # two passes + with testing.AssertFunctionIsCalled( + func_name, wraps=func, times_called=times_called + ): + a.argmax(axis=self.axis) + # ...then perform the actual computation + return a.argmax(axis=self.axis) @testing.parameterize( @@ -225,6 +280,7 @@ def test_argmin_zero_size_axis1(self, xp, dtype): } ) ) +@pytest.mark.skip("dtype is not supported") class TestArgMinMaxDtype: @testing.for_dtypes( dtypes=[numpy.int8, numpy.int16, numpy.int32, numpy.int64], @@ -249,9 +305,9 @@ def test_argminmax_dtype(self, in_dtype, result_dtype): {"cond_shape": (2, 3, 4), "x_shape": (2, 3, 4), "y_shape": (3, 4)}, {"cond_shape": (3, 4), "x_shape": (2, 3, 4), "y_shape": (4,)}, ) -class TestWhereTwoArrays(unittest.TestCase): +class TestWhereTwoArrays: @testing.for_all_dtypes_combination(names=["cond_type", "x_type", "y_type"]) - @testing.numpy_cupy_allclose(type_check=False) + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_where_two_arrays(self, xp, cond_type, x_type, y_type): m = testing.shaped_random(self.cond_shape, xp, xp.bool_) # Almost all values of a matrix `shaped_random` makes are not zero. @@ -268,7 +324,7 @@ def test_where_two_arrays(self, xp, cond_type, x_type, y_type): {"cond_shape": (2, 3, 4)}, {"cond_shape": (3, 4)}, ) -class TestWhereCond(unittest.TestCase): +class TestWhereCond: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_where_cond(self, xp, dtype): @@ -277,7 +333,7 @@ def test_where_cond(self, xp, dtype): return xp.where(cond) -class TestWhereError(unittest.TestCase): +class TestWhereError: def test_one_argument(self): for xp in (numpy, cupy): cond = testing.shaped_random((3, 4), xp, dtype=xp.bool_) @@ -287,6 +343,8 @@ def test_one_argument(self): @testing.parameterize( + {"array": numpy.random.randint(0, 2, (20,))}, + {"array": numpy.random.randn(3, 2, 4)}, {"array": numpy.empty((0,))}, {"array": numpy.empty((0, 2))}, {"array": numpy.empty((0, 2, 0))}, @@ -304,17 +362,20 @@ def test_nonzero(self, xp, dtype): {"array": numpy.array(0)}, {"array": numpy.array(1)}, ) +@pytest.mark.skip("Only positive rank is supported") @testing.with_requires("numpy>=1.17.0") -class TestNonzeroZeroDimension(unittest.TestCase): +class TestNonzeroZeroDimension: @testing.for_all_dtypes() - def test_nonzero(self, dtype): - for xp in (numpy, cupy): - array = xp.array(self.array, dtype=dtype) - with pytest.raises(DeprecationWarning): - xp.nonzero(array) + @testing.numpy_cupy_array_equal() + def test_nonzero(self, xp, dtype): + array = xp.array(self.array, dtype=dtype) + with testing.assert_warns(DeprecationWarning): + return xp.nonzero(array) @testing.parameterize( + {"array": numpy.random.randint(0, 2, (20,))}, + {"array": numpy.random.randn(3, 2, 4)}, {"array": numpy.array(0)}, {"array": numpy.array(1)}, {"array": numpy.empty((0,))}, @@ -322,6 +383,7 @@ def test_nonzero(self, dtype): {"array": numpy.empty((0, 2, 0))}, _ids=False, # Do not generate ids from randomly generated params ) +@pytest.mark.skip("flatnonzero isn't implemented yet") class TestFlatNonzero: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() @@ -331,11 +393,14 @@ def test_flatnonzero(self, xp, dtype): @testing.parameterize( + {"array": numpy.random.randint(0, 2, (20,))}, + {"array": numpy.random.randn(3, 2, 4)}, {"array": numpy.empty((0,))}, {"array": numpy.empty((0, 2))}, {"array": numpy.empty((0, 2, 0))}, _ids=False, # Do not generate ids from randomly generated params ) +@pytest.mark.skip("argwhere isn't implemented yet") class TestArgwhere: @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() @@ -344,19 +409,18 @@ def test_argwhere(self, xp, dtype): return xp.argwhere(array) -# DPNP_BUG -# dpnp/backend.pyx:86: in dpnp.backend.dpnp_array -# raise TypeError(f"Intel NumPy array(): Unsupported non-sequence obj={type(obj)}") -# E TypeError: Intel NumPy array(): Unsupported non-sequence obj= -# @testing.parameterize( -# {'array': cupy.array(1)}, -# ) - -# class TestArgwhereZeroDimension(unittest.TestCase): - -# def test_argwhere(self): -# with testing.assert_warns(DeprecationWarning): -# return cupy.nonzero(self.array) +@testing.parameterize( + {"value": 0}, + {"value": 3}, +) +@pytest.mark.skip("argwhere isn't implemented yet") +@testing.with_requires("numpy>=1.18") +class TestArgwhereZeroDimension: + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_argwhere(self, xp, dtype): + array = xp.array(self.value, dtype=dtype) + return xp.argwhere(array) class TestNanArgMin: @@ -560,8 +624,7 @@ def test_nanargmax_zero_size_axis1(self, xp, dtype): } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestSearchSorted(unittest.TestCase): +class TestSearchSorted: @testing.for_all_dtypes(no_bool=True) @testing.numpy_cupy_array_equal() def test_searchsorted(self, xp, dtype): @@ -570,10 +633,17 @@ def test_searchsorted(self, xp, dtype): y = xp.searchsorted(bins, x, side=self.side) return (y,) + @testing.for_all_dtypes(no_bool=True) + @testing.numpy_cupy_array_equal() + def test_ndarray_searchsorted(self, xp, dtype): + x = testing.shaped_arange(self.shape, xp, dtype) + bins = xp.array(self.bins) + y = bins.searchsorted(x, side=self.side) + return (y,) + @testing.parameterize({"side": "left"}, {"side": "right"}) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestSearchSortedNanInf(unittest.TestCase): +class TestSearchSortedNanInf: @testing.numpy_cupy_array_equal() def test_searchsorted_nanbins(self, xp): x = testing.shaped_arange((10,), xp, xp.float64) @@ -589,33 +659,37 @@ def test_searchsorted_nan(self, xp): y = xp.searchsorted(bins, x, side=self.side) return (y,) - # DPNP_BUG - # Segmentation fault on access to negative index # x[-1] = float('nan') ####### - # @testing.numpy_cupy_array_equal() - # def test_searchsorted_nan_last(self, xp): - # x = testing.shaped_arange((10,), xp, xp.float64) - # x[-1] = float('nan') - # bins = xp.array([0, 1, 2, 4, float('nan')]) - # y = xp.searchsorted(bins, x, side=self.side) - # return y, - - # @testing.numpy_cupy_array_equal() - # def test_searchsorted_nan_last_repeat(self, xp): - # x = testing.shaped_arange((10,), xp, xp.float64) - # x[-1] = float('nan') - # bins = xp.array([0, 1, 2, float('nan'), float('nan')]) - # y = xp.searchsorted(bins, x, side=self.side) - # return y, - - # @testing.numpy_cupy_array_equal() - # def test_searchsorted_all_nans(self, xp): - # x = testing.shaped_arange((10,), xp, xp.float64) - # x[-1] = float('nan') - # bins = xp.array([float('nan'), float('nan'), float('nan'), - # float('nan'), float('nan')]) - # y = xp.searchsorted(bins, x, side=self.side) - # return y, - ############################################################################### + @testing.numpy_cupy_array_equal() + def test_searchsorted_nan_last(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array([0, 1, 2, 4, float("nan")]) + y = xp.searchsorted(bins, x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_nan_last_repeat(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array([0, 1, 2, float("nan"), float("nan")]) + y = xp.searchsorted(bins, x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_all_nans(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array( + [ + float("nan"), + float("nan"), + float("nan"), + float("nan"), + float("nan"), + ] + ) + y = xp.searchsorted(bins, x, side=self.side) + return (y,) @testing.numpy_cupy_array_equal() def test_searchsorted_inf(self, xp): @@ -634,8 +708,7 @@ def test_searchsorted_minf(self, xp): return (y,) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestSearchSortedInvalid(unittest.TestCase): +class TestSearchSortedInvalid: # Cant test unordered bins due to numpy undefined # behavior for searchsorted @@ -646,9 +719,15 @@ def test_searchsorted_ndbins(self): with pytest.raises(ValueError): xp.searchsorted(bins, x) + def test_ndarray_searchsorted_ndbins(self): + for xp in (numpy, cupy): + x = testing.shaped_arange((10,), xp, xp.float64) + bins = xp.array([[10, 4], [2, 1], [7, 8]]) + with pytest.raises(ValueError): + bins.searchsorted(x) + -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestSearchSortedWithSorter(unittest.TestCase): +class TestSearchSortedWithSorter: @testing.numpy_cupy_array_equal() def test_sorter(self, xp): x = testing.shaped_arange((12,), xp, xp.float64) @@ -667,8 +746,102 @@ def test_invalid_sorter(self): def test_nonint_sorter(self): for xp in (numpy, cupy): - x = testing.shaped_arange((12,), xp, xp.float32) + dt = cupy.default_float_type() + x = testing.shaped_arange((12,), xp, dt) bins = xp.array([10, 4, 2, 1, 8]) - sorter = xp.array([], dtype=xp.float32) - with pytest.raises(TypeError): + sorter = xp.array([], dtype=dt) + with pytest.raises((TypeError, ValueError)): xp.searchsorted(bins, x, sorter=sorter) + + +@testing.parameterize({"side": "left"}, {"side": "right"}) +class TestNdarraySearchSortedNanInf: + @testing.numpy_cupy_array_equal() + def test_searchsorted_nanbins(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + bins = xp.array([0, 1, 2, 4, 10, float("nan")]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_nan(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[5] = float("nan") + bins = xp.array([0, 1, 2, 4, 10]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_nan_last(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array([0, 1, 2, 4, float("nan")]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_nan_last_repeat(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array([0, 1, 2, float("nan"), float("nan")]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_all_nans(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[-1] = float("nan") + bins = xp.array( + [ + float("nan"), + float("nan"), + float("nan"), + float("nan"), + float("nan"), + ] + ) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_inf(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[5] = float("inf") + bins = xp.array([0, 1, 2, 4, 10]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + @testing.numpy_cupy_array_equal() + def test_searchsorted_minf(self, xp): + x = testing.shaped_arange((10,), xp, xp.float64) + x[5] = float("-inf") + bins = xp.array([0, 1, 2, 4, 10]) + y = bins.searchsorted(x, side=self.side) + return (y,) + + +class TestNdarraySearchSortedWithSorter: + @testing.numpy_cupy_array_equal() + def test_sorter(self, xp): + x = testing.shaped_arange((12,), xp, xp.float64) + bins = xp.array([10, 4, 2, 1, 8]) + sorter = xp.array([3, 2, 1, 4, 0]) + y = bins.searchsorted(x, sorter=sorter) + return (y,) + + def test_invalid_sorter(self): + for xp in (numpy, cupy): + x = testing.shaped_arange((12,), xp, xp.float64) + bins = xp.array([10, 4, 2, 1, 8]) + sorter = xp.array([0]) + with pytest.raises(ValueError): + bins.searchsorted(x, sorter=sorter) + + def test_nonint_sorter(self): + for xp in (numpy, cupy): + dt = cupy.default_float_type() + x = testing.shaped_arange((12,), xp, dt) + bins = xp.array([10, 4, 2, 1, 8]) + sorter = xp.array([], dtype=dt) + with pytest.raises((TypeError, ValueError)): + bins.searchsorted(x, sorter=sorter)