diff --git a/dpctl/tests/test_tensor_sum.py b/dpctl/tests/test_tensor_sum.py index fbfd9547e1..749ca055b9 100644 --- a/dpctl/tests/test_tensor_sum.py +++ b/dpctl/tests/test_tensor_sum.py @@ -212,6 +212,36 @@ def test_axis0_bug(): assert dpt.all(s == expected) +def test_sum_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + # The atomic case is checked in `test_usm_ndarray_reductions` + # This test checks the tree reduction path for correctness + x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5)) + + m = dpt.sum(x, axis=0) + expected = dpt.asarray( + [ + [60, 63, 66, 69, 72], + [75, 78, 81, 84, 87], + [90, 93, 96, 99, 102], + [105, 108, 111, 114, 117], + ], + dtype="f4", + ) + tol = dpt.finfo(m.dtype).resolution + assert dpt.allclose(m, expected, atol=tol, rtol=tol) + + x = dpt.flip(x, axis=2) + m = dpt.sum(x, axis=2) + expected = dpt.asarray( + [[10, 35, 60, 85], [110, 135, 160, 185], [210, 235, 260, 285]], + dtype="f4", + ) + assert dpt.allclose(m, expected, atol=tol, rtol=tol) + + def _any_complex(dtypes): return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index 56059e54b8..45afb26aac 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -61,6 +61,20 @@ def test_max_min_axis(): assert dpt.all(m == x[:, 0, 0, :, 0]) +def test_max_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + x = dpt.reshape(dpt.arange(3 * 4 * 5), (3, 4, 5)) + + m = dpt.max(x, axis=0) + assert dpt.all(m == x[-1, :, :]) + + x = dpt.flip(x, axis=2) + m = dpt.max(x, axis=2) + assert dpt.all(m == x[:, :, 0]) + + def test_reduction_keepdims(): get_queue_or_skip() @@ -440,3 +454,28 @@ def test_hypot_complex(): x = dpt.zeros(1, dtype="c8") with pytest.raises(TypeError): dpt.reduce_hypot(x) + + +def test_tree_reduction_axis1_axis0(): + """See gh-1455""" + get_queue_or_skip() + + x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5)) + + m = dpt.logsumexp(x, axis=0) + tol = dpt.finfo(m.dtype).resolution + assert_allclose( + dpt.asnumpy(m), + np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype), + rtol=tol, + atol=tol, + ) + + x = dpt.flip(x, axis=2) + m = dpt.logsumexp(x, axis=2) + assert_allclose( + dpt.asnumpy(m), + np.logaddexp.reduce(dpt.asnumpy(x), axis=2, dtype=m.dtype), + rtol=tol, + atol=tol, + )