Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.extract implementation to get rid of limitations for input arguments #1906

Merged
merged 7 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/reference/sorting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ Searching
dpnp.nanargmax
dpnp.argmin
dpnp.nanargmin
dpnp.argwhere
dpnp.nonzero
dpnp.flatnonzero
dpnp.where
dpnp.argwhere
dpnp.searchsorted
dpnp.extract

Expand Down
92 changes: 68 additions & 24 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,42 +490,86 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
)


def extract(condition, x):
def extract(condition, a):
"""
Return the elements of an array that satisfy some condition.

This is equivalent to
``dpnp.compress(dpnp.ravel(condition), dpnp.ravel(a))``. If `condition`
is boolean :obj:`dpnp.extract` is equivalent to ``a[condition]``.

Note that :obj:`dpnp.place` does the exact opposite of :obj:`dpnp.extract`.

For full documentation refer to :obj:`numpy.extract`.

Parameters
----------
condition : {array_like, scalar}
An array whose non-zero or ``True`` entries indicate the element of `a`
to extract.
a : {dpnp_array, usm_ndarray}
Input array of the same size as `condition`.

Returns
-------
out : dpnp.ndarray
Rank 1 array of values from `x` where `condition` is True.
Rank 1 array of values from `a` where `condition` is ``True``.

See Also
--------
:obj:`dpnp.take` : Take elements from an array along an axis.
:obj:`dpnp.put` : Replaces specified elements of an array with given values.
:obj:`dpnp.copyto` : Copies values from one array to another, broadcasting
as necessary.
:obj:`dpnp.compress` : eturn selected slices of an array along given axis.
:obj:`dpnp.place` : Change elements of an array based on conditional and
input values.

Examples
--------
>>> import dpnp as np
>>> a = np.arange(12).reshape((3, 4))
>>> a
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> condition = np.mod(a, 3) == 0
>>> condition
array([[ True, False, False, True],
[False, False, True, False],
[False, True, False, False]])
>>> np.extract(condition, a)
array([0, 3, 6, 9])

If `condition` is boolean:

>>> a[condition]
array([0, 3, 6, 9])

Limitations
-----------
Parameters `condition` and `x` are supported either as
:class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Parameter `x` must be the same shape as `condition`.
Otherwise the function will be executed sequentially on CPU.
"""

if dpnp.is_supported_array_type(condition) and dpnp.is_supported_array_type(
x
):
if condition.shape != x.shape:
pass
else:
dpt_condition = (
condition.get_array()
if isinstance(condition, dpnp_array)
else condition
)
dpt_array = x.get_array() if isinstance(x, dpnp_array) else x
return dpnp_array._create_from_usm_ndarray(
dpt.extract(dpt_condition, dpt_array)
)
usm_a = dpnp.get_usm_ndarray(a)
if not dpnp.is_supported_array_type(condition):
usm_cond = dpt.asarray(
condition, usm_type=a.usm_type, sycl_queue=a.sycl_queue
)
else:
usm_cond = dpnp.get_usm_ndarray(condition)

if usm_cond.size != usm_a.size:
usm_a = dpt.reshape(usm_a, -1)
usm_cond = dpt.reshape(usm_cond, -1)

usm_res = dpt.take(usm_a, dpt.nonzero(usm_cond)[0])
else:
if usm_cond.shape != usm_a.shape:
usm_a = dpt.reshape(usm_a, -1)
usm_cond = dpt.reshape(usm_cond, -1)

usm_res = dpt.extract(usm_cond, usm_a)

return call_origin(numpy.extract, condition, x)
dpnp.synchronize_array_data(usm_res)
return dpnp_array._create_from_usm_ndarray(usm_res)


def fill_diagonal(a, val, wrap=False):
Expand Down
5 changes: 0 additions & 5 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order

tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
Expand Down
5 changes: 0 additions & 5 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,6 @@ tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_i
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_index
tests/third_party/cupy/indexing_tests/test_generate.py::TestUnravelIndex::test_invalid_order

tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis
tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist
tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast
Expand Down
2 changes: 1 addition & 1 deletion tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_density(self, dtype):
result_hist, result_edges = dpnp.histogram(iv, density=True)

if numpy.issubdtype(dtype, numpy.inexact):
tol = numpy.finfo(dtype).resolution
tol = 4 * numpy.finfo(dtype).resolution
assert_allclose(result_hist, expected_hist, rtol=tol, atol=tol)
assert_allclose(result_edges, expected_edges, rtol=tol, atol=tol)
else:
Expand Down
Loading
Loading