Skip to content

Commit

Permalink
[SYCL][Reduction] Support range version with multiple reductions (#7456)
Browse files Browse the repository at this point in the history
  • Loading branch information
aelovikov-intel authored Nov 21, 2022
1 parent f5f512b commit 572bc50
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 44 deletions.
34 changes: 16 additions & 18 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2029,25 +2029,24 @@ class __SYCL_EXPORT handler {

/// Reductions @{

template <typename KernelName = detail::auto_name, typename KernelType,
typename PropertiesT, int Dims, typename Reduction>
template <typename KernelName = detail::auto_name, int Dims,
typename PropertiesT, typename... RestT>
std::enable_if_t<
detail::IsReduction<Reduction>::value &&
(sizeof...(RestT) > 1) &&
detail::AreAllButLastReductions<RestT...>::value &&
ext::oneapi::experimental::is_property_list<PropertiesT>::value>
parallel_for(range<Dims> Range, PropertiesT Properties, Reduction Redu,
_KERNELFUNCPARAM(KernelFunc)) {
detail::reduction_parallel_for<KernelName>(*this, Range, Properties, Redu,
std::move(KernelFunc));
parallel_for(range<Dims> Range, PropertiesT Properties, RestT &&...Rest) {
detail::reduction_parallel_for<KernelName>(*this, Range, Properties,
std::forward<RestT>(Rest)...);
}

template <typename KernelName = detail::auto_name, typename KernelType,
int Dims, typename Reduction>
std::enable_if_t<detail::IsReduction<Reduction>::value>
parallel_for(range<Dims> Range, Reduction Redu,
_KERNELFUNCPARAM(KernelFunc)) {
template <typename KernelName = detail::auto_name, int Dims,
typename... RestT>
std::enable_if_t<detail::AreAllButLastReductions<RestT...>::value>
parallel_for(range<Dims> Range, RestT &&...Rest) {
parallel_for<KernelName>(
Range, ext::oneapi::experimental::detail::empty_properties_t{}, Redu,
std::move(KernelFunc));
Range, ext::oneapi::experimental::detail::empty_properties_t{},
std::forward<RestT>(Rest)...);
}

template <typename KernelName = detail::auto_name, int Dims,
Expand Down Expand Up @@ -2520,11 +2519,10 @@ class __SYCL_EXPORT handler {
friend void detail::reduction::withAuxHandler(handler &CGH, FunctorTy Func);

template <typename KernelName, detail::reduction::strategy Strategy, int Dims,
typename PropertiesT, typename KernelType, typename Reduction>
friend void detail::reduction_parallel_for(handler &CGH, range<Dims> Range,
typename PropertiesT, typename... RestT>
friend void detail::reduction_parallel_for(handler &CGH, range<Dims> NDRange,
PropertiesT Properties,
Reduction Redu,
KernelType KernelFunc);
RestT... Rest);

template <typename KernelName, detail::reduction::strategy Strategy, int Dims,
typename PropertiesT, typename... RestT>
Expand Down
66 changes: 45 additions & 21 deletions sycl/include/sycl/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2302,16 +2302,29 @@ __SYCL_EXPORT uint32_t
reduGetMaxNumConcurrentWorkGroups(std::shared_ptr<queue_impl> Queue);

template <typename KernelName, reduction::strategy Strategy, int Dims,
typename PropertiesT, typename KernelType, typename Reduction>
typename PropertiesT, typename... RestT>
void reduction_parallel_for(handler &CGH, range<Dims> Range,
PropertiesT Properties, Reduction Redu,
KernelType KernelFunc) {
PropertiesT Properties, RestT... Rest) {
std::tuple<RestT...> ArgsTuple(Rest...);
constexpr size_t NumArgs = sizeof...(RestT);
static_assert(NumArgs > 1, "No reduction!");
auto KernelFunc = std::get<NumArgs - 1>(ArgsTuple);
auto ReduIndices = std::make_index_sequence<NumArgs - 1>();
auto ReduTuple = detail::tuple_select_elements(ArgsTuple, ReduIndices);

// Before running the kernels, check that device has enough local memory
// to hold local arrays required for the tree-reduction algorithm.
constexpr bool IsTreeReduction =
!Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
size_t OneElemSize =
IsTreeReduction ? sizeof(typename Reduction::result_type) : 0;
size_t OneElemSize = [&]() {
if constexpr (NumArgs == 2) {
using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>;
constexpr bool IsTreeReduction =
!Reduction::has_fast_reduce && !Reduction::has_fast_atomics;
return IsTreeReduction ? sizeof(typename Reduction::result_type) : 0;
} else {
return reduGetMemPerWorkItem(ReduTuple, ReduIndices);
}
}();

uint32_t NumConcurrentWorkGroups =
#ifdef __SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS
__SYCL_REDUCTION_NUM_CONCURRENT_WORKGROUPS;
Expand Down Expand Up @@ -2341,7 +2354,7 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
// stride equal to 1. For each of the index the given the original KernelFunc
// is called and the reduction value hold in \p Reducer is accumulated in
// those calls.
auto UpdatedKernelFunc = [=](auto NDId, auto &Reducer) {
auto UpdatedKernelFunc = [=](auto NDId, auto &...Reducers) {
// Divide into contiguous chunks and assign each chunk to a Group
// Rely on precomputed division to avoid repeating expensive operations
// TODO: Some devices may prefer alternative remainder handling
Expand All @@ -2357,23 +2370,34 @@ void reduction_parallel_for(handler &CGH, range<Dims> Range,
size_t End = GroupEnd;
size_t Stride = NDId.get_local_range(0);
for (size_t I = Start; I < End; I += Stride)
KernelFunc(getDelinearizedId(Range, I), Reducer);
KernelFunc(getDelinearizedId(Range, I), Reducers...);
};
if constexpr (NumArgs == 2) {
using Reduction = std::tuple_element_t<0, decltype(ReduTuple)>;
auto &Redu = std::get<0>(ReduTuple);

constexpr auto StrategyToUse = [&]() {
if constexpr (Strategy != reduction::strategy::auto_select)
return Strategy;
constexpr auto StrategyToUse = [&]() {
if constexpr (Strategy != reduction::strategy::auto_select)
return Strategy;

if constexpr (Reduction::has_fast_reduce)
return reduction::strategy::group_reduce_and_last_wg_detection;
else if constexpr (Reduction::has_fast_atomics)
return reduction::strategy::local_atomic_and_atomic_cross_wg;
else
return reduction::strategy::range_basic;
}();
if constexpr (Reduction::has_fast_reduce)
return reduction::strategy::group_reduce_and_last_wg_detection;
else if constexpr (Reduction::has_fast_atomics)
return reduction::strategy::local_atomic_and_atomic_cross_wg;
else
return reduction::strategy::range_basic;
}();

reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
Redu, UpdatedKernelFunc);
reduction_parallel_for<KernelName, StrategyToUse>(CGH, NDRange, Properties,
Redu, UpdatedKernelFunc);
} else {
return std::apply(
[&](auto &...Reds) {
return reduction_parallel_for<KernelName, Strategy>(
CGH, NDRange, Properties, Reds..., UpdatedKernelFunc);
},
ReduTuple);
}
}
} // namespace detail

Expand Down
8 changes: 3 additions & 5 deletions sycl/include/sycl/reduction_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@ template <class FunctorTy> void withAuxHandler(handler &CGH, FunctorTy Func);

template <typename KernelName,
reduction::strategy Strategy = reduction::strategy::auto_select,
int Dims, typename PropertiesT, typename KernelType,
typename Reduction>
void reduction_parallel_for(handler &CGH, range<Dims> Range,
PropertiesT Properties, Reduction Redu,
KernelType KernelFunc);
int Dims, typename PropertiesT, typename... RestT>
void reduction_parallel_for(handler &CGH, range<Dims> NDRange,
PropertiesT Properties, RestT... Rest);

template <typename KernelName,
reduction::strategy Strategy = reduction::strategy::auto_select,
Expand Down

0 comments on commit 572bc50

Please sign in to comment.