diff --git a/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h b/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h index cd33872aa0d..266fad7e410 100644 --- a/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h +++ b/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h @@ -161,7 +161,8 @@ exclusive_scan_by_segment_impl(__internal::__hetero_tag<_BackendTag>, Policy&& p transform_inclusive_scan(::std::move(policy2), make_zip_iterator(_temp.get(), _flags.get()), make_zip_iterator(_temp.get(), _flags.get()) + n, make_zip_iterator(result, _flags.get()), internal::segmented_scan_fun(binary_op), - oneapi::dpl::__internal::__no_op(), ::std::make_tuple(init, FlagType(1))); + oneapi::dpl::__internal::__no_op(), + oneapi::dpl::__internal::make_tuple(init, FlagType(1))); return result + n; } diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index cfde868e8d1..6eda3ed237f 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -577,7 +577,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W template auto - operator()(const _Policy& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init, + operator()(_Policy&& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init, _BinaryOperation __bin_op, _UnaryOp __unary_op) { using _ValueType = ::std::uint16_t; @@ -589,7 +589,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize; - sycl::buffer<_Size> __res(sycl::range<1>(1)); + __result_and_scratch_storage<_Policy, _Size> __result{__policy, 0}; auto __event = __policy.queue().submit([&](sycl::handler& __hdl) { oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng); @@ -598,10 +598,12 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W // predicate on each element of the input range. The second half stores the index of the output // range to copy elements of the input range. auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg * 2}, __hdl); - auto __res_acc = __res.template get_access(__hdl); + auto __res_acc = __result.__get_result_acc(__hdl); __hdl.parallel_for<_ScanKernelName...>( sycl::nd_range<1>(_WGSize, _WGSize), [=](sycl::nd_item<1> __self_item) { + auto __res_ptr = + __result_and_scratch_storage<_Policy, _Size>::__get_usm_or_buffer_accessor_ptr(__res_acc); const auto& __group = __self_item.get_group(); const auto& __subgroup = __self_item.get_sub_group(); // This kernel is only launched for sizes less than 2^16 @@ -656,11 +658,11 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W if (__item_id == 0) { // Add predicate of last element to account for the scan's exclusivity - __res_acc[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1]; + __res_ptr[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1]; } }); }); - return __future(__event, __res); + return __future(__event, __result); } }; @@ -774,6 +776,82 @@ __group_scan_fits_in_slm(const sycl::queue& __queue, ::std::size_t __n, ::std::s return (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size); } +template +struct __gen_transform_input +{ + template + auto + operator()(InRng&& __in_rng, std::size_t __idx) const + { + using _ValueType = oneapi::dpl::__internal::__value_t; + using _OutValueType = oneapi::dpl::__internal::__decay_with_tuple_specialization_t::type>; + return _OutValueType{__unary_op(__in_rng[__idx])}; + } + _UnaryOp __unary_op; +}; + +struct __simple_write_to_idx +{ + template + void + operator()(_OutRng&& __out, std::size_t __idx, const ValueType& __v) const + { + __out[__idx] = __v; + } +}; + +template +struct __gen_count_pred +{ + template + _SizeType + operator()(_InRng&& __in_rng, _SizeType __idx) const + { + return __pred(__in_rng[__idx]) ? _SizeType{1} : _SizeType{0}; + } + _Predicate __pred; +}; + +template +struct __gen_expand_count_pred +{ + template + auto + operator()(_InRng&& __in_rng, _SizeType __idx) const + { + // Explicitly creating this element type is necessary to avoid modifying the input data when _InRng is a + // zip_iterator which will return a tuple of references when dereferenced. With this explicit type, we copy + // the values of zipped the input types rather than their references. + using _ElementType = + oneapi::dpl::__internal::__decay_with_tuple_specialization_t>; + _ElementType ele = __in_rng[__idx]; + bool mask = __pred(ele); + return std::tuple(mask ? _SizeType{1} : _SizeType{0}, mask, ele); + } + _Predicate __pred; +}; + +struct __get_zeroth_element +{ + template + auto& + operator()(_Tp&& __a) const + { + return std::get<0>(std::forward<_Tp>(__a)); + } +}; + +struct __write_to_idx_if +{ + template + void + operator()(_OutRng&& __out, _SizeType __idx, const ValueType& __v) const + { + if (std::get<1>(__v)) + __out[std::get<0>(__v) - 1] = std::get<2>(__v); + } +}; + template auto @@ -782,39 +860,62 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen _InitType __init, _BinaryOperation __binary_op, _Inclusive) { using _Type = typename _InitType::__value_type; - - // Next power of 2 greater than or equal to __n - auto __n_uniform = __n; - if ((__n_uniform & (__n_uniform - 1)) != 0) - __n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1; - - // TODO: can we reimplement this with support fort non-identities as well? We can then use in reduce-then-scan - // for the last block if it is sufficiently small - constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value; - if constexpr (__can_use_group_scan) + // Reduce-then-scan is dependent on sycl::shift_group_right which requires the underlying type to be trivially + // copyable. If this is not met, then we must fallback to the legacy implementation. The single work-group implementation + // requires a fundamental type which must also be trivially copyable. + if constexpr (std::is_trivially_copyable_v<_Type>) { - if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform)) + // Next power of 2 greater than or equal to __n + auto __n_uniform = __n; + if ((__n_uniform & (__n_uniform - 1)) != 0) + __n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1; + + // TODO: can we reimplement this with support for non-identities as well? We can then use in reduce-then-scan + // for the last block if it is sufficiently small + constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value; + if constexpr (__can_use_group_scan) { - return __parallel_transform_scan_single_group( - __backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng), - ::std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{}); + if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform)) + { + return __parallel_transform_scan_single_group( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng), + ::std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{}); + } } + oneapi::dpl::__par_backend_hetero::__gen_transform_input<_UnaryOperation> __gen_transform{__unary_op}; + return __future(__parallel_transform_reduce_then_scan( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng), + std::forward<_Range2>(__out_rng), __gen_transform, __binary_op, __gen_transform, + oneapi::dpl::__internal::__no_op{}, __simple_write_to_idx{}, __init, _Inclusive{}) + .event()); + } + else + { + using _Assigner = unseq_backend::__scan_assigner; + using _NoAssign = unseq_backend::__scan_no_assign; + using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>; + using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>; + + _Assigner __assign_op; + _NoAssign __no_assign_op; + _NoOpFunctor __get_data_op; + + return __future( + __parallel_transform_scan_base( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng), + std::forward<_Range2>(__out_rng), __binary_op, __init, + // local scan + unseq_backend::__scan<_Inclusive, _ExecutionPolicy, _BinaryOperation, _UnaryFunctor, _Assigner, + _Assigner, _NoOpFunctor, _InitType>{__binary_op, _UnaryFunctor{__unary_op}, + __assign_op, __assign_op, __get_data_op}, + // scan between groups + unseq_backend::__scan>{ + __binary_op, _NoOpFunctor{}, __no_assign_op, __assign_op, __get_data_op}, + // global scan + unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init}) + .event()); } - - // TODO: Reintegrate once support has been added - //// Either we can't use group scan or this input is too big for one workgroup - //using _Assigner = unseq_backend::__scan_assigner; - //using _NoAssign = unseq_backend::__scan_no_assign; - //using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>; - //using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>; - - //_Assigner __assign_op; - //_NoAssign __no_assign_op; - //_NoOpFunctor __get_data_op; - return __future(__parallel_transform_reduce_then_scan(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), - ::std::forward<_Range1>(__in_rng), ::std::forward<_Range2>(__out_rng), - __binary_op, __unary_op, __init, _Inclusive{}) - .event()); } template @@ -915,15 +1016,14 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag, // The kernel stores n integers for the predicate and another n integers for the offsets const auto __req_slm_size = sizeof(::std::uint16_t) * __n_uniform * 2; - constexpr ::std::uint16_t __single_group_upper_limit = 16384; + constexpr ::std::uint16_t __single_group_upper_limit = 2048; ::std::size_t __max_wg_size = oneapi::dpl::__internal::__max_work_group_size(__exec); if (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size && __max_wg_size >= _SingleGroupInvoker::__targeted_wg_size) { - using _SizeBreakpoints = - ::std::integer_sequence<::std::uint16_t, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384>; + using _SizeBreakpoints = ::std::integer_sequence<::std::uint16_t, 16, 32, 64, 128, 256, 512, 1024, 2048>; return __par_backend_hetero::__static_monotonic_dispatcher<_SizeBreakpoints>::__dispatch( _SingleGroupInvoker{}, __n, ::std::forward<_ExecutionPolicy>(__exec), __n, ::std::forward<_InRng>(__in_rng), @@ -932,13 +1032,15 @@ __parallel_copy_if(oneapi::dpl::__internal::__device_backend_tag __backend_tag, else { using _ReduceOp = ::std::plus<_Size>; - using CreateOp = unseq_backend::__create_mask<_Pred, _Size>; - using CopyOp = unseq_backend::__copy_by_mask<_ReduceOp, oneapi::dpl::__internal::__pstl_assign, - /*inclusive*/ ::std::true_type, 1>; - return __parallel_scan_copy(__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), - ::std::forward<_InRng>(__in_rng), ::std::forward<_OutRng>(__out_rng), __n, - CreateOp{__pred}, CopyOp{}); + return __parallel_transform_reduce_then_scan( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), + std::forward<_OutRng>(__out_rng), oneapi::dpl::__par_backend_hetero::__gen_count_pred<_Pred>{__pred}, + _ReduceOp{}, oneapi::dpl::__par_backend_hetero::__gen_expand_count_pred<_Pred>{__pred}, + oneapi::dpl::__par_backend_hetero::__get_zeroth_element{}, + oneapi::dpl::__par_backend_hetero::__write_to_idx_if{}, + oneapi::dpl::unseq_backend::__no_init_value<_Size>{}, + /*_Inclusive=*/std::true_type{}); } } diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h index b46f86f4826..2979120ad1a 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h @@ -148,10 +148,11 @@ __sub_group_scan_partial(const _SubGroup& __sub_group, _ValueType& __value, _Bin } template + std::uint32_t __max_inputs_per_item, typename _SubGroup, typename _GenInput, typename _ScanInputTransform, + typename _BinaryOp, typename _WriteOp, typename _LazyValueType, typename _InRng, typename _OutRng> void -__scan_through_elements_helper(const _SubGroup& __sub_group, _UnaryOp __unary_op, _BinaryOp __binary_op, +__scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_input, + _ScanInputTransform __scan_input_transform, _BinaryOp __binary_op, _WriteOp __write_op, _LazyValueType& __sub_group_carry, _InRng __in_rng, _OutRng __out_rng, std::size_t __start_idx, std::size_t __n, std::uint32_t __iters_per_item, std::size_t __subgroup_start_idx, std::uint32_t __sub_group_id, @@ -161,43 +162,43 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _UnaryOp __unary_op bool __is_full_thread = __subgroup_start_idx + __iters_per_item * __sub_group_size <= __n; if (__is_full_thread && __is_full_block) { - auto __v = __unary_op(__in_rng[__start_idx]); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __v, __binary_op, - __sub_group_carry); + auto __v = __gen_input(__in_rng, __start_idx); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v), + __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__start_idx] = __v; + __write_op(__out_rng, __start_idx, __v); } _ONEDPL_PRAGMA_UNROLL for (std::uint32_t __j = 1; __j < __max_inputs_per_item; __j++) { - __v = __unary_op(__in_rng[__start_idx + __j * __sub_group_size]); - __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __v, __binary_op, - __sub_group_carry); + __v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__start_idx + __j * __sub_group_size] = __v; + __write_op(__out_rng, __start_idx + __j * __sub_group_size, __v); } } } else if (__is_full_thread) { - auto __v = __unary_op(__in_rng[__start_idx]); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __v, __binary_op, - __sub_group_carry); + auto __v = __gen_input(__in_rng, __start_idx); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v), + __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__start_idx] = __v; + __write_op(__out_rng, __start_idx, __v); } for (std::uint32_t __j = 1; __j < __iters_per_item; __j++) { - __v = __unary_op(__in_rng[__start_idx + __j * __sub_group_size]); - __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __v, __binary_op, - __sub_group_carry); + __v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__start_idx + __j * __sub_group_size] = __v; + __write_op(__out_rng, __start_idx + __j * __sub_group_size, __v); } } } @@ -209,47 +210,48 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _UnaryOp __unary_op if (__iters == 1) { - auto __v = __unary_op(__in_rng[__start_idx]); + auto __v = __gen_input(__in_rng, __start_idx); __sub_group_scan_partial<__sub_group_size, __is_inclusive, __init_present>( - __sub_group, __v, __binary_op, __sub_group_carry, __n - __subgroup_start_idx); + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, + __n - __subgroup_start_idx); if constexpr (__capture_output) { if (__start_idx < __n) - __out_rng[__start_idx] = __v; + __write_op(__out_rng, __start_idx, __v); } } else { - auto __v = __unary_op(__in_rng[__start_idx]); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __v, __binary_op, - __sub_group_carry); + auto __v = __gen_input(__in_rng, __start_idx); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__start_idx] = __v; + __write_op(__out_rng, __start_idx, __v); } for (int __j = 1; __j < __iters - 1; __j++) { auto __local_idx = __start_idx + __j * __sub_group_size; - __v = __unary_op(__in_rng[__local_idx]); + __v = __gen_input(__in_rng, __local_idx); __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( - __sub_group, __v, __binary_op, __sub_group_carry); + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { - __out_rng[__local_idx] = __v; + __write_op(__out_rng, __local_idx, __v); } } auto __offset = __start_idx + (__iters - 1) * __sub_group_size; auto __local_idx = (__offset < __n) ? __offset : __n - 1; - __v = __unary_op(__in_rng[__local_idx]); + __v = __gen_input(__in_rng, __local_idx); __sub_group_scan_partial<__sub_group_size, __is_inclusive, /*__init_present=*/true>( - __sub_group, __v, __binary_op, __sub_group_carry, + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __n - (__subgroup_start_idx + (__iters - 1) * __sub_group_size)); if constexpr (__capture_output) { if (__offset < __n) - __out_rng[__offset] = __v; + __write_op(__out_rng, __offset, __v); } } } @@ -263,43 +265,42 @@ template class __reduce_then_scan_scan_kernel; template + typename _GenReduceInput, typename _ReduceOp, typename _InitType, typename _KernelName> struct __parallel_reduce_then_scan_reduce_submitter; template + typename _GenReduceInput, typename _ReduceOp, typename _InitType, typename... _KernelName> struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive, - _BinaryOperation, _UnaryOperation, _InitType, + _GenReduceInput, _ReduceOp, _InitType, __internal::__optional_kernel_name<_KernelName...>> { // Step 1 - SubGroupReduce is expected to perform sub-group reductions to global memory // input buffer template auto - operator()(_ExecutionPolicy&& __exec, _InRng&& __in_rng, _TmpStorageAcc __tmp_storage, - const sycl::event& __prior_event, const std::size_t __inputs_per_sub_group, - const std::size_t __inputs_per_item, const std::size_t __block_num, const bool __is_full_block) const + operator()(_ExecutionPolicy&& __exec, const sycl::nd_range<1> __nd_range, _InRng&& __in_rng, + _TmpStorageAcc __scratch_container, const sycl::event& __prior_event, + const std::size_t __inputs_per_sub_group, const std::size_t __inputs_per_item, + const std::size_t __block_num) const { - using _InValueType = oneapi::dpl::__internal::__value_t<_InRng>; + using _InitValueType = typename _InitType::__value_type; return __exec.queue().submit([&, this](sycl::handler& __cgh) { - sycl::local_accessor<_InValueType> __sub_group_partials(__num_sub_groups_local, __cgh); + sycl::local_accessor<_InitValueType> __sub_group_partials(__num_sub_groups_local, __cgh); __cgh.depends_on(__prior_event); oneapi::dpl::__ranges::__require_access(__cgh, __in_rng); + auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh); __cgh.parallel_for<_KernelName...>(__nd_range, [=, *this](sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size( __sub_group_size)]] { - auto __id = __ndi.get_global_id(0); - auto __lid = __ndi.get_local_id(0); + auto __temp_ptr = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); auto __g = __ndi.get_group(0); auto __sub_group = __ndi.get_sub_group(); auto __sub_group_id = __sub_group.get_group_linear_id(); auto __sub_group_local_id = __sub_group.get_local_linear_id(); - oneapi::dpl::__internal::__lazy_ctor_storage<_InValueType> __sub_group_carry; + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __sub_group_carry; std::size_t __group_start_idx = (__block_num * __max_block_size) + (__g * __inputs_per_sub_group * __num_sub_groups_local); - if (__n <= __group_start_idx) - return; // exit early for empty groups (TODO: avoid launching these?) std::size_t __elements_in_group = std::min(__n - __group_start_idx, std::size_t(__num_sub_groups_local * __inputs_per_sub_group)); @@ -316,8 +317,9 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu __scan_through_elements_helper<__sub_group_size, __is_inclusive, /*__init_present=*/false, /*__capture_output=*/false, __max_inputs_per_item>( - __sub_group, __unary_op, __binary_op, __sub_group_carry, __in_rng, nullptr, __start_idx, __n, - __inputs_per_item, __subgroup_start_idx, __sub_group_id, __active_subgroups); + __sub_group, __gen_reduce_input, oneapi::dpl::__internal::__no_op{}, __reduce_op, nullptr, + __sub_group_carry, __in_rng, nullptr, __start_idx, __n, __inputs_per_item, __subgroup_start_idx, + __sub_group_id, __active_subgroups); if (__sub_group_local_id == 0) __sub_group_partials[__sub_group_id] = __sub_group_carry.__v; __sub_group_carry.__destroy(); @@ -342,25 +344,25 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu : (__active_subgroups - 1); // else is unused dummy value auto __v = __sub_group_partials[__load_idx]; __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( - __sub_group, __v, __binary_op, __sub_group_carry, + __sub_group, __v, __reduce_op, __sub_group_carry, __active_subgroups - __subgroup_start_idx); if (__sub_group_local_id < __active_subgroups) - __tmp_storage[__start_idx + __sub_group_local_id] = __v; + __temp_ptr[__start_idx + __sub_group_local_id] = __v; } else { //need to pull out first iteration tp avoid identity auto __v = __sub_group_partials[__sub_group_local_id]; __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( - __sub_group, __v, __binary_op, __sub_group_carry); - __tmp_storage[__start_idx + __sub_group_local_id] = __v; + __sub_group, __v, __reduce_op, __sub_group_carry); + __temp_ptr[__start_idx + __sub_group_local_id] = __v; for (int __i = 1; __i < __iters - 1; __i++) { __v = __sub_group_partials[__i * __sub_group_size + __sub_group_local_id]; __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( - __sub_group, __v, __binary_op, __sub_group_carry); - __tmp_storage[__start_idx + __i * __sub_group_size + __sub_group_local_id] = __v; + __sub_group, __v, __reduce_op, __sub_group_carry); + __temp_ptr[__start_idx + __i * __sub_group_size + __sub_group_local_id] = __v; } // If we are past the input range, then the previous value of v is passed to the sub-group scan. // It does not affect the result as our sub_group_scan will use a mask to only process in-range elements. @@ -372,9 +374,9 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu __v = __sub_group_partials[__load_idx]; __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( - __sub_group, __v, __binary_op, __sub_group_carry, __num_sub_groups_local); + __sub_group, __v, __reduce_op, __sub_group_carry, __num_sub_groups_local); if (__proposed_idx < __num_sub_groups_local) - __tmp_storage[__start_idx + __proposed_idx] = __v; + __temp_ptr[__start_idx + __proposed_idx] = __v; } __sub_group_carry.__destroy(); @@ -384,47 +386,53 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu } // Constant parameters throughout all blocks - const sycl::nd_range<1> __nd_range; - const std::size_t __max_block_size; const std::size_t __num_sub_groups_local; const std::size_t __num_sub_groups_global; const std::size_t __num_work_items; const std::size_t __n; - const _BinaryOperation __binary_op; - const _UnaryOperation __unary_op; + const _GenReduceInput __gen_reduce_input; + const _ReduceOp __reduce_op; _InitType __init; - - // TODO: Add the mask functors here to generalize for scan-based algorithms }; template + typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform, + typename _WriteOp, typename _InitType, typename _KernelName> struct __parallel_reduce_then_scan_scan_submitter; template -struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive, - _BinaryOperation, _UnaryOperation, _InitType, - __internal::__optional_kernel_name<_KernelName...>> + typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform, + typename _WriteOp, typename _InitType, typename... _KernelName> +struct __parallel_reduce_then_scan_scan_submitter< + __sub_group_size, __max_inputs_per_item, __is_inclusive, _GenReduceInput, _ReduceOp, _GenScanInput, + _ScanInputTransform, _WriteOp, _InitType, __internal::__optional_kernel_name<_KernelName...>> { template auto - operator()(_ExecutionPolicy&& __exec, _InRng&& __in_rng, _OutRng&& __out_rng, _TmpStorageAcc __tmp_storage, - const sycl::event& __prior_event, const std::size_t __inputs_per_sub_group, - const std::size_t __inputs_per_item, const std::size_t __block_num, const bool __is_full_block) const + operator()(_ExecutionPolicy&& __exec, const sycl::nd_range<1> __nd_range, _InRng&& __in_rng, _OutRng&& __out_rng, + _TmpStorageAcc __scratch_container, const sycl::event& __prior_event, + const std::size_t __inputs_per_sub_group, const std::size_t __inputs_per_item, + const std::size_t __block_num) const { - using _InValueType = oneapi::dpl::__internal::__value_t<_InRng>; + std::size_t __elements_in_block = std::min(__n - __block_num * __max_block_size, std::size_t(__max_block_size)); + std::size_t __active_groups = oneapi::dpl::__internal::__dpl_ceiling_div( + __elements_in_block, __inputs_per_sub_group * __num_sub_groups_local); using _InitValueType = typename _InitType::__value_type; return __exec.queue().submit([&, this](sycl::handler& __cgh) { - sycl::local_accessor<_InValueType> __sub_group_partials(__num_sub_groups_local + 1, __cgh); + sycl::local_accessor<_InitValueType> __sub_group_partials(__num_sub_groups_local + 1, __cgh); __cgh.depends_on(__prior_event); oneapi::dpl::__ranges::__require_access(__cgh, __in_rng, __out_rng); + auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh); + auto __res_acc = __scratch_container.__get_result_acc(__cgh); + __cgh.parallel_for<_KernelName...>(__nd_range, [=, *this](sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size( __sub_group_size)]] { - auto __id = __ndi.get_global_id(0); + auto __tmp_ptr = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); + auto __res_ptr = + _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__res_acc, __num_sub_groups_global + 1); auto __lid = __ndi.get_local_id(0); auto __g = __ndi.get_group(0); auto __sub_group = __ndi.get_sub_group(); @@ -433,18 +441,16 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs auto __group_start_idx = (__block_num * __max_block_size) + (__g * __inputs_per_sub_group * __num_sub_groups_local); - if (__n <= __group_start_idx) - return; // exit early for empty groups (TODO: avoid launching these?) std::size_t __elements_in_group = std::min(__n - __group_start_idx, std::size_t(__num_sub_groups_local * __inputs_per_sub_group)); std::uint32_t __active_subgroups = oneapi::dpl::__internal::__dpl_ceiling_div(__elements_in_group, __inputs_per_sub_group); - oneapi::dpl::__internal::__lazy_ctor_storage<_InValueType> __carry_last; - oneapi::dpl::__internal::__lazy_ctor_storage<_InValueType> __value; + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __carry_last; + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __value; // propogate carry in from previous block - oneapi::dpl::__internal::__lazy_ctor_storage<_InValueType> __sub_group_carry; + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __sub_group_carry; // on the first sub-group in a work-group (assuming S subgroups in a work-group): // 1. load S sub-group local carry pfix sums (T0..TS-1) to slm @@ -469,12 +475,12 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs for (; __i < __iters - 1; __i++) { __sub_group_partials[__i * __sub_group_size + __sub_group_local_id] = - __tmp_storage[__subgroups_before_my_group + __i * __sub_group_size + __sub_group_local_id]; + __tmp_ptr[__subgroups_before_my_group + __i * __sub_group_size + __sub_group_local_id]; } if (__i * __sub_group_size + __sub_group_local_id < __active_subgroups) { __sub_group_partials[__i * __sub_group_size + __sub_group_local_id] = - __tmp_storage[__subgroups_before_my_group + __i * __sub_group_size + __sub_group_local_id]; + __tmp_ptr[__subgroups_before_my_group + __i * __sub_group_size + __sub_group_local_id]; } // step 2) load 32, 64, 96, etc. work-group carry outs on every work-group; then @@ -496,27 +502,27 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs auto __reduction_idx = (__proposed_idx < __subgroups_before_my_group) ? __proposed_idx : __subgroups_before_my_group - 1; - __value.__setup(__tmp_storage[__reduction_idx]); + __value.__setup(__tmp_ptr[__reduction_idx]); __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, - /*__init_present=*/false>(__sub_group, __value.__v, __binary_op, + /*__init_present=*/false>(__sub_group, __value.__v, __reduce_op, __carry_last, __remaining_elements); } else { // multiple iterations // first 1 full - __value.__setup(__tmp_storage[__num_sub_groups_local * __sub_group_local_id + __offset]); + __value.__setup(__tmp_ptr[__num_sub_groups_local * __sub_group_local_id + __offset]); __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( - __sub_group, __value.__v, __binary_op, __carry_last); + __sub_group, __value.__v, __reduce_op, __carry_last); // then some number of full iterations for (int __i = 1; __i < __pre_carry_iters - 1; __i++) { auto __reduction_idx = __i * __num_sub_groups_local * __sub_group_size + __num_sub_groups_local * __sub_group_local_id + __offset; - __value.__v = __tmp_storage[__reduction_idx]; + __value.__v = __tmp_ptr[__reduction_idx]; __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( - __sub_group, __value.__v, __binary_op, __carry_last); + __sub_group, __value.__v, __reduce_op, __carry_last); } // final partial iteration @@ -527,9 +533,9 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs auto __reduction_idx = (__proposed_idx < __subgroups_before_my_group) ? __proposed_idx : __subgroups_before_my_group - 1; - __value.__v = __tmp_storage[__reduction_idx]; + __value.__v = __tmp_ptr[__reduction_idx]; __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, - /*__init_present=*/true>(__sub_group, __value.__v, __binary_op, + /*__init_present=*/true>(__sub_group, __value.__v, __reduce_op, __carry_last, __remaining_elements); } } @@ -552,13 +558,13 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs for (; __i < __iters - 1; ++__i) { __sub_group_partials[__carry_offset + __sub_group_local_id] = - __binary_op(__carry_last.__v, __sub_group_partials[__carry_offset + __sub_group_local_id]); + __reduce_op(__carry_last.__v, __sub_group_partials[__carry_offset + __sub_group_local_id]); __carry_offset += __sub_group_size; } if (__i * __sub_group_size + __sub_group_local_id < __active_subgroups) { __sub_group_partials[__carry_offset + __sub_group_local_id] = - __binary_op(__carry_last.__v, __sub_group_partials[__carry_offset + __sub_group_local_id]); + __reduce_op(__carry_last.__v, __sub_group_partials[__carry_offset + __sub_group_local_id]); __carry_offset += __sub_group_size; } if (__sub_group_local_id == 0) @@ -577,18 +583,19 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs if (__sub_group_id > 0) { auto __value = __sub_group_partials[__sub_group_id - 1]; - oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __binary_op); + oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __reduce_op); __sub_group_carry.__setup(__value); } else if (__g > 0) { auto __value = __sub_group_partials[__active_subgroups]; - oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __binary_op); + oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __reduce_op); __sub_group_carry.__setup(__value); } else { - if constexpr (std::is_same_v<_InitType, oneapi::dpl::unseq_backend::__no_init_value<_InitValueType>>) + if constexpr (std::is_same_v<_InitType, + oneapi::dpl::unseq_backend::__no_init_value<_InitValueType>>) { // This is the only case where we still don't have a carry in. No init value, 0th block, // group, and subgroup. This changes the final scan through elements below. @@ -604,109 +611,96 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs { if (__sub_group_id > 0) { - if constexpr (__is_inclusive) - __sub_group_carry.__setup(__binary_op(__out_rng[__block_num * __max_block_size - 1], - __sub_group_partials[__sub_group_id - 1])); - else // The last block wrote an exclusive result, so we must make it inclusive. - { - // Grab the last element from the previous block that has been cached in temporary - // storage in the second kernel of the previous block. - _InValueType __last_block_element = __unary_op(__tmp_storage[__num_sub_groups_global]); - __sub_group_carry.__setup(__binary_op( - __binary_op(__out_rng[__block_num * __max_block_size - 1], __last_block_element), - __sub_group_partials[__sub_group_id - 1])); - } + __sub_group_carry.__setup( + __reduce_op(__tmp_ptr[__num_sub_groups_global], __sub_group_partials[__sub_group_id - 1])); } else if (__g > 0) { - if constexpr (__is_inclusive) - __sub_group_carry.__setup(__binary_op(__out_rng[__block_num * __max_block_size - 1], - __sub_group_partials[__active_subgroups])); - else // The last block wrote an exclusive result, so we must make it inclusive. - { - // Grab the last element from the previous block that has been cached in temporary - // storage in the second kernel of the previous block. - _InValueType __last_block_element = __unary_op(__tmp_storage[__num_sub_groups_global]); - __sub_group_carry.__setup(__binary_op( - __binary_op(__out_rng[__block_num * __max_block_size - 1], __last_block_element), - __sub_group_partials[__active_subgroups])); - } + __sub_group_carry.__setup( + __reduce_op(__tmp_ptr[__num_sub_groups_global], __sub_group_partials[__active_subgroups])); } else { - if constexpr (__is_inclusive) - __sub_group_carry.__setup(__out_rng[__block_num * __max_block_size - 1]); - else // The last block wrote an exclusive result, so we must make it inclusive. - { - // Grab the last element from the previous block that has been cached in temporary - // storage in the second kernel of the previous block. - _InValueType __last_block_element = __unary_op(__tmp_storage[__num_sub_groups_global]); - __sub_group_carry.__setup( - __binary_op(__out_rng[__block_num * __max_block_size - 1], __last_block_element)); - } - } - } - // For the exclusive scan case: - // Have the last item in the group store the last element - // in the block to temporary storage for use in the next block. - // This is required to support in-place exclusive scans as the input values will be overwritten. - if constexpr (!__is_inclusive) - { - auto __global_id = __ndi.get_global_linear_id(); - if (__global_id == __num_work_items - 1) - { - std::size_t __last_idx_in_block = std::min(__n - 1, __max_block_size * (__block_num + 1) - 1); - __tmp_storage[__num_sub_groups_global] = __in_rng[__last_idx_in_block]; + __sub_group_carry.__setup(__tmp_ptr[__num_sub_groups_global]); } } // step 5) apply global carries - size_t __subgroup_start_idx = __group_start_idx + (__sub_group_id * __inputs_per_sub_group); - size_t __start_idx = __subgroup_start_idx + __sub_group_local_id; + std::size_t __subgroup_start_idx = __group_start_idx + (__sub_group_id * __inputs_per_sub_group); + std::size_t __start_idx = __subgroup_start_idx + __sub_group_local_id; if (__sub_group_carry_initialized) { __scan_through_elements_helper<__sub_group_size, __is_inclusive, /*__init_present=*/true, /*__capture_output=*/true, __max_inputs_per_item>( - __sub_group, __unary_op, __binary_op, __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, - __inputs_per_item, __subgroup_start_idx, __sub_group_id, __active_subgroups); - - __sub_group_carry.__destroy(); + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, + __subgroup_start_idx, __sub_group_id, __active_subgroups); } else // first group first block, no subgroup carry { __scan_through_elements_helper<__sub_group_size, __is_inclusive, /*__init_present=*/false, /*__capture_output=*/true, __max_inputs_per_item>( - __sub_group, __unary_op, __binary_op, __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, - __inputs_per_item, __subgroup_start_idx, __sub_group_id, __active_subgroups); + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, + __subgroup_start_idx, __sub_group_id, __active_subgroups); + } + //If within the last active group and subgroup of the block, use the 0th work item of the subgroup + // to write out the last carry out for either the return value or the next block + if (__sub_group_local_id == 0 && (__active_groups == __g + 1) && + (__active_subgroups == __sub_group_id + 1)) + { + if (__block_num + 1 == __num_blocks) + { + __res_ptr[0] = __sub_group_carry.__v; + } + else + { + //capture the last carry out for the next block + __tmp_ptr[__num_sub_groups_global] = __sub_group_carry.__v; + } } + + __sub_group_carry.__destroy(); }); }); } - const sycl::nd_range<1> __nd_range; - const std::size_t __max_block_size; const std::size_t __num_sub_groups_local; const std::size_t __num_sub_groups_global; const std::size_t __num_work_items; + const std::size_t __num_blocks; const std::size_t __n; - const _BinaryOperation __binary_op; - const _UnaryOperation __unary_op; + const _GenReduceInput __gen_reduce_input; + const _ReduceOp __reduce_op; + const _GenScanInput __gen_scan_input; + const _ScanInputTransform __scan_input_transform; + const _WriteOp __write_op; _InitType __init; - - // TODO: Add the mask functors here to generalize for scan-based algorithms }; -template +// General scan-like algorithm helpers +// _GenReduceInput - a function which accepts the input range and index to generate the data needed by the main output +// used in the reduction operation (to calculate the global carries) +// _GenScanInput - a function which accepts the input range and index to generate the data needed by the final scan +// and write operations, for scan patterns +// _ScanInputTransform - a unary function applied to the ouput of `_GenScanInput` to extract the component used in the scan, but +// not the part only required for the final write operation +// _ReduceOp - a binary function which is used in the reduction and scan operations +// _WriteOp - a function which accepts output range, index, and output of `_GenScanInput` applied to the input range +// and performs the final write to output operation +template auto __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, - _InRng&& __in_rng, _OutRng&& __out_rng, _BinaryOperation __binary_op, - _UnaryOperation __unary_op, + _InRng&& __in_rng, _OutRng&& __out_rng, _GenReduceInput __gen_reduce_input, + _ReduceOp __reduce_op, _GenScanInput __gen_scan_input, + _ScanInputTransform __scan_input_transform, _WriteOp __write_op, _InitType __init /*TODO mask assigners for generalization go here*/, _Inclusive) { using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; @@ -739,28 +733,27 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_ : std::max(__sub_group_size, oneapi::dpl::__internal::__dpl_bit_ceil(__num_remaining) / __num_sub_groups_global); auto __inputs_per_item = __inputs_per_sub_group / __sub_group_size; - const auto __global_range = sycl::range<1>(__num_work_items); - const auto __local_range = sycl::range<1>(__work_group_size); - const auto __kernel_nd_range = sycl::nd_range<1>(__global_range, __local_range); const auto __block_size = (__n < __max_inputs_per_block) ? __n : __max_inputs_per_block; const auto __num_blocks = __n / __block_size + (__n % __block_size != 0); - // TODO: Use the trick in reduce to wrap in a shared_ptr with custom deleter to support asynchronous frees. - _ValueType* __tmp_storage = sycl::malloc_device<_ValueType>(__num_sub_groups_global + 1, __exec.queue()); + __result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{__exec, + __num_sub_groups_global + 1}; // Reduce and scan step implementations using _ReduceSubmitter = __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inputs_per_item, __inclusive, - _BinaryOperation, _UnaryOperation, _InitType, _ReduceKernel>; + _GenReduceInput, _ReduceOp, _InitType, _ReduceKernel>; using _ScanSubmitter = __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __inclusive, - _BinaryOperation, _UnaryOperation, _InitType, _ScanKernel>; + _GenReduceInput, _ReduceOp, _GenScanInput, _ScanInputTransform, + _WriteOp, _InitType, _ScanKernel>; // TODO: remove below before merging. used for convenience now // clang-format off - _ReduceSubmitter __reduce_submitter{__kernel_nd_range, __max_inputs_per_block, __num_sub_groups_local, - __num_sub_groups_global, __num_work_items, __n, __binary_op, __unary_op, __init}; - _ScanSubmitter __scan_submitter{__kernel_nd_range, __max_inputs_per_block, __num_sub_groups_local, - __num_sub_groups_global, __num_work_items, __n, __binary_op, __unary_op, __init}; + _ReduceSubmitter __reduce_submitter{__max_inputs_per_block, __num_sub_groups_local, + __num_sub_groups_global, __num_work_items, __n, __gen_reduce_input, __reduce_op, __init}; + _ScanSubmitter __scan_submitter{__max_inputs_per_block, __num_sub_groups_local, + __num_sub_groups_global, __num_work_items, __num_blocks, __n, __gen_reduce_input, __reduce_op, __gen_scan_input, __scan_input_transform, + __write_op, __init}; // clang-format on sycl::event __event; @@ -768,13 +761,19 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_ // with sufficiently large L2 / L3 caches. for (std::size_t __b = 0; __b < __num_blocks; ++__b) { - bool __is_full_block = __inputs_per_item == __max_inputs_per_item; + auto __elements_in_block = oneapi::dpl::__internal::__dpl_ceiling_div( + std::min(__num_remaining, __max_inputs_per_block), __inputs_per_item); + auto __ele_in_block_round_up_workgroup = + oneapi::dpl::__internal::__dpl_ceiling_div(__elements_in_block, __work_group_size) * __work_group_size; + auto __global_range = sycl::range<1>(__ele_in_block_round_up_workgroup); + auto __local_range = sycl::range<1>(__work_group_size); + auto __kernel_nd_range = sycl::nd_range<1>(__global_range, __local_range); // 1. Reduce step - Reduce assigned input per sub-group, compute and apply intra-wg carries, and write to global memory. - __event = __reduce_submitter(__exec, __in_rng, __tmp_storage, __event, __inputs_per_sub_group, - __inputs_per_item, __b, __is_full_block); + __event = __reduce_submitter(__exec, __kernel_nd_range, __in_rng, __result_and_scratch, __event, + __inputs_per_sub_group, __inputs_per_item, __b); // 2. Scan step - Compute intra-wg carries, determine sub-group carry-ins, and perform full input block scan. - __event = __scan_submitter(__exec, __in_rng, __out_rng, __tmp_storage, __event, __inputs_per_sub_group, - __inputs_per_item, __b, __is_full_block); + __event = __scan_submitter(__exec, __kernel_nd_range, __in_rng, __out_rng, __result_and_scratch, __event, + __inputs_per_sub_group, __inputs_per_item, __b); if (__num_remaining > __block_size) { // Resize for the next block. @@ -789,10 +788,7 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_ __inputs_per_item = __inputs_per_sub_group / __sub_group_size; } } - // TODO: Remove to make asynchronous. Depends on completing async USM free TODO. - __event.wait(); - sycl::free(__tmp_storage, __exec.queue()); - return __future(__event); + return __future(__event, __result_and_scratch); } } // namespace __par_backend_hetero diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h index e0b153e31e2..b49a5559108 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h @@ -510,6 +510,7 @@ struct __usm_or_buffer_accessor template struct __result_and_scratch_storage { + using __value_type = _T; private: using __sycl_buffer_t = sycl::buffer<_T, 1>;