Skip to content

Commit

Permalink
Merge pull request #1458 from IntelPython/fix-reduction-contig_impl-o…
Browse files Browse the repository at this point in the history
…ffset-handling

Fix reduction contig impl offset handling
  • Loading branch information
oleksandr-pavlyk authored Oct 27, 2023
2 parents 03fd737 + bfba152 commit d82f3a9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 20 deletions.
53 changes: 33 additions & 20 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,18 @@ namespace tensor
namespace kernels
{

template <typename ReductionOpT, typename T> struct needs_workaround
{
static constexpr bool value =
std::is_same_v<ReductionOpT, sycl::multiplies<T>> &&
(std::is_same_v<T, std::int64_t> || std::is_same_v<T, std::uint64_t>);
};

template <typename ReductionOpT, typename T> struct can_use_reduce_over_group
{
static constexpr bool value =
sycl::has_known_identity<ReductionOpT, T>::value &&
!std::is_same_v<T, std::int64_t> && !std::is_same_v<T, std::uint64_t> &&
!std::is_same_v<ReductionOpT, sycl::multiplies<T>>;
!needs_workaround<ReductionOpT, T>::value;
};

template <typename argT,
Expand Down Expand Up @@ -1088,7 +1094,7 @@ sycl::event reduction_over_group_temps_strided_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -1339,7 +1345,7 @@ sycl::event reduction_over_group_temps_strided_impl(
static_cast<py::ssize_t>(remaining_reduction_nelems)};
ResIndexerT res_iter_indexer{iter_nd, iter_res_offset,
/* shape */ iter_shape_and_strides,
/*s trides */ iter_shape_and_strides +
/* strides */ iter_shape_and_strides +
2 * iter_nd};

InputOutputIterIndexerT in_out_iter_indexer{inp_indexer,
Expand Down Expand Up @@ -1424,8 +1430,9 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;

Expand All @@ -1437,7 +1444,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -1767,8 +1774,9 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;

Expand All @@ -1780,7 +1788,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
// max_max_wg prevents running out of resources on CPU
constexpr size_t max_max_wg = 2048;
size_t max_wg = std::min(
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>());
max_max_wg, d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -3875,8 +3883,9 @@ sycl::event search_over_group_temps_strided_impl(

constexpr size_t preferrered_reductions_per_wi = 4;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -4258,8 +4267,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
Expand All @@ -4270,8 +4280,9 @@ sycl::event search_axis1_over_group_temps_contig_impl(

constexpr size_t preferrered_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down Expand Up @@ -4635,8 +4646,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(
py::ssize_t reduction_arg_offset,
const std::vector<sycl::event> &depends)
{
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp);
resTy *res_tp = reinterpret_cast<resTy *>(res_cp);
const argTy *arg_tp = reinterpret_cast<const argTy *>(arg_cp) +
iter_arg_offset + reduction_arg_offset;
resTy *res_tp = reinterpret_cast<resTy *>(res_cp) + iter_res_offset;

constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
Expand All @@ -4647,8 +4659,9 @@ sycl::event search_axis0_over_group_temps_contig_impl(

constexpr size_t preferrered_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg = std::min(
size_t(2048), d.get_info<sycl::info::device::max_work_group_size>());
size_t max_wg =
std::min(size_t(2048),
d.get_info<sycl::info::device::max_work_group_size>() / 2);

size_t reductions_per_wi(preferrered_reductions_per_wi);
if (reduction_nelems <= preferrered_reductions_per_wi * max_wg) {
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arg_dtype, q)

# test reduction for C-contiguous input
m = dpt.ones(100, dtype=arg_dtype)
r = dpt.sum(m)

Expand All @@ -55,12 +56,20 @@ def test_sum_arg_dtype_default_output_dtype_matrix(arg_dtype):
assert r.dtype.kind == "f"
elif m.dtype.kind == "c":
assert r.dtype.kind == "c"

assert dpt.all(r == 100)

# test reduction for strided input
m = dpt.ones(200, dtype=arg_dtype)[:1:-2]
r = dpt.sum(m)
assert dpt.all(r == 99)

# test reduction for strided input which can be simplified
# to contiguous computation
m = dpt.ones(100, dtype=arg_dtype)
r = dpt.sum(dpt.flip(m))
assert dpt.all(r == 100)


@pytest.mark.parametrize("arg_dtype", _all_dtypes)
@pytest.mark.parametrize("out_dtype", _all_dtypes[1:])
Expand Down
11 changes: 11 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def test_search_reduction_kernels(arg_dtype):
m = dpt.argmax(x)
assert m == idx

# test case of strided input mapping to contig
# implementation
m = dpt.argmax(dpt.flip(x))
assert m == x.size - 1 - idx

# test case of strided implementation
y = dpt.ones(2 * x.size, dtype=arg_dtype, sycl_queue=q)
y[::2] = x
m = dpt.argmax(y)
assert m == 2 * idx

x = dpt.reshape(x, (24, 1025))

x[idx_tup[0], :] = 3
Expand Down

0 comments on commit d82f3a9

Please sign in to comment.