Skip to content

Commit

Permalink
Fix search reductions giving incorrect results for F-contiguous inputs (
Browse files Browse the repository at this point in the history
#1462)

* Fixes correctness regression in search functions

``py_search_over_axis`` no longer calls the ``axis1`` contiguous variant

``py_search_over_axis`` now only calls ``axis0`` variant wh

* Adds tests for fixed search reduction behavior
  • Loading branch information
ndgrigorian authored Nov 1, 2023
1 parent 9131925 commit 11ecba8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
13 changes: 3 additions & 10 deletions dpctl/tensor/libtensor/source/reductions/reduction_over_axis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -874,14 +874,11 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

// handle special case when both reduction and iteration are 1D contiguous
// and can be done with atomics
bool is_src_c_contig = src.is_c_contiguous();
bool is_dst_c_contig = dst.is_c_contiguous();
bool is_src_f_contig = src.is_f_contiguous();

if ((is_src_c_contig && is_dst_c_contig) ||
(is_src_f_contig && dst_nelems == 1))
{
if (is_src_c_contig && is_dst_c_contig) {
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
if (fn != nullptr) {
size_t iter_nelems = dst_nelems;
Expand All @@ -903,9 +900,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
reduction_over_axis_contig_ev);
}
}
else if (is_src_f_contig &&
((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous()))
{
else if (is_src_f_contig && dst_nd == 1) {
auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid];
if (fn != nullptr) {
size_t iter_nelems = dst_nelems;
Expand Down Expand Up @@ -983,11 +978,9 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
if ((reduction_nd == 1) && (iteration_nd == 1)) {
bool mat_reduce_over_axis1 = false;
bool mat_reduce_over_axis0 = false;
bool array_reduce_all_elems = false;
size_t iter_nelems = dst_nelems;

if (compact_reduction_src_strides[0] == 1) {
array_reduce_all_elems = (simplified_iteration_shape[0] == 1);
mat_reduce_over_axis1 =
(simplified_iteration_dst_strides[0] == 1) &&
(static_cast<size_t>(simplified_iteration_src_strides[0]) ==
Expand All @@ -1000,7 +993,7 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
(simplified_iteration_src_strides[0] == 1);
}

if (mat_reduce_over_axis1 || array_reduce_all_elems) {
if (mat_reduce_over_axis1) {
auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid];
if (fn != nullptr) {
sycl::event reduction_over_axis1_contig_ev =
Expand Down
16 changes: 16 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,22 @@ def test_argmax_argmin_identities():
assert dpt.argmin(x) == 0


@pytest.mark.parametrize("order", ["C", "F"])
def test_argmax_axis0_axis1(order):
get_queue_or_skip()

x = dpt.asarray([[1, 2, 3], [6, 5, 4]], dtype="i4", order=order)
assert dpt.argmax(x) == 3

res = dpt.argmax(x, axis=0)
expected = dpt.asarray([1, 1, 1], dtype=res.dtype)
assert dpt.all(res == expected)

res = dpt.argmax(x, axis=1)
expected = dpt.asarray([2, 0], dtype=res.dtype)
assert dpt.all(res == expected)


def test_reduction_arg_validation():
get_queue_or_skip()

Expand Down

0 comments on commit 11ecba8

Please sign in to comment.