Skip to content

Commit

Permalink
supporting assign into single wg copy_if
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <dan.hoeflinger@intel.com>
  • Loading branch information
danhoeflinger committed Jul 22, 2024
1 parent 7ad26d7 commit 964a9d2
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,10 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
__internal::__optional_kernel_name<_ScanKernelName...>>
{
template <typename _Policy, typename _InRng, typename _OutRng, typename _InitType, typename _BinaryOperation,
typename _UnaryOp>
typename _UnaryOp, typename _Assign>
auto
operator()(_Policy&& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init,
_BinaryOperation __bin_op, _UnaryOp __unary_op)
_BinaryOperation __bin_op, _UnaryOp __unary_op, _Assign __assign)
{
using _ValueType = ::std::uint16_t;

Expand Down Expand Up @@ -636,12 +636,13 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W

__scan_work_group<_ValueType, /* _Inclusive */ false>(
__group, __lacc_ptr, __lacc_ptr + __elems_per_wg, __lacc_ptr + __elems_per_wg, __bin_op,
__init);
__init);

for (::std::uint16_t __idx = __item_id; __idx < __n; __idx += _WGSize)
{
if (__lacc[__idx])
__out_rng[__lacc[__idx + __elems_per_wg]] = static_cast<__tuple_type>(__in_rng[__idx]);
__assign(static_cast<__tuple_type>(__in_rng[__idx]),
__out_rng[__lacc[__idx + __elems_per_wg]]);
}

const ::std::uint16_t __residual = __n % _WGSize;
Expand All @@ -650,7 +651,8 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
{
auto __idx = __residual_start + __item_id;
if (__lacc[__idx])
__out_rng[__lacc[__idx + __elems_per_wg]] = static_cast<__tuple_type>(__in_rng[__idx]);
__assign(static_cast<__tuple_type>(__in_rng[__idx]),
__out_rng[__lacc[__idx + __elems_per_wg]]);
}

if (__item_id == 0)
Expand Down Expand Up @@ -982,9 +984,11 @@ struct __invoke_single_group_copy_if
// Specialization for devices that have a max work-group size of at least 1024
static constexpr ::std::uint16_t __targeted_wg_size = 1024;

template <::std::uint16_t _Size, typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _Pred>
template <::std::uint16_t _Size, typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _Pred,
typename _Assign = oneapi::dpl::__internal::__pstl_assign>
auto
operator()(_ExecutionPolicy&& __exec, ::std::size_t __n, _InRng&& __in_rng, _OutRng&& __out_rng, _Pred&& __pred)
operator()(_ExecutionPolicy&& __exec, ::std::size_t __n, _InRng&& __in_rng, _OutRng&& __out_rng, _Pred&& __pred,
_Assign __assign)
{
constexpr ::std::uint16_t __wg_size = ::std::min(_Size, __targeted_wg_size);
constexpr ::std::uint16_t __num_elems_per_item = ::oneapi::dpl::__internal::__dpl_ceiling_div(_Size, __wg_size);
Expand All @@ -1002,7 +1006,7 @@ struct __invoke_single_group_copy_if
/* _IsFullGroup= */ std::true_type, _CustomName>>
>()(
__exec, ::std::forward<_InRng>(__in_rng), ::std::forward<_OutRng>(__out_rng), __n, _InitType{},
_ReduceOp{}, ::std::forward<_Pred>(__pred));
_ReduceOp{}, ::std::forward<_Pred>(__pred), __assign);
else
return __par_backend_hetero::__parallel_copy_if_static_single_group_submitter<
_SizeType, __num_elems_per_item, __wg_size, false,
Expand All @@ -1012,7 +1016,7 @@ struct __invoke_single_group_copy_if
/* _IsFullGroup= */ std::false_type, _CustomName>>
>()(
__exec, ::std::forward<_InRng>(__in_rng), ::std::forward<_OutRng>(__out_rng), __n, _InitType{},
_ReduceOp{}, ::std::forward<_Pred>(__pred));
_ReduceOp{}, ::std::forward<_Pred>(__pred), __assign);
}
};

Expand Down Expand Up @@ -1063,7 +1067,7 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag,

return __par_backend_hetero::__static_monotonic_dispatcher<_SizeBreakpoints>::__dispatch(
_SingleGroupInvoker{}, __n, ::std::forward<_ExecutionPolicy>(__exec), __n, ::std::forward<_InRng>(__in_rng),
::std::forward<_OutRng>(__out_rng), __pred);
::std::forward<_OutRng>(__out_rng), __pred, __assign);
}
else
{
Expand Down

0 comments on commit 964a9d2

Please sign in to comment.