Skip to content

Commit

Permalink
Implement dpnp.searchsorted (IntelPython#1751)
Browse files Browse the repository at this point in the history
* Implement dpnp.searchsorted

* Stated explicit wrapping behavior of out of bound values of sorter per dpctl spec

* Muted previously enabled overflow tests

* Corrected test_argsort_ndarray test with unique generated random values
  • Loading branch information
antonwolfy authored Mar 26, 2024
1 parent 726738d commit 09e7e33
Show file tree
Hide file tree
Showing 13 changed files with 634 additions and 394 deletions.
12 changes: 5 additions & 7 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,11 @@ enum class DPNPFuncName : size_t
DPNP_FN_RNG_ZIPF_EXT, /**< Used in numpy.random.zipf() impl, requires extra
parameters */
DPNP_FN_SEARCHSORTED, /**< Used in numpy.searchsorted() impl */
DPNP_FN_SEARCHSORTED_EXT, /**< Used in numpy.searchsorted() impl, requires
extra parameters */
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SIGN, /**< Used in numpy.sign() impl */
DPNP_FN_SIN, /**< Used in numpy.sin() impl */
DPNP_FN_SINH, /**< Used in numpy.sinh() impl */
DPNP_FN_SORT, /**< Used in numpy.sort() impl */
DPNP_FN_SQRT, /**< Used in numpy.sqrt() impl */
DPNP_FN_SQRT_EXT, /**< Used in numpy.sqrt() impl, requires extra parameters
*/
DPNP_FN_SQUARE, /**< Used in numpy.square() impl */
Expand Down
20 changes: 0 additions & 20 deletions dpnp/backend/kernels/dpnp_krnl_sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,6 @@ void (*dpnp_searchsorted_default_c)(void *,
const size_t) =
dpnp_searchsorted_c<_DataType, _IndexingType>;

template <typename _DataType, typename _IndexingType>
DPCTLSyclEventRef (*dpnp_searchsorted_ext_c)(DPCTLSyclQueueRef,
void *,
const void *,
const void *,
bool,
const size_t,
const size_t,
const DPCTLEventVectorRef) =
dpnp_searchsorted_c<_DataType, _IndexingType>;

template <typename _DataType>
class dpnp_sort_c_kernel;

Expand Down Expand Up @@ -507,15 +496,6 @@ void func_map_init_sorting(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_searchsorted_default_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_searchsorted_ext_c<int32_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_searchsorted_ext_c<int64_t, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_searchsorted_ext_c<float, int64_t>};
fmap[DPNPFuncName::DPNP_FN_SEARCHSORTED_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_searchsorted_ext_c<double, int64_t>};

fmap[DPNPFuncName::DPNP_FN_SORT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_sort_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_SORT][eft_LNG][eft_LNG] = {
Expand Down
1 change: 0 additions & 1 deletion dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
DPNP_FN_RNG_WALD_EXT
DPNP_FN_RNG_WEIBULL_EXT
DPNP_FN_RNG_ZIPF_EXT
DPNP_FN_SEARCHSORTED_EXT
DPNP_FN_TRACE_EXT
DPNP_FN_TRAPZ_EXT

Expand Down
50 changes: 0 additions & 50 deletions dpnp/dpnp_algo/dpnp_algo_sorting.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ and the rest of the library

__all__ += [
"dpnp_partition",
"dpnp_searchsorted",
]


Expand All @@ -49,14 +48,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_partition_t)(c_dpctl.DPCTLSyclQueu
const shape_elem_type * ,
const size_t,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_searchsorted_t)(c_dpctl.DPCTLSyclQueueRef,
void * ,
const void * ,
const void * ,
bool,
const size_t,
const size_t,
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
Expand Down Expand Up @@ -98,44 +89,3 @@ cpdef utils.dpnp_descriptor dpnp_partition(utils.dpnp_descriptor arr, int kth, a
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_searchsorted(utils.dpnp_descriptor arr, utils.dpnp_descriptor v, side='left'):
if side is 'left':
side_ = True
else:
side_ = False

cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)

cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SEARCHSORTED_EXT, param1_type, param1_type)

arr_obj = arr.get_array()

cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py(v.shape,
dpnp.int64,
None,
device=arr_obj.sycl_device,
usm_type=arr_obj.usm_type,
sycl_queue=arr_obj.sycl_queue)

result_sycl_queue = result.get_array().sycl_queue

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_dpnp_searchsorted_t func = <fptr_dpnp_searchsorted_t > kernel_data.ptr

cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref,
arr.get_data(),
v.get_data(),
result.get_data(),
side_,
arr.size,
v.size,
NULL) # dep_events_ref

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result
12 changes: 11 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,17 @@ def round(self, decimals=0, out=None):

return dpnp.around(self, decimals, out)

# 'searchsorted',
def searchsorted(self, v, side="left", sorter=None):
"""
Find indices where elements of `v` should be inserted in `a`
to maintain order.
Refer to :obj:`dpnp.searchsorted` for full documentation
"""

return dpnp.searchsorted(self, v, side=side, sorter=sorter)

# 'setfield',
# 'setflags',

Expand Down
54 changes: 53 additions & 1 deletion dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,61 @@ def searchsorted(a, v, side="left", sorter=None):
For full documentation refer to :obj:`numpy.searchsorted`.
Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
Input 1-D array. If `sorter` is ``None``, then it must be sorted in
ascending order, otherwise `sorter` must be an array of indices that
sort it.
v : {dpnp.ndarray, usm_ndarray, scalar}
Values to insert into `a`.
side : {'left', 'right'}, optional
If ``'left'``, the index of the first suitable location found is given.
If ``'right'``, return the last such index. If there is no suitable
index, return either 0 or N (where N is the length of `a`).
Default is ``'left'``.
sorter : {dpnp.ndarray, usm_ndarray}, optional
Optional 1-D array of integer indices that sort array a into ascending
order. They are typically the result of argsort.
Out of bound index values of `sorter` array are treated using `"wrap"`
mode documented in :py:func:`dpnp.take`.
Default is ``None``.
Returns
-------
indices : dpnp.ndarray
Array of insertion points with the same shape as `v`,
or 0-D array if `v` is a scalar.
See Also
--------
:obj:`dpnp.sort` : Return a sorted copy of an array.
:obj:`dpnp.histogram` : Produce histogram from 1-D data.
Examples
--------
>>> import dpnp as np
>>> a = np.array([11,12,13,14,15])
>>> np.searchsorted(a, 13)
array(2)
>>> np.searchsorted(a, 13, side='right')
array(3)
>>> v = np.array([-10, 20, 12, 13])
>>> np.searchsorted(a, v)
array([0, 5, 1, 2])
"""

return call_origin(numpy.where, a, v, side, sorter)
usm_a = dpnp.get_usm_ndarray(a)
if dpnp.isscalar(v):
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
else:
usm_v = dpnp.get_usm_ndarray(v)

usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
return dpnp_array._create_from_usm_ndarray(
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
)


def where(condition, x=None, y=None, /):
Expand Down
38 changes: 1 addition & 37 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@
# pylint: disable=no-name-in-module
from .dpnp_algo import (
dpnp_partition,
dpnp_searchsorted,
)
from .dpnp_array import dpnp_array
from .dpnp_utils import (
call_origin,
)

__all__ = ["argsort", "partition", "searchsorted", "sort"]
__all__ = ["argsort", "partition", "sort"]


def argsort(a, axis=-1, kind=None, order=None):
Expand Down Expand Up @@ -189,41 +188,6 @@ def partition(x1, kth, axis=-1, kind="introselect", order=None):
return call_origin(numpy.partition, x1, kth, axis, kind, order)


def searchsorted(x1, x2, side="left", sorter=None):
"""
Find indices where elements should be inserted to maintain order.
For full documentation refer to :obj:`numpy.searchsorted`.
Limitations
-----------
Input arrays is supported as :obj:`dpnp.ndarray`.
Input array is supported only sorted.
Input side is supported only values ``left``, ``right``.
Parameter `sorter` is supported only with default values.
"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False)
# pylint: disable=condition-evals-to-constant
if 0 and x1_desc and x2_desc:
if x1_desc.ndim != 1:
pass
elif x1_desc.dtype != x2_desc.dtype:
pass
elif side not in ["left", "right"]:
pass
elif sorter is not None:
pass
elif x1_desc.size < 2:
pass
else:
return dpnp_searchsorted(x1_desc, x2_desc, side=side).get_pyobj()

return call_origin(numpy.searchsorted, x1, x2, side=side, sorter=sorter)


def sort(a, axis=-1, kind=None, order=None):
"""
Return a sorted copy of an array.
Expand Down
22 changes: 0 additions & 22 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -701,28 +701,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2

tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]

tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]

tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]

tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero

tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2
Expand Down
22 changes: 0 additions & 22 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -763,28 +763,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2

tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_0_{func='argmin', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_1_{func='argmin', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_2_{func='argmin', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_3_{func='argmin', is_module=False, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_4_{func='argmax', is_module=True, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_5_{func='argmax', is_module=True, shape=()}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_6_{func='argmax', is_module=False, shape=(3, 4)}]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgMinMaxDtype::test_argminmax_dtype[_param_7_{func='argmax', is_module=False, shape=()}]

tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestArgwhere::test_argwhere[2]

tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[0]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[1]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[2]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[3]
tests/third_party/cupy/sorting_tests/test_search.py::TestFlatNonzero::test_flatnonzero[4]

tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_0_{array=array(0)}::test_nonzero
tests/third_party/cupy/sorting_tests/test_search.py::TestNonzeroZeroDimension_param_1_{array=array(1)}::test_nonzero

tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_axis
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis1
tests/third_party/cupy/sorting_tests/test_sort.py::TestArgpartition_param_0_{external=False}::test_argpartition_invalid_axis2
Expand Down
Loading

0 comments on commit 09e7e33

Please sign in to comment.