Skip to content

Commit

Permalink
Improves performance of search reductions for small numbers of elemen…
Browse files Browse the repository at this point in the history
…ts (#1464)

* Adds SequentialSearchReduction functor to search reductions

* Search reductions use correct branch for float16

constexpr branch logic accounted for floating point types but not sycl::half,
which meant NaNs were not propagating for float16 data
  • Loading branch information
ndgrigorian authored Nov 3, 2023
1 parent 097ecf5 commit af28d98
Showing 1 changed file with 248 additions and 3 deletions.
251 changes: 248 additions & 3 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3401,6 +3401,129 @@ struct LogSumExpOverAxis0TempsContigFactory

// Argmax and Argmin

/* Sequential search reduction */

template <typename argT,
typename outT,
typename ReductionOp,
typename IdxReductionOp,
typename InputOutputIterIndexerT,
typename InputRedIndexerT>
struct SequentialSearchReduction
{
private:
const argT *inp_ = nullptr;
outT *out_ = nullptr;
ReductionOp reduction_op_;
argT identity_;
IdxReductionOp idx_reduction_op_;
outT idx_identity_;
InputOutputIterIndexerT inp_out_iter_indexer_;
InputRedIndexerT inp_reduced_dims_indexer_;
size_t reduction_max_gid_ = 0;

public:
SequentialSearchReduction(const argT *inp,
outT *res,
ReductionOp reduction_op,
const argT &identity_val,
IdxReductionOp idx_reduction_op,
const outT &idx_identity_val,
InputOutputIterIndexerT arg_res_iter_indexer,
InputRedIndexerT arg_reduced_dims_indexer,
size_t reduction_size)
: inp_(inp), out_(res), reduction_op_(reduction_op),
identity_(identity_val), idx_reduction_op_(idx_reduction_op),
idx_identity_(idx_identity_val),
inp_out_iter_indexer_(arg_res_iter_indexer),
inp_reduced_dims_indexer_(arg_reduced_dims_indexer),
reduction_max_gid_(reduction_size)
{
}

void operator()(sycl::id<1> id) const
{

auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]);
const py::ssize_t &inp_iter_offset =
inp_out_iter_offsets_.get_first_offset();
const py::ssize_t &out_iter_offset =
inp_out_iter_offsets_.get_second_offset();

argT red_val(identity_);
outT idx_val(idx_identity_);
for (size_t m = 0; m < reduction_max_gid_; ++m) {
const py::ssize_t inp_reduction_offset =
inp_reduced_dims_indexer_(m);
const py::ssize_t inp_offset =
inp_iter_offset + inp_reduction_offset;

argT val = inp_[inp_offset];
if (val == red_val) {
idx_val = idx_reduction_op_(idx_val, static_cast<outT>(m));
}
else {
if constexpr (su_ns::IsMinimum<argT, ReductionOp>::value) {
using dpctl::tensor::type_utils::is_complex;
if constexpr (is_complex<argT>::value) {
using dpctl::tensor::math_utils::less_complex;
// less_complex always returns false for NaNs, so check
if (less_complex<argT>(val, red_val) ||
std::isnan(std::real(val)) ||
std::isnan(std::imag(val)))
{
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else {
if (val < red_val) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
}
else if constexpr (su_ns::IsMaximum<argT, ReductionOp>::value) {
using dpctl::tensor::type_utils::is_complex;
if constexpr (is_complex<argT>::value) {
using dpctl::tensor::math_utils::greater_complex;
if (greater_complex<argT>(val, red_val) ||
std::isnan(std::real(val)) ||
std::isnan(std::imag(val)))
{
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > red_val || std::isnan(val)) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
else {
if (val > red_val) {
red_val = val;
idx_val = static_cast<outT>(m);
}
}
}
}
}
out_[out_iter_offset] = idx_val;
}
};

/* = Search reduction using reduce_over_group*/

template <typename argT,
Expand Down Expand Up @@ -3670,7 +3793,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val < local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3714,7 +3839,9 @@ struct CustomSearchReduction
}
}
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
if (val > local_red_val || std::isnan(val)) {
local_red_val = val;
if constexpr (!First) {
Expand Down Expand Up @@ -3757,7 +3884,9 @@ struct CustomSearchReduction
? local_idx
: idx_identity_;
}
else if constexpr (std::is_floating_point_v<argT>) {
else if constexpr (std::is_floating_point_v<argT> ||
std::is_same_v<argT, sycl::half>)
{
// equality does not hold for NaNs, so check here
local_idx =
(red_val_over_wg == local_red_val || std::isnan(local_red_val))
Expand Down Expand Up @@ -3799,6 +3928,14 @@ typedef sycl::event (*search_strided_impl_fn_ptr)(
py::ssize_t,
const std::vector<sycl::event> &);

template <typename T1,
typename T2,
typename T3,
typename T4,
typename T5,
typename T6>
class search_seq_strided_krn;

template <typename T1,
typename T2,
typename T3,
Expand All @@ -3820,6 +3957,14 @@ template <typename T1,
bool b2>
class custom_search_over_group_temps_strided_krn;

template <typename T1,
typename T2,
typename T3,
typename T4,
typename T5,
typename T6>
class search_seq_contig_krn;

template <typename T1,
typename T2,
typename T3,
Expand Down Expand Up @@ -4019,6 +4164,36 @@ sycl::event search_over_group_temps_strided_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer;
using ReductionIndexerT =
dpctl::tensor::offset_utils::StridedIndexer;

InputOutputIterIndexerT in_out_iter_indexer{
iter_nd, iter_arg_offset, iter_res_offset,
iter_shape_and_strides};
ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset,
reduction_shape_stride};

cgh.parallel_for<class search_seq_strided_krn<
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>>(
sycl::range<1>(iter_nelems),
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 4;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =
Expand Down Expand Up @@ -4419,6 +4594,39 @@ sycl::event search_axis1_over_group_temps_contig_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using InputIterIndexerT =
dpctl::tensor::offset_utils::Strided1DIndexer;
using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
InputIterIndexerT, NoOpIndexerT>;
using ReductionIndexerT = NoOpIndexerT;

InputOutputIterIndexerT in_out_iter_indexer{
InputIterIndexerT{0, static_cast<py::ssize_t>(iter_nelems),
static_cast<py::ssize_t>(reduction_nelems)},
NoOpIndexerT{}};
ReductionIndexerT reduction_indexer{};

cgh.parallel_for<class search_seq_contig_krn<
argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>>(
sycl::range<1>(iter_nelems),
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =
Expand Down Expand Up @@ -4801,6 +5009,43 @@ sycl::event search_axis0_over_group_temps_contig_impl(
const auto &sg_sizes = d.get_info<sycl::info::device::sub_group_sizes>();
size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes);

if (reduction_nelems < wg) {
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);

using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer;
using InputOutputIterIndexerT =
dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer<
NoOpIndexerT, NoOpIndexerT>;
using ReductionIndexerT =
dpctl::tensor::offset_utils::Strided1DIndexer;

InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{},
NoOpIndexerT{}};
ReductionIndexerT reduction_indexer{
0, static_cast<py::ssize_t>(reduction_nelems),
static_cast<py::ssize_t>(iter_nelems)};

using KernelName =
class search_seq_contig_krn<argTy, resTy, ReductionOpT,
IndexOpT, InputOutputIterIndexerT,
ReductionIndexerT>;

sycl::range<1> iter_range{iter_nelems};

cgh.parallel_for<KernelName>(
iter_range,
SequentialSearchReduction<argTy, resTy, ReductionOpT, IndexOpT,
InputOutputIterIndexerT,
ReductionIndexerT>(
arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(),
idx_identity_val, in_out_iter_indexer, reduction_indexer,
reduction_nelems));
});

return comp_ev;
}

constexpr size_t preferred_reductions_per_wi = 8;
// max_max_wg prevents running out of resources on CPU
size_t max_wg =
Expand Down

0 comments on commit af28d98

Please sign in to comment.