Skip to content

Commit

Permalink
Add proper handling of unput array of usm_ndarray type in dpnp.ix_ (
Browse files Browse the repository at this point in the history
#2047)

* Proper handling of input array as dpctl.tensor.usm_ndarray

* Update a page for Indexing routines

* Move dpnp.ix_ tests to keep lexycographical order

* Add tests to cover faulty use case

* Added entry to changelog
  • Loading branch information
antonwolfy committed Sep 16, 2024
1 parent be5c3f0 commit 6314346
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds
* Resolved an issue with `dpnp.matmul` when an f_contiguous `out` keyword is passed to the the function [#1872](https://github.com/IntelPython/dpnp/pull/1872)
* Resolved a possible race condition in `dpnp.inv` [#1940](https://github.com/IntelPython/dpnp/pull/1940)
* Resolved an issue with failing tests for `dpnp.append` when running on a device without fp64 support [#2034](https://github.com/IntelPython/dpnp/pull/2034)
* Resolved an issue with input array of `usm_ndarray` passed into `dpnp.ix_` [#2047](https://github.com/IntelPython/dpnp/pull/2047)


## [0.15.0] - 05/25/2024
Expand Down
9 changes: 4 additions & 5 deletions doc/reference/indexing.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
.. _routines.indexing:
.. _arrays.indexing:

Array Indexing Routines
=======================
Indexing routines
=================

.. https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
.. https://numpy.org/doc/stable/reference/routines.indexing.html
Generating index arrays
-----------------------


.. autosummary::
:toctree: generated/
:nosignatures:
Expand Down
2 changes: 1 addition & 1 deletion doc/reference/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Indexing arrays
Arrays can be indexed using an extended Python slicing syntax,
``array[selection]``.

.. seealso:: :ref:`Array Indexing Routines <routines.indexing>`.
.. seealso:: :ref:`Indexing routines <routines.indexing>`.


Array attributes
Expand Down
5 changes: 3 additions & 2 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ def ix_(*args):
and the dimension with the non-unit shape value cycles through all
N dimensions.
Using :obj:`dpnp.ix_` one can quickly construct index arrays that will
Using :obj:`dpnp.ix_` one can quickly construct index arrays that will
index the cross product. ``a[dpnp.ix_([1,3],[2,5])]`` returns the array
``[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]``.
Expand Down Expand Up @@ -994,14 +994,15 @@ def ix_(*args):
"""

dpnp.check_supported_arrays_type(*args)

out = []
nd = len(args)
for k, new in enumerate(args):
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
if dpnp.issubdtype(new.dtype, dpnp.bool):
(new,) = dpnp.nonzero(new)
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = dpnp.reshape(new, (1,) * k + (new.size,) + (1,) * (nd - k - 1))
out.append(new)
return tuple(out)

Expand Down
84 changes: 53 additions & 31 deletions tests/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools

import dpctl.tensor as dpt
import numpy
import pytest
from dpctl.tensor._numpy_helper import AxisError
Expand All @@ -12,6 +13,7 @@
)

import dpnp
from dpnp.dpnp_array import dpnp_array

from .helper import get_all_dtypes, get_integer_dtypes, has_support_aspect64

Expand Down Expand Up @@ -272,6 +274,57 @@ def test_indexing_array_negative_strides(self):
assert_array_equal(arr, 10.0)


class TestIx:
@pytest.mark.parametrize(
"x0", [[0, 1], [True, True]], ids=["[0, 1]", "[True, True]"]
)
@pytest.mark.parametrize(
"x1",
[[2, 4], [False, False, True, False, True]],
ids=["[2, 4]", "[False, False, True, False, True]"],
)
def test_ix(self, x0, x1):
expected = dpnp.ix_(dpnp.array(x0), dpnp.array(x1))
result = numpy.ix_(numpy.array(x0), numpy.array(x1))

assert_array_equal(result[0], expected[0])
assert_array_equal(result[1], expected[1])

@pytest.mark.parametrize("dt", [dpnp.intp, dpnp.float32])
def test_ix_empty_out(self, dt):
a = numpy.array([], dtype=dt)
ia = dpnp.array(a)

(result,) = dpnp.ix_(ia)
(expected,) = numpy.ix_(a)
assert_array_equal(result, expected)
assert a.dtype == dt

def test_repeated_input(self):
a = numpy.arange(5)
ia = dpnp.array(a)

result = dpnp.ix_(ia, ia)
expected = numpy.ix_(a, a)
assert_array_equal(result[0], expected[0])
assert_array_equal(result[1], expected[1])

@pytest.mark.parametrize("arr", [[2, 4, 0, 1], [True, False, True, True]])
def test_usm_ndarray_input(self, arr):
a = numpy.array(arr)
ia = dpt.asarray(a)

(result,) = dpnp.ix_(ia)
(expected,) = numpy.ix_(a)
assert_array_equal(result, expected)
assert isinstance(result, dpnp_array)

@pytest.mark.parametrize("xp", [dpnp, numpy])
@pytest.mark.parametrize("shape", [(), (2, 2)])
def test_ix_error(self, xp, shape):
assert_raises(ValueError, xp.ix_, xp.ones(shape))


class TestNonzero:
@pytest.mark.parametrize("list_val", [[], [0], [1]])
def test_trivial(self, list_val):
Expand Down Expand Up @@ -1143,37 +1196,6 @@ def test_empty_indices(self):
)


class TestIx:
@pytest.mark.parametrize(
"x0", [[0, 1], [True, True]], ids=["[0, 1]", "[True, True]"]
)
@pytest.mark.parametrize(
"x1",
[[2, 4], [False, False, True, False, True]],
ids=["[2, 4]", "[False, False, True, False, True]"],
)
def test_ix(self, x0, x1):
expected = dpnp.ix_(dpnp.array(x0), dpnp.array(x1))
result = numpy.ix_(numpy.array(x0), numpy.array(x1))

assert_array_equal(expected[0], result[0])
assert_array_equal(expected[1], result[1])

def test_ix_empty_out(self):
(a,) = dpnp.ix_(dpnp.array([], dtype=dpnp.intp))
assert_equal(a.dtype, dpnp.intp)

(a,) = dpnp.ix_(dpnp.array([], dtype=dpnp.float32))
assert_equal(a.dtype, dpnp.float32)

def test_ix_error(self):
with pytest.raises(ValueError):
dpnp.ix_(dpnp.ones(()))

with pytest.raises(ValueError):
dpnp.ix_(dpnp.ones((2, 2)))


class TestSelect:
choices_np = [
numpy.array([1, 2, 3]),
Expand Down

0 comments on commit 6314346

Please sign in to comment.