Skip to content

Commit

Permalink
[PROTOTYPE] Generalized two pass algorithm and copy_if (#1700)
Browse files Browse the repository at this point in the history
This PR changes the two pass algorithm to be more generalized for use with other scan-like algorithms like copy_if.

This PR adds copy_if as an example

---------

Signed-off-by: Dan Hoeflinger <dan.hoeflinger@intel.com>
Signed-off-by: Matthew Michel <matthew.michel@intel.com>
Co-authored-by: Adam Fidel <110841220+adamfidel@users.noreply.github.com>
Co-authored-by: Matthew Michel <106704043+mmichel11@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 6, 2024
1 parent 1f96902 commit 5e9ac57
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 216 deletions.
3 changes: 2 additions & 1 deletion include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ exclusive_scan_by_segment_impl(__internal::__hetero_tag<_BackendTag>, Policy&& p
transform_inclusive_scan(::std::move(policy2), make_zip_iterator(_temp.get(), _flags.get()),
make_zip_iterator(_temp.get(), _flags.get()) + n, make_zip_iterator(result, _flags.get()),
internal::segmented_scan_fun<ValueType, FlagType, Operator>(binary_op),
oneapi::dpl::__internal::__no_op(), ::std::make_tuple(init, FlagType(1)));
oneapi::dpl::__internal::__no_op(),
oneapi::dpl::__internal::make_tuple(init, FlagType(1)));
return result + n;
}

Expand Down
176 changes: 138 additions & 38 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,82 @@ __group_scan_fits_in_slm(const sycl::queue& __queue, ::std::size_t __n, ::std::s
return (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size);
}

template <typename _UnaryOp>
struct __gen_transform_input
{
template <typename InRng>
auto
operator()(InRng&& __in_rng, std::size_t __idx) const
{
using _ValueType = oneapi::dpl::__internal::__value_t<InRng>;
using _OutValueType = oneapi::dpl::__internal::__decay_with_tuple_specialization_t<typename std::invoke_result<_UnaryOp, _ValueType>::type>;
return _OutValueType{__unary_op(__in_rng[__idx])};
}
_UnaryOp __unary_op;
};

struct __simple_write_to_idx
{
template <typename _OutRng, typename ValueType>
void
operator()(_OutRng&& __out, std::size_t __idx, const ValueType& __v) const
{
__out[__idx] = __v;
}
};

template <typename _Predicate>
struct __gen_count_pred
{
template <typename _InRng, typename _SizeType>
_SizeType
operator()(_InRng&& __in_rng, _SizeType __idx) const
{
return __pred(__in_rng[__idx]) ? _SizeType{1} : _SizeType{0};
}
_Predicate __pred;
};

template <typename _Predicate>
struct __gen_expand_count_pred
{
template <typename _InRng, typename _SizeType>
auto
operator()(_InRng&& __in_rng, _SizeType __idx) const
{
// Explicitly creating this element type is necessary to avoid modifying the input data when _InRng is a
// zip_iterator which will return a tuple of references when dereferenced. With this explicit type, we copy
// the values of zipped the input types rather than their references.
using _ElementType =
oneapi::dpl::__internal::__decay_with_tuple_specialization_t<oneapi::dpl::__internal::__value_t<_InRng>>;
_ElementType ele = __in_rng[__idx];
bool mask = __pred(ele);
return std::tuple(mask ? _SizeType{1} : _SizeType{0}, mask, ele);
}
_Predicate __pred;
};

struct __get_zeroth_element
{
template <typename _Tp>
auto&
operator()(_Tp&& __a) const
{
return std::get<0>(std::forward<_Tp>(__a));
}
};

struct __write_to_idx_if
{
template <typename _OutRng, typename _SizeType, typename ValueType>
void
operator()(_OutRng&& __out, _SizeType __idx, const ValueType& __v) const
{
if (std::get<1>(__v))
__out[std::get<0>(__v) - 1] = std::get<2>(__v);
}
};

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryOperation, typename _InitType,
typename _BinaryOperation, typename _Inclusive>
auto
Expand All @@ -774,39 +850,62 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen
_InitType __init, _BinaryOperation __binary_op, _Inclusive)
{
using _Type = typename _InitType::__value_type;

// Next power of 2 greater than or equal to __n
auto __n_uniform = __n;
if ((__n_uniform & (__n_uniform - 1)) != 0)
__n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1;

// TODO: can we reimplement this with support fort non-identities as well? We can then use in reduce-then-scan
// for the last block if it is sufficiently small
constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value;
if constexpr (__can_use_group_scan)
// Reduce-then-scan is dependent on sycl::shift_group_right which requires the underlying type to be trivially
// copyable. If this is not met, then we must fallback to the legacy implementation. The single work-group implementation
// requires a fundamental type which must also be trivially copyable.
if constexpr (std::is_trivially_copyable_v<_Type>)
{
if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform))
// Next power of 2 greater than or equal to __n
auto __n_uniform = __n;
if ((__n_uniform & (__n_uniform - 1)) != 0)
__n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1;

// TODO: can we reimplement this with support for non-identities as well? We can then use in reduce-then-scan
// for the last block if it is sufficiently small
constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value;
if constexpr (__can_use_group_scan)
{
return __parallel_transform_scan_single_group(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng),
::std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{});
if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform))
{
return __parallel_transform_scan_single_group(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng),
::std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{});
}
}
oneapi::dpl::__par_backend_hetero::__gen_transform_input<_UnaryOperation> __gen_transform{__unary_op};
return __future(__parallel_transform_reduce_then_scan(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng),
std::forward<_Range2>(__out_rng), __gen_transform, __binary_op, __gen_transform,
oneapi::dpl::__internal::__no_op{}, __simple_write_to_idx{}, __init, _Inclusive{})
.event());
}
else
{
using _Assigner = unseq_backend::__scan_assigner;
using _NoAssign = unseq_backend::__scan_no_assign;
using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>;
using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>;

_Assigner __assign_op;
_NoAssign __no_assign_op;
_NoOpFunctor __get_data_op;

return __future(
__parallel_transform_scan_base(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng),
std::forward<_Range2>(__out_rng), __binary_op, __init,
// local scan
unseq_backend::__scan<_Inclusive, _ExecutionPolicy, _BinaryOperation, _UnaryFunctor, _Assigner,
_Assigner, _NoOpFunctor, _InitType>{__binary_op, _UnaryFunctor{__unary_op},
__assign_op, __assign_op, __get_data_op},
// scan between groups
unseq_backend::__scan</*inclusive=*/std::true_type, _ExecutionPolicy, _BinaryOperation, _NoOpFunctor,
_NoAssign, _Assigner, _NoOpFunctor, unseq_backend::__no_init_value<_Type>>{
__binary_op, _NoOpFunctor{}, __no_assign_op, __assign_op, __get_data_op},
// global scan
unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init})
.event());
}

// TODO: Reintegrate once support has been added
//// Either we can't use group scan or this input is too big for one workgroup
//using _Assigner = unseq_backend::__scan_assigner;
//using _NoAssign = unseq_backend::__scan_no_assign;
//using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>;
//using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>;

//_Assigner __assign_op;
//_NoAssign __no_assign_op;
//_NoOpFunctor __get_data_op;
return __future(__parallel_transform_reduce_then_scan(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_Range1>(__in_rng), ::std::forward<_Range2>(__out_rng),
__binary_op, __unary_op, __init, _Inclusive{})
.event());
}

template <typename _SizeType>
Expand Down Expand Up @@ -907,15 +1006,14 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag,
// The kernel stores n integers for the predicate and another n integers for the offsets
const auto __req_slm_size = sizeof(::std::uint16_t) * __n_uniform * 2;

constexpr ::std::uint16_t __single_group_upper_limit = 16384;
constexpr ::std::uint16_t __single_group_upper_limit = 2048;

::std::size_t __max_wg_size = oneapi::dpl::__internal::__max_work_group_size(__exec);

if (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size &&
__max_wg_size >= _SingleGroupInvoker::__targeted_wg_size)
{
using _SizeBreakpoints =
::std::integer_sequence<::std::uint16_t, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384>;
using _SizeBreakpoints = ::std::integer_sequence<::std::uint16_t, 16, 32, 64, 128, 256, 512, 1024, 2048>;

return __par_backend_hetero::__static_monotonic_dispatcher<_SizeBreakpoints>::__dispatch(
_SingleGroupInvoker{}, __n, ::std::forward<_ExecutionPolicy>(__exec), __n, ::std::forward<_InRng>(__in_rng),
Expand All @@ -924,13 +1022,15 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag,
else
{
using _ReduceOp = ::std::plus<_Size>;
using CreateOp = unseq_backend::__create_mask<_Pred, _Size>;
using CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, oneapi::dpl::__internal::__pstl_assign,
/*inclusive*/ ::std::true_type, 1>;

return __parallel_scan_copy(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec),
::std::forward<_InRng>(__in_rng), ::std::forward<_OutRng>(__out_rng), __n,
CreateOp{__pred}, CopyOp{});
return __parallel_transform_reduce_then_scan(
__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng),
std::forward<_OutRng>(__out_rng), oneapi::dpl::__par_backend_hetero::__gen_count_pred<_Pred>{__pred},
_ReduceOp{}, oneapi::dpl::__par_backend_hetero::__gen_expand_count_pred<_Pred>{__pred},
oneapi::dpl::__par_backend_hetero::__get_zeroth_element{},
oneapi::dpl::__par_backend_hetero::__write_to_idx_if{},
oneapi::dpl::unseq_backend::__no_init_value<_Size>{},
/*_Inclusive=*/std::true_type{});
}
}

Expand Down
Loading

0 comments on commit 5e9ac57

Please sign in to comment.