Skip to content

Commit

Permalink
Implement support of tuple key in __getitem__ and __setitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Apr 3, 2023
1 parent d41ed51 commit 5b082f8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 5 deletions.
24 changes: 20 additions & 4 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,23 @@

import dpnp


def _get_unwrapped_index_key(key):
"""
Return a key where each nested instance of DPNP array is unwrapped into USM ndarray
for futher processing in DPCTL advanced indexing functions.
"""

if isinstance(key, tuple):
if any(isinstance(x, dpnp_array) for x in key):
# create a new tuple from the input key with unwrapped DPNP arrays
return tuple(x.get_array() if isinstance(x, dpnp_array) else x for x in key)
elif isinstance(key, dpnp_array):
return key.get_array()
return key


class dpnp_array:
"""
Multi-dimensional array object.
Expand Down Expand Up @@ -176,8 +193,7 @@ def __ge__(self, other):
# '__getattribute__',

def __getitem__(self, key):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

item = self._array_obj.__getitem__(key)
if not isinstance(item, dpt.usm_ndarray):
Expand Down Expand Up @@ -337,8 +353,8 @@ def __rxor__(self, other):
# '__setattr__',

def __setitem__(self, key, val):
if isinstance(key, dpnp_array):
key = key.get_array()
key = _get_unwrapped_index_key(key)

if isinstance(val, dpnp_array):
val = val.get_array()

Expand Down
62 changes: 61 additions & 1 deletion tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,70 @@

import numpy
from numpy.testing import (
assert_array_equal
assert_,
assert_array_equal,
assert_equal
)


class TestIndexing:
def test_ellipsis_index(self):
a = dpnp.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
assert_(a[...] is not a)
assert_equal(a[...], a)

# test that slicing with ellipsis doesn't skip an arbitrary number of dimensions
assert_equal(a[0, ...], a[0])
assert_equal(a[0, ...], a[0,:])
assert_equal(a[..., 0], a[:, 0])

# test that slicing with ellipsis always results in an array
assert_equal(a[0, ..., 1], dpnp.array(2))

# assignment with `(Ellipsis,)` on 0-d arrays
b = dpnp.array(1)
b[(Ellipsis,)] = 2
assert_equal(b, 2)

def test_boolean_indexing_list(self):
a = dpnp.array([1, 2, 3])
b = dpnp.array([True, False, True])

assert_equal(a[b], [1, 3])
assert_equal(a[None, b], [[1, 3]])

def test_indexing_array_weird_strides(self):
np_x = numpy.ones(10)
dp_x = dpnp.ones(10)

np_ind = numpy.arange(10)[:, None, None, None]
np_ind = numpy.broadcast_to(np_ind, (10, 55, 4, 4))

dp_ind = dpnp.arange(10)[:, None, None, None]
dp_ind = dpnp.broadcast_to(dp_ind, (10, 55, 4, 4))

# single advanced index case
assert_array_equal(dp_x[dp_ind], np_x[np_ind])

np_x2 = numpy.ones((10, 2))
dp_x2 = dpnp.ones((10, 2))

np_zind = numpy.zeros(4, dtype=np_ind.dtype)
dp_zind = dpnp.zeros(4, dtype=dp_ind.dtype)

# higher dimensional advanced index
assert_array_equal(dp_x2[dp_ind, dp_zind], np_x2[np_ind, np_zind])

def test_indexing_array_negative_strides(self):
arr = dpnp.zeros((4, 4))[::-1, ::-1]

slices = (slice(None), dpnp.array([0, 1, 2, 3]))
arr[slices] = 10
assert_array_equal(arr, 10.)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
def test_choose():
a = numpy.r_[:4]
Expand Down

0 comments on commit 5b082f8

Please sign in to comment.