Skip to content

Commit

Permalink
Changes per PR review
Browse files Browse the repository at this point in the history
Make tests more maintenance-friendly
  • Loading branch information
ndgrigorian committed Dec 13, 2024
1 parent f50813e commit 7602b4f
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions dpctl/tests/test_usm_ndarray_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,14 +1864,13 @@ def test_take_along_axis_uint64_indices():
get_queue_or_skip()

inds = dpt.arange(1, 10, 2, dtype="u8")

x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), 5)
res = dpt.take_along_axis(x, inds)
assert dpt.all(res == -1)

x = dpt.tile(dpt.asarray([0, -1], dtype="i4"), (2, 5))
inds = dpt.arange(1, 10, 2, dtype="u8")
inds = dpt.broadcast_to(inds, (2, 5))
sh0 = 2
inds = dpt.broadcast_to(inds, (sh0,) + inds.shape)
x = dpt.broadcast_to(x, (sh0,) + x.shape)
res = dpt.take_along_axis(x, inds, axis=1)
assert dpt.all(res == -1)

Expand All @@ -1880,14 +1879,14 @@ def test_put_along_axis_uint64_indices():
get_queue_or_skip()

inds = dpt.arange(1, 10, 2, dtype="u8")

x = dpt.zeros(10, dtype="i4")
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype))
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), 5)
assert dpt.all(x == expected)

x = dpt.zeros((2, 10), dtype="i4")
inds = dpt.broadcast_to(inds, (2, 5))
sh0 = 2
inds = dpt.broadcast_to(inds, (sh0,) + inds.shape)
x = dpt.zeros((sh0,) + x.shape, dtype="i4")
dpt.put_along_axis(x, inds, dpt.asarray(2, dtype=x.dtype), axis=1)
expected = dpt.tile(dpt.asarray([0, 2], dtype="i4"), (2, 5))
assert dpt.all(expected == x)

0 comments on commit 7602b4f

Please sign in to comment.