Skip to content

Commit

Permalink
Adds tests for #1455 resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Oct 27, 2023
1 parent f293713 commit 891161f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
30 changes: 30 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,36 @@ def test_axis0_bug():
assert dpt.all(s == expected)


def test_sum_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

# The atomic case is checked in `test_usm_ndarray_reductions`
# This test checks the tree reduction path for correctness
x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))

m = dpt.sum(x, axis=0)
expected = dpt.asarray(
[
[60, 63, 66, 69, 72],
[75, 78, 81, 84, 87],
[90, 93, 96, 99, 102],
[105, 108, 111, 114, 117],
],
dtype="f4",
)
tol = dpt.finfo(m.dtype).resolution
assert dpt.allclose(m, expected, atol=tol, rtol=tol)

x = dpt.flip(x, axis=2)
m = dpt.sum(x, axis=2)
expected = dpt.asarray(
[[10, 35, 60, 85], [110, 135, 160, 185], [210, 235, 260, 285]],
dtype="f4",
)
assert dpt.allclose(m, expected, atol=tol, rtol=tol)


def _any_complex(dtypes):
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)

Expand Down
39 changes: 39 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ def test_max_min_axis():
assert dpt.all(m == x[:, 0, 0, :, 0])


def test_max_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

x = dpt.reshape(dpt.arange(3 * 4 * 5), (3, 4, 5))

m = dpt.max(x, axis=0)
assert dpt.all(m == x[-1, :, :])

x = dpt.flip(x, axis=2)
m = dpt.max(x, axis=2)
assert dpt.all(m == x[:, :, 0])


def test_reduction_keepdims():
get_queue_or_skip()

Expand Down Expand Up @@ -440,3 +454,28 @@ def test_hypot_complex():
x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
dpt.reduce_hypot(x)


def test_tree_reduction_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))

m = dpt.logsumexp(x, axis=0)
tol = dpt.finfo(m.dtype).resolution
assert_allclose(
dpt.asnumpy(m),
np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype),
rtol=tol,
atol=tol,
)

x = dpt.flip(x, axis=2)
m = dpt.logsumexp(x, axis=2)
assert_allclose(
dpt.asnumpy(m),
np.logaddexp.reduce(dpt.asnumpy(x), axis=2, dtype=m.dtype),
rtol=tol,
atol=tol,
)

0 comments on commit 891161f

Please sign in to comment.