From d4b44d509bbb822598d15cab84505c18ae7f92a0 Mon Sep 17 00:00:00 2001 From: Dan Hoeflinger Date: Tue, 6 Aug 2024 12:01:42 -0400 Subject: [PATCH] __result_and_scratch_storage changes for scan Signed-off-by: Dan Hoeflinger --- .../pstl/hetero/dpcpp/parallel_backend_sycl.h | 99 +++++++++---------- .../dpcpp/parallel_backend_sycl_utils.h | 3 +- .../pstl/hetero/dpcpp/unseq_backend_sycl.h | 65 +++++++----- 3 files changed, 88 insertions(+), 79 deletions(-) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 3b7356bbaa7..35d0685edf2 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -262,29 +262,6 @@ __parallel_for(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& //------------------------------------------------------------------------ // parallel_transform_scan - async pattern //------------------------------------------------------------------------ -template -struct __global_scan_caller -{ - __global_scan_caller(const _GlobalScan& __global_scan, const _Range2& __rng2, const _Range1& __rng1, - const _Accessor& __wg_sums_acc, _Size __n, ::std::size_t __size_per_wg) - : __m_global_scan(__global_scan), __m_rng2(__rng2), __m_rng1(__rng1), __m_wg_sums_acc(__wg_sums_acc), - __m_n(__n), __m_size_per_wg(__size_per_wg) - { - } - - void operator()(sycl::item<1> __item) const - { - __m_global_scan(__item, __m_rng2, __m_rng1, __m_wg_sums_acc, __m_n, __m_size_per_wg); - } - - private: - _GlobalScan __m_global_scan; - _Range2 __m_rng2; - _Range1 __m_rng1; - _Accessor __m_wg_sums_acc; - _Size __m_n; - ::std::size_t __m_size_per_wg; -}; // Please see the comment for __parallel_for_submitter for optional kernel name explanation template @@ -332,14 +309,16 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name auto __size_per_wg = __iters_per_witem * __wgroup_size; auto __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_wg); // Storage for the results of scan for each workgroup - sycl::buffer<_Type> __wg_sums(__n_groups); + + using _TempStorage = __result_and_scratch_storage, _Type>; + _TempStorage __result_and_scratch{__exec, __n_groups + 1}; _PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu); // 1. Local scan on each workgroup auto __submit_event = __exec.queue().submit([&](sycl::handler& __cgh) { oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer - auto __wg_sums_acc = __wg_sums.template get_access(__cgh); + auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh); __dpl_sycl::__local_accessor<_Type> __local_acc(__wgroup_size, __cgh); #if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT __cgh.use_kernel_bundle(__kernel_1.get_kernel_bundle()); @@ -349,7 +328,8 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name __kernel_1, #endif sycl::nd_range<1>(__n_groups * __wgroup_size, __wgroup_size), [=](sycl::nd_item<1> __item) { - __local_scan(__item, __n, __local_acc, __rng1, __rng2, __wg_sums_acc, __size_per_wg, __wgroup_size, + auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc); + __local_scan(__item, __n, __local_acc, __rng1, __rng2, __temp_ptr, __size_per_wg, __wgroup_size, __iters_per_witem, __init); }); }); @@ -359,7 +339,7 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name auto __iters_per_single_wg = oneapi::dpl::__internal::__dpl_ceiling_div(__n_groups, __wgroup_size); __submit_event = __exec.queue().submit([&](sycl::handler& __cgh) { __cgh.depends_on(__submit_event); - auto __wg_sums_acc = __wg_sums.template get_access(__cgh); + auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh); __dpl_sycl::__local_accessor<_Type> __local_acc(__wgroup_size, __cgh); #if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT __cgh.use_kernel_bundle(__kernel_2.get_kernel_bundle()); @@ -370,8 +350,9 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name #endif // TODO: try to balance work between several workgroups instead of one sycl::nd_range<1>(__wgroup_size, __wgroup_size), [=](sycl::nd_item<1> __item) { - __group_scan(__item, __n_groups, __local_acc, __wg_sums_acc, __wg_sums_acc, - /*dummy*/ __wg_sums_acc, __n_groups, __wgroup_size, __iters_per_single_wg); + auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc); + __group_scan(__item, __n_groups, __local_acc, __temp_ptr, __temp_ptr, + /*dummy*/ __temp_ptr, __n_groups, __wgroup_size, __iters_per_single_wg); }); }); } @@ -380,15 +361,16 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name auto __final_event = __exec.queue().submit([&](sycl::handler& __cgh) { __cgh.depends_on(__submit_event); oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer - auto __wg_sums_acc = __wg_sums.template get_access(__cgh); - __cgh.parallel_for<_PropagateScanName...>( - sycl::range<1>(__n_groups * __size_per_wg), - __global_scan_caller<_GlobalScan, ::std::decay_t<_Range2>, ::std::decay_t<_Range1>, - decltype(__wg_sums_acc), decltype(__n)>(__global_scan, __rng2, __rng1, - __wg_sums_acc, __n, __size_per_wg)); + auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh); + auto __res_acc = __result_and_scratch.__get_result_acc(__cgh); + __cgh.parallel_for<_PropagateScanName...>(sycl::range<1>(__n_groups * __size_per_wg), [=](auto __item) { + auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc); + auto __res_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__res_acc, __n_groups + 1); + __global_scan(__item, __rng2, __rng1, __temp_ptr, __res_ptr, __n, __size_per_wg); + }); }); - return __future(__final_event, sycl::buffer(__wg_sums, sycl::id<1>(__n_groups - 1), sycl::range<1>(1))); + return __future(__final_event, __result_and_scratch); } }; @@ -434,7 +416,7 @@ struct __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive, const ::std::uint16_t __elems_per_item = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __wg_size); const ::std::uint16_t __elems_per_wg = __elems_per_item * __wg_size; - auto __event = __policy.queue().submit([&](sycl::handler& __hdl) { + return __policy.queue().submit([&](sycl::handler& __hdl) { oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng); auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg}, __hdl); @@ -466,7 +448,6 @@ struct __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive, } }); }); - return __future(__event); } }; @@ -489,7 +470,7 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem constexpr ::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize; - auto __event = __policy.queue().submit([&](sycl::handler& __hdl) { + return __policy.queue().submit([&](sycl::handler& __hdl) { oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng); auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg}, __hdl); @@ -559,7 +540,6 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem } }); }); - return __future(__event); } }; @@ -575,7 +555,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W template auto - operator()(const _Policy& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init, + operator()(_Policy&& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init, _BinaryOperation __bin_op, _UnaryOp __unary_op) { using _ValueType = ::std::uint16_t; @@ -587,7 +567,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize; - sycl::buffer<_Size> __res(sycl::range<1>(1)); + __result_and_scratch_storage, _Size> __result{__policy, 0}; auto __event = __policy.queue().submit([&](sycl::handler& __hdl) { oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng); @@ -596,10 +576,12 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W // predicate on each element of the input range. The second half stores the index of the output // range to copy elements of the input range. auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg * 2}, __hdl); - auto __res_acc = __res.template get_access(__hdl); + auto __res_acc = __result.__get_result_acc(__hdl); __hdl.parallel_for<_ScanKernelName...>( sycl::nd_range<1>(_WGSize, _WGSize), [=](sycl::nd_item<1> __self_item) { + auto __res_ptr = + __result_and_scratch_storage<_Policy, _Size>::__get_usm_or_buffer_accessor_ptr(__res_acc); const auto& __group = __self_item.get_group(); const auto& __subgroup = __self_item.get_sub_group(); // This kernel is only launched for sizes less than 2^16 @@ -654,11 +636,11 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W if (__item_id == 0) { // Add predicate of last element to account for the scan's exclusivity - __res_acc[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1]; + __res_ptr[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1]; } }); }); - return __future(__event, __res); + return __future(__event, __result); } }; @@ -677,6 +659,13 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend // Specialization for devices that have a max work-group size of 1024 constexpr ::std::uint16_t __targeted_wg_size = 1024; + using _ValueType = typename _InitType::__value_type; + + // Although we do not actually need result storage in this case, we need to construct + // a placeholder here to match the return type of the non-single-work-group implementation + using _TempStorage = __result_and_scratch_storage, _ValueType>; + _TempStorage __dummy_result_and_scratch{__exec, 0}; + if (__max_wg_size >= __targeted_wg_size) { auto __single_group_scan_f = [&](auto __size_constant) { @@ -686,8 +675,9 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend oneapi::dpl::__internal::__dpl_ceiling_div(__size, __wg_size); const bool __is_full_group = __n == __wg_size; + sycl::event __event; if (__is_full_group) - return __parallel_transform_scan_static_single_group_submitter< + __event = __parallel_transform_scan_static_single_group_submitter< _Inclusive::value, __num_elems_per_item, __wg_size, /* _IsFullGroup= */ true, oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel< @@ -697,7 +687,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op); else - return __parallel_transform_scan_static_single_group_submitter< + __event = __parallel_transform_scan_static_single_group_submitter< _Inclusive::value, __num_elems_per_item, __wg_size, /* _IsFullGroup= */ false, oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel< @@ -706,6 +696,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend /* _IsFullGroup= */ ::std::false_type, _Inclusive, _CustomName>>>()( ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op); + return __future(__event, __dummy_result_and_scratch); }; if (__n <= 16) return __single_group_scan_f(std::integral_constant<::std::uint16_t, 16>{}); @@ -735,9 +726,11 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend using _DynamicGroupScanKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< __par_backend_hetero::__scan_single_wg_dynamic_kernel<_BinaryOperation, _CustomName>>; - return __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive::value, _DynamicGroupScanKernel>()( - ::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng), - __n, __init, __binary_op, __unary_op, __max_wg_size); + auto __event = + __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive::value, _DynamicGroupScanKernel>()( + std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), + std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op, __max_wg_size); + return __future(__event, __dummy_result_and_scratch); } } @@ -807,8 +800,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen _NoAssign __no_assign_op; _NoOpFunctor __get_data_op; - return __future( - __parallel_transform_scan_base( + return __parallel_transform_scan_base( __backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng), ::std::forward<_Range2>(__out_rng), __binary_op, __init, // local scan @@ -820,8 +812,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen _NoAssign, _Assigner, _NoOpFunctor, unseq_backend::__no_init_value<_Type>>{ __binary_op, _NoOpFunctor{}, __no_assign_op, __assign_op, __get_data_op}, // global scan - unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init}) - .event()); + unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init}); } template diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h index 4b045ce6292..0641e575dbb 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h @@ -547,7 +547,8 @@ struct __result_and_scratch_storage } public: - __result_and_scratch_storage(_ExecutionPolicy& __exec, ::std::size_t __scratch_n) + template + __result_and_scratch_storage(_Policy&& __exec, ::std::size_t __scratch_n) : __exec{__exec}, __scratch_n{__scratch_n}, __use_USM_host{__use_USM_host_allocations(__exec.queue())}, __supports_USM_device{__use_USM_allocations(__exec.queue())} { diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h index 0d4b3a6b84e..b904a577c92 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/unseq_backend_sycl.h @@ -529,12 +529,19 @@ struct __mask_assigner struct __scan_assigner { template - void + std::enable_if_t> operator()(_OutAcc& __out_acc, const _OutIdx __out_idx, const _InAcc& __in_acc, _InIdx __in_idx) const { __out_acc[__out_idx] = __in_acc[__in_idx]; } + template + std::enable_if_t> + operator()(_OutAcc __out_acc, const _OutIdx __out_idx, const _InAcc& __in_acc, _InIdx __in_idx) const + { + __out_acc[__out_idx] = __in_acc[__in_idx]; + } + template void operator()(_Acc&, _OutAcc& __out_acc, const _OutIdx __out_idx, const _InAcc& __in_acc, _InIdx __in_idx) const @@ -578,11 +585,11 @@ struct __copy_by_mask _BinaryOp __binary_op; _Assigner __assigner; - template void - operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, const _WgSumsAcc& __wg_sums_acc, _Size __n, - _SizePerWg __size_per_wg) const + operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, _WgSumsPtr* __wg_sums_ptr, _RetPtr* __ret_ptr, + _Size __n, _SizePerWg __size_per_wg) const { using ::std::get; auto __item_idx = __item.get_linear_id(); @@ -598,7 +605,7 @@ struct __copy_by_mask if (__item_idx >= __size_per_wg) { auto __wg_sums_idx = __item_idx / __size_per_wg - 1; - __out_idx = __binary_op(__out_idx, __wg_sums_acc[__wg_sums_idx]); + __out_idx = __binary_op(__out_idx, __wg_sums_ptr[__wg_sums_idx]); } if (__item_idx % __size_per_wg == 0 || (get(__in_acc[__item_idx]) != get(__in_acc[__item_idx - 1]))) // If we work with tuples we might have a situation when internal tuple is assigned to ::std::tuple @@ -617,6 +624,11 @@ struct __copy_by_mask // is performed(i.e. __typle_type is the same type as its operand). __assigner(static_cast<__tuple_type>(get<0>(__in_acc[__item_idx])), __out_acc[__out_idx]); } + if (__item_idx == 0) + { + //copy final result to output + __ret_ptr[0] = __wg_sums_ptr[(__n - 1) / __size_per_wg]; + } } }; @@ -625,11 +637,11 @@ struct __partition_by_mask { _BinaryOp __binary_op; - template void - operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, const _WgSumsAcc& __wg_sums_acc, _Size __n, - _SizePerWg __size_per_wg) const + operator()(_Item __item, _OutAcc& __out_acc, const _InAcc& __in_acc, _WgSumsPtr* __wg_sums_ptr, _RetPtr* __ret_ptr, + _Size __n, _SizePerWg __size_per_wg) const { auto __item_idx = __item.get_linear_id(); if (__item_idx < __n) @@ -646,7 +658,7 @@ struct __partition_by_mask __in_type, ::std::decay_t(__out_acc[__out_idx]))>>::__type; if (__not_first_wg) - __out_idx = __binary_op(__out_idx, __wg_sums_acc[__wg_sums_idx - 1]); + __out_idx = __binary_op(__out_idx, __wg_sums_ptr[__wg_sums_idx - 1]); get<0>(__out_acc[__out_idx]) = static_cast<__tuple_type>(get<0>(__in_acc[__item_idx])); } else @@ -656,10 +668,15 @@ struct __partition_by_mask __in_type, ::std::decay_t(__out_acc[__out_idx]))>>::__type; if (__not_first_wg) - __out_idx -= __wg_sums_acc[__wg_sums_idx - 1]; + __out_idx -= __wg_sums_ptr[__wg_sums_idx - 1]; get<1>(__out_acc[__out_idx]) = static_cast<__tuple_type>(get<0>(__in_acc[__item_idx])); } } + if (__item_idx == 0) + { + //copy final result to output + __ret_ptr[0] = __wg_sums_ptr[(__n - 1) / __size_per_wg]; + } } }; @@ -669,10 +686,10 @@ struct __global_scan_functor _BinaryOp __binary_op; _InitType __init; - template void - operator()(_Item __item, _OutAcc& __out_acc, const _InAcc&, const _WgSumsAcc& __wg_sums_acc, _Size __n, + operator()(_Item __item, _OutAcc& __out_acc, const _InAcc&, _WgSumsPtr* __wg_sums_ptr, _RetPtr*, _Size __n, _SizePerWg __size_per_wg) const { constexpr auto __shift = _Inclusive{} ? 0 : 1; @@ -683,7 +700,7 @@ struct __global_scan_functor auto __wg_sums_idx = __item_idx / __size_per_wg - 1; // an initial value precedes the first group for the exclusive scan __item_idx += __shift; - auto __bin_op_result = __binary_op(__wg_sums_acc[__wg_sums_idx], __out_acc[__item_idx]); + auto __bin_op_result = __binary_op(__wg_sums_ptr[__wg_sums_idx], __out_acc[__item_idx]); using __out_type = ::std::decay_t; using __in_type = ::std::decay_t; __out_acc[__item_idx] = @@ -712,10 +729,10 @@ struct __scan _DataAccessor __data_acc; template + typename _WGSumsPtr, typename _SizePerWG, typename _WGSize, typename _ItersPerWG> void scan_impl(_NDItemId __item, _Size __n, _AccLocal& __local_acc, const _InAcc& __acc, _OutAcc& __out_acc, - _WGSumsAcc& __wg_sums_acc, _SizePerWG __size_per_wg, _WGSize __wgroup_size, _ItersPerWG __iters_per_wg, + _WGSumsPtr* __wg_sums_ptr, _SizePerWG __size_per_wg, _WGSize __wgroup_size, _ItersPerWG __iters_per_wg, _InitType __init, std::false_type /*has_known_identity*/) const { ::std::size_t __group_id = __item.get_group(0); @@ -785,18 +802,18 @@ struct __scan __gl_assigner(__acc, __out_acc, __adjusted_global_id + __shift, __local_acc, __local_id); if (__adjusted_global_id == __n - 1) - __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id); + __wg_assigner(__wg_sums_ptr, __group_id, __local_acc, __local_id); } if (__local_id == __wgroup_size - 1 && __adjusted_global_id - __wgroup_size < __n) - __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id); + __wg_assigner(__wg_sums_ptr, __group_id, __local_acc, __local_id); } template + typename _WGSumsPtr, typename _SizePerWG, typename _WGSize, typename _ItersPerWG> void scan_impl(_NDItemId __item, _Size __n, _AccLocal& __local_acc, const _InAcc& __acc, _OutAcc& __out_acc, - _WGSumsAcc& __wg_sums_acc, _SizePerWG __size_per_wg, _WGSize __wgroup_size, _ItersPerWG __iters_per_wg, + _WGSumsPtr* __wg_sums_ptr, _SizePerWG __size_per_wg, _WGSize __wgroup_size, _ItersPerWG __iters_per_wg, _InitType __init, std::true_type /*has_known_identity*/) const { auto __group_id = __item.get_group(0); @@ -831,21 +848,21 @@ struct __scan __gl_assigner(__acc, __out_acc, __adjusted_global_id + __shift, __local_acc, __local_id); if (__adjusted_global_id == __n - 1) - __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id); + __wg_assigner(__wg_sums_ptr, __group_id, __local_acc, __local_id); } if (__local_id == __wgroup_size - 1 && __adjusted_global_id - __wgroup_size < __n) - __wg_assigner(__wg_sums_acc, __group_id, __local_acc, __local_id); + __wg_assigner(__wg_sums_ptr, __group_id, __local_acc, __local_id); } template + typename _WGSumsPtr, typename _SizePerWG, typename _WGSize, typename _ItersPerWG> void operator()(_NDItemId __item, _Size __n, _AccLocal& __local_acc, const _InAcc& __acc, _OutAcc& __out_acc, - _WGSumsAcc& __wg_sums_acc, _SizePerWG __size_per_wg, _WGSize __wgroup_size, + _WGSumsPtr* __wg_sums_ptr, _SizePerWG __size_per_wg, _WGSize __wgroup_size, _ItersPerWG __iters_per_wg, _InitType __init = __no_init_value{}) const { - scan_impl(__item, __n, __local_acc, __acc, __out_acc, __wg_sums_acc, __size_per_wg, __wgroup_size, + scan_impl(__item, __n, __local_acc, __acc, __out_acc, __wg_sums_ptr, __size_per_wg, __wgroup_size, __iters_per_wg, __init, __has_known_identity<_BinaryOperation, _Tp>{}); } };