diff --git a/include/oneapi/dpl/experimental/kt/single_pass_scan.h b/include/oneapi/dpl/experimental/kt/single_pass_scan.h index 94d0474f402..635ed869132 100644 --- a/include/oneapi/dpl/experimental/kt/single_pass_scan.h +++ b/include/oneapi/dpl/experimental/kt/single_pass_scan.h @@ -332,7 +332,7 @@ __single_pass_scan(sycl::queue __queue, _InRange&& __in_rng, _OutRange&& __out_r auto __n_uniform = ::oneapi::dpl::__internal::__dpl_bit_ceil(__n); // Perform a single-work group scan if the input is small - if (oneapi::dpl::__par_backend_hetero::__group_scan_fits_in_slm<_Type>(__queue, __n, __n_uniform)) + if (oneapi::dpl::__par_backend_hetero::__group_scan_fits_in_slm<_Type>(__queue, __n, __n_uniform, /*limit=*/16384)) { return oneapi::dpl::__par_backend_hetero::__parallel_transform_scan_single_group( oneapi::dpl::__internal::__device_backend_tag{}, diff --git a/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h b/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h index cd33872aa0d..266fad7e410 100644 --- a/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h +++ b/include/oneapi/dpl/internal/exclusive_scan_by_segment_impl.h @@ -161,7 +161,8 @@ exclusive_scan_by_segment_impl(__internal::__hetero_tag<_BackendTag>, Policy&& p transform_inclusive_scan(::std::move(policy2), make_zip_iterator(_temp.get(), _flags.get()), make_zip_iterator(_temp.get(), _flags.get()) + n, make_zip_iterator(result, _flags.get()), internal::segmented_scan_fun(binary_op), - oneapi::dpl::__internal::__no_op(), ::std::make_tuple(init, FlagType(1))); + oneapi::dpl::__internal::__no_op(), + oneapi::dpl::__internal::make_tuple(init, FlagType(1))); return result + n; } diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 76c6abb1b34..27ce403a018 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -38,6 +38,7 @@ #include "parallel_backend_sycl_reduce.h" #include "parallel_backend_sycl_merge.h" #include "parallel_backend_sycl_merge_sort.h" +#include "parallel_backend_sycl_reduce_then_scan.h" #include "execution_sycl_defs.h" #include "sycl_iterator.h" #include "unseq_backend_sycl.h" @@ -753,10 +754,9 @@ __parallel_transform_scan_base(oneapi::dpl::__internal::__device_backend_tag, _E template bool -__group_scan_fits_in_slm(const sycl::queue& __queue, ::std::size_t __n, ::std::size_t __n_uniform) +__group_scan_fits_in_slm(const sycl::queue& __queue, std::size_t __n, std::size_t __n_uniform, + std::size_t __single_group_upper_limit) { - constexpr int __single_group_upper_limit = 16384; - // Pessimistically only use half of the memory to take into account memory used by compiled kernel const ::std::size_t __max_slm_size = __queue.get_device().template get_info() / 2; @@ -765,6 +765,37 @@ __group_scan_fits_in_slm(const sycl::queue& __queue, ::std::size_t __n, ::std::s return (__n <= __single_group_upper_limit && __max_slm_size >= __req_slm_size); } +template +struct __gen_transform_input +{ + template + auto + operator()(const _InRng& __in_rng, std::size_t __id) const + { + // We explicitly convert __in_rng[__id] to the value type of _InRng to properly handle the case where we + // process zip_iterator input where the reference type is a tuple of a references. This prevents the caller + // from modifying the input range when altering the return of this functor. + using _ValueType = oneapi::dpl::__internal::__value_t<_InRng>; + return __unary_op(_ValueType{__in_rng[__id]}); + } + _UnaryOp __unary_op; +}; + +struct __simple_write_to_id +{ + template + void + operator()(_OutRng& __out_rng, std::size_t __id, const _ValueType& __v) const + { + // Use of an explicit cast to our internal tuple type is required to resolve conversion issues between our + // internal tuple and std::tuple. If the underlying type is not a tuple, then the type will just be passed through. + using _ConvertedTupleType = + typename oneapi::dpl::__internal::__get_tuple_type, + std::decay_t>::__type; + __out_rng[__id] = static_cast<_ConvertedTupleType>(__v); + } +}; + template auto @@ -773,24 +804,46 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen _InitType __init, _BinaryOperation __binary_op, _Inclusive) { using _Type = typename _InitType::__value_type; + // Reduce-then-scan is dependent on sycl::shift_group_right which requires the underlying type to be trivially + // copyable. If this is not met, then we must fallback to the multi pass scan implementation. The single + // work-group implementation requires a fundamental type which must also be trivially copyable. + if constexpr (std::is_trivially_copyable_v<_Type>) + { + bool __use_reduce_then_scan = oneapi::dpl::__par_backend_hetero::__is_gpu_with_sg_32(__exec); - // Next power of 2 greater than or equal to __n - auto __n_uniform = __n; - if ((__n_uniform & (__n_uniform - 1)) != 0) - __n_uniform = oneapi::dpl::__internal::__dpl_bit_floor(__n) << 1; + // TODO: Consider re-implementing single group scan to support types without known identities. This could also + // allow us to use single wg scan for the last block of reduce-then-scan if it is sufficiently small. + constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value; + if constexpr (__can_use_group_scan) + { + // Next power of 2 greater than or equal to __n + std::size_t __n_uniform = oneapi::dpl::__internal::__dpl_bit_ceil(__n); - constexpr bool __can_use_group_scan = unseq_backend::__has_known_identity<_BinaryOperation, _Type>::value; - if constexpr (__can_use_group_scan) - { - if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform)) + // Empirically found values for reduce-then-scan and multi pass scan implementation for single wg cutoff + std::size_t __single_group_upper_limit = __use_reduce_then_scan ? 2048 : 16384; + if (__group_scan_fits_in_slm<_Type>(__exec.queue(), __n, __n_uniform, __single_group_upper_limit)) + { + return __parallel_transform_scan_single_group( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng), + std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{}); + } + } + if (__use_reduce_then_scan) { - return __parallel_transform_scan_single_group( - __backend_tag, std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng), - ::std::forward<_Range2>(__out_rng), __n, __unary_op, __init, __binary_op, _Inclusive{}); + using _GenInput = oneapi::dpl::__par_backend_hetero::__gen_transform_input<_UnaryOperation>; + using _ScanInputTransform = oneapi::dpl::__internal::__no_op; + using _WriteOp = oneapi::dpl::__par_backend_hetero::__simple_write_to_id; + + _GenInput __gen_transform{__unary_op}; + + return __parallel_transform_reduce_then_scan( + __backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__in_rng), + std::forward<_Range2>(__out_rng), __gen_transform, __binary_op, __gen_transform, _ScanInputTransform{}, + _WriteOp{}, __init, _Inclusive{}); } } - // Either we can't use group scan or this input is too big for one workgroup + //else use multi pass scan implementation using _Assigner = unseq_backend::__scan_assigner; using _NoAssign = unseq_backend::__scan_no_assign; using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>; diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h new file mode 100644 index 00000000000..ed979f77ba3 --- /dev/null +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h @@ -0,0 +1,839 @@ +// -*- C++ -*- +//===-- parallel_backend_sycl_reduce_then_scan.h ---------------------------------===// +// +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// This file incorporates work covered by the following copyright and permission +// notice: +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// +//===----------------------------------------------------------------------===// + +#ifndef _ONEDPL_PARALLEL_BACKEND_SYCL_REDUCE_THEN_SCAN_H +#define _ONEDPL_PARALLEL_BACKEND_SYCL_REDUCE_THEN_SCAN_H + +#include +#include +#include + +#include "sycl_defs.h" +#include "parallel_backend_sycl_utils.h" +#include "execution_sycl_defs.h" +#include "unseq_backend_sycl.h" +#include "utils_ranges_sycl.h" + +#include "../../utils.h" + +namespace oneapi +{ +namespace dpl +{ +namespace __par_backend_hetero +{ + +template +void +__exclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, + _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, + _LazyValueType& __init_and_carry) +{ + std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); + _ONEDPL_PRAGMA_UNROLL + for (std::uint8_t __shift = 1; __shift <= __sub_group_size / 2; __shift <<= 1) + { + _ValueType __partial_carry_in = sycl::shift_group_right(__sub_group, __value, __shift); + if (__mask_fn(__sub_group_local_id, __shift)) + { + __value = __binary_op(__partial_carry_in, __value); + } + } + _LazyValueType __old_init; + if constexpr (__init_present) + { + __value = __binary_op(__init_and_carry.__v, __value); + if (__sub_group_local_id == 0) + __old_init.__setup(__init_and_carry.__v); + __init_and_carry.__v = sycl::group_broadcast(__sub_group, __value, __init_broadcast_id); + } + else + { + __init_and_carry.__setup(sycl::group_broadcast(__sub_group, __value, __init_broadcast_id)); + } + + __value = sycl::shift_group_right(__sub_group, __value, 1); + if constexpr (__init_present) + { + if (__sub_group_local_id == 0) + { + __value = __old_init.__v; + __old_init.__destroy(); + } + } + //return by reference __value and __init_and_carry +} + +template +void +__inclusive_sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, + _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, + _LazyValueType& __init_and_carry) +{ + std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); + _ONEDPL_PRAGMA_UNROLL + for (std::uint8_t __shift = 1; __shift <= __sub_group_size / 2; __shift <<= 1) + { + _ValueType __partial_carry_in = sycl::shift_group_right(__sub_group, __value, __shift); + if (__mask_fn(__sub_group_local_id, __shift)) + { + __value = __binary_op(__partial_carry_in, __value); + } + } + if constexpr (__init_present) + { + __value = __binary_op(__init_and_carry.__v, __value); + __init_and_carry.__v = sycl::group_broadcast(__sub_group, __value, __init_broadcast_id); + } + else + { + __init_and_carry.__setup(sycl::group_broadcast(__sub_group, __value, __init_broadcast_id)); + } + //return by reference __value and __init_and_carry +} + +template +void +__sub_group_masked_scan(const __dpl_sycl::__sub_group& __sub_group, _MaskOp __mask_fn, + _InitBroadcastId __init_broadcast_id, _ValueType& __value, _BinaryOp __binary_op, + _LazyValueType& __init_and_carry) +{ + if constexpr (__is_inclusive) + { + __inclusive_sub_group_masked_scan<__sub_group_size, __init_present>(__sub_group, __mask_fn, __init_broadcast_id, + __value, __binary_op, __init_and_carry); + } + else + { + __exclusive_sub_group_masked_scan<__sub_group_size, __init_present>(__sub_group, __mask_fn, __init_broadcast_id, + __value, __binary_op, __init_and_carry); + } +} + +template +void +__sub_group_scan(const __dpl_sycl::__sub_group& __sub_group, _ValueType& __value, _BinaryOp __binary_op, + _LazyValueType& __init_and_carry) +{ + auto __mask_fn = [](auto __sub_group_local_id, auto __offset) { return __sub_group_local_id >= __offset; }; + constexpr std::uint8_t __init_broadcast_id = __sub_group_size - 1; + __sub_group_masked_scan<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry); +} + +template +void +__sub_group_scan_partial(const __dpl_sycl::__sub_group& __sub_group, _ValueType& __value, _BinaryOp __binary_op, + _LazyValueType& __init_and_carry, _SizeType __elements_to_process) +{ + auto __mask_fn = [__elements_to_process](auto __sub_group_local_id, auto __offset) { + return __sub_group_local_id >= __offset && __sub_group_local_id < __elements_to_process; + }; + std::uint8_t __init_broadcast_id = __elements_to_process - 1; + __sub_group_masked_scan<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __mask_fn, __init_broadcast_id, __value, __binary_op, __init_and_carry); +} + +template +void +__scan_through_elements_helper(const __dpl_sycl::__sub_group& __sub_group, _GenInput __gen_input, + _ScanInputTransform __scan_input_transform, _BinaryOp __binary_op, _WriteOp __write_op, + _LazyValueType& __sub_group_carry, const _InRng& __in_rng, _OutRng& __out_rng, + std::size_t __start_id, std::size_t __n, std::uint32_t __iters_per_item, + std::size_t __subgroup_start_id, std::uint32_t __sub_group_id, + std::uint32_t __active_subgroups) +{ + using _GenInputType = std::invoke_result_t<_GenInput, _InRng, std::size_t>; + + bool __is_full_block = (__iters_per_item == __max_inputs_per_item); + bool __is_full_thread = __subgroup_start_id + __iters_per_item * __sub_group_size <= __n; + if (__is_full_thread) + { + _GenInputType __v = __gen_input(__in_rng, __start_id); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v), + __binary_op, __sub_group_carry); + if constexpr (__capture_output) + { + __write_op(__out_rng, __start_id, __v); + } + + if (__is_full_block) + { + // For full block and full thread, we can unroll the loop + _ONEDPL_PRAGMA_UNROLL + for (std::uint32_t __j = 1; __j < __max_inputs_per_item; __j++) + { + __v = __gen_input(__in_rng, __start_id + __j * __sub_group_size); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); + if constexpr (__capture_output) + { + __write_op(__out_rng, __start_id + __j * __sub_group_size, __v); + } + } + } + else + { + // For full thread but not full block, we can't unroll the loop, but we + // can proceed without special casing for partial subgroups. + for (std::uint32_t __j = 1; __j < __iters_per_item; __j++) + { + __v = __gen_input(__in_rng, __start_id + __j * __sub_group_size); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); + if constexpr (__capture_output) + { + __write_op(__out_rng, __start_id + __j * __sub_group_size, __v); + } + } + } + } + else + { + // For partial thread, we need to handle the partial subgroup at the end of the range + if (__sub_group_id < __active_subgroups) + { + std::uint32_t __iters = + oneapi::dpl::__internal::__dpl_ceiling_div(__n - __subgroup_start_id, __sub_group_size); + + if (__iters == 1) + { + std::size_t __local_id = (__start_id < __n) ? __start_id : __n - 1; + _GenInputType __v = __gen_input(__in_rng, __local_id); + __sub_group_scan_partial<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, + __n - __subgroup_start_id); + if constexpr (__capture_output) + { + if (__start_id < __n) + __write_op(__out_rng, __start_id, __v); + } + } + else + { + _GenInputType __v = __gen_input(__in_rng, __start_id); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); + if constexpr (__capture_output) + { + __write_op(__out_rng, __start_id, __v); + } + + for (std::uint32_t __j = 1; __j < __iters - 1; __j++) + { + std::size_t __local_id = __start_id + __j * __sub_group_size; + __v = __gen_input(__in_rng, __local_id); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); + if constexpr (__capture_output) + { + __write_op(__out_rng, __local_id, __v); + } + } + + std::size_t __offset = __start_id + (__iters - 1) * __sub_group_size; + std::size_t __local_id = (__offset < __n) ? __offset : __n - 1; + __v = __gen_input(__in_rng, __local_id); + __sub_group_scan_partial<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, + __n - (__subgroup_start_id + (__iters - 1) * __sub_group_size)); + if constexpr (__capture_output) + { + if (__offset < __n) + __write_op(__out_rng, __offset, __v); + } + } + } + } +} + +template +class __reduce_then_scan_reduce_kernel; + +template +class __reduce_then_scan_scan_kernel; + +template +struct __parallel_reduce_then_scan_reduce_submitter; + +template +struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive, + _GenReduceInput, _ReduceOp, _InitType, + __internal::__optional_kernel_name<_KernelName...>> +{ + // Step 1 - SubGroupReduce is expected to perform sub-group reductions to global memory + // input buffer + template + sycl::event + operator()(_ExecutionPolicy&& __exec, const sycl::nd_range<1> __nd_range, _InRng&& __in_rng, + _TmpStorageAcc& __scratch_container, const sycl::event& __prior_event, + const std::uint32_t __inputs_per_sub_group, const std::uint32_t __inputs_per_item, + const std::size_t __block_num) const + { + using _InitValueType = typename _InitType::__value_type; + return __exec.queue().submit([&, this](sycl::handler& __cgh) { + __dpl_sycl::__local_accessor<_InitValueType> __sub_group_partials(__num_sub_groups_local, __cgh); + __cgh.depends_on(__prior_event); + oneapi::dpl::__ranges::__require_access(__cgh, __in_rng); + auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh); + __cgh.parallel_for<_KernelName...>( + __nd_range, [=, *this](sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size(__sub_group_size)]] { + _InitValueType* __temp_ptr = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); + std::size_t __group_id = __ndi.get_group(0); + __dpl_sycl::__sub_group __sub_group = __ndi.get_sub_group(); + std::uint32_t __sub_group_id = __sub_group.get_group_linear_id(); + std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); + + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __sub_group_carry; + std::size_t __group_start_id = + (__block_num * __max_block_size) + (__group_id * __inputs_per_sub_group * __num_sub_groups_local); + + std::size_t __max_inputs_in_group = __inputs_per_sub_group * __num_sub_groups_local; + std::uint32_t __inputs_in_group = std::min(__n - __group_start_id, __max_inputs_in_group); + std::uint32_t __active_subgroups = + oneapi::dpl::__internal::__dpl_ceiling_div(__inputs_in_group, __inputs_per_sub_group); + std::size_t __subgroup_start_id = __group_start_id + (__sub_group_id * __inputs_per_sub_group); + + std::size_t __start_id = __subgroup_start_id + __sub_group_local_id; + + if (__sub_group_id < __active_subgroups) + { + // adjust for lane-id + // compute sub-group local prefix on T0..63, K samples/T, send to accumulator kernel + __scan_through_elements_helper<__sub_group_size, __is_inclusive, + /*__init_present=*/false, + /*__capture_output=*/false, __max_inputs_per_item>( + __sub_group, __gen_reduce_input, oneapi::dpl::__internal::__no_op{}, __reduce_op, nullptr, + __sub_group_carry, __in_rng, /*unused*/ __in_rng, __start_id, __n, __inputs_per_item, + __subgroup_start_id, __sub_group_id, __active_subgroups); + if (__sub_group_local_id == 0) + __sub_group_partials[__sub_group_id] = __sub_group_carry.__v; + __sub_group_carry.__destroy(); + } + __dpl_sycl::__group_barrier(__ndi); + + // compute sub-group local prefix sums on (T0..63) carries + // and store to scratch space at the end of dst; next + // accumulator kernel takes M thread carries from scratch + // to compute a prefix sum on global carries + if (__sub_group_id == 0) + { + __start_id = (__group_id * __num_sub_groups_local); + std::uint8_t __iters = + oneapi::dpl::__internal::__dpl_ceiling_div(__active_subgroups, __sub_group_size); + if (__iters == 1) + { + // fill with unused dummy values to avoid overruning input + std::uint32_t __load_id = std::min(std::uint32_t{__sub_group_local_id}, __active_subgroups - 1); + _InitValueType __v = __sub_group_partials[__load_id]; + __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __v, __reduce_op, __sub_group_carry, __active_subgroups); + if (__sub_group_local_id < __active_subgroups) + __temp_ptr[__start_id + __sub_group_local_id] = __v; + } + else + { + std::uint32_t __reduction_scan_id = __sub_group_local_id; + // need to pull out first iteration tp avoid identity + _InitValueType __v = __sub_group_partials[__reduction_scan_id]; + __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __v, __reduce_op, __sub_group_carry); + __temp_ptr[__start_id + __reduction_scan_id] = __v; + __reduction_scan_id += __sub_group_size; + + for (std::uint32_t __i = 1; __i < __iters - 1; __i++) + { + __v = __sub_group_partials[__reduction_scan_id]; + __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( + __sub_group, __v, __reduce_op, __sub_group_carry); + __temp_ptr[__start_id + __reduction_scan_id] = __v; + __reduction_scan_id += __sub_group_size; + } + // If we are past the input range, then the previous value of v is passed to the sub-group scan. + // It does not affect the result as our sub_group_scan will use a mask to only process in-range elements. + + // fill with unused dummy values to avoid overruning input + std::uint32_t __load_id = std::min(__reduction_scan_id, __num_sub_groups_local - 1); + + __v = __sub_group_partials[__load_id]; + __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( + __sub_group, __v, __reduce_op, __sub_group_carry, + __active_subgroups - ((__iters - 1) * __sub_group_size)); + if (__reduction_scan_id < __num_sub_groups_local) + __temp_ptr[__start_id + __reduction_scan_id] = __v; + } + + __sub_group_carry.__destroy(); + } + }); + }); + } + + // Constant parameters throughout all blocks + const std::uint32_t __max_block_size; + const std::uint32_t __num_sub_groups_local; + const std::uint32_t __num_sub_groups_global; + const std::uint32_t __num_work_items; + const std::size_t __n; + + const _GenReduceInput __gen_reduce_input; + const _ReduceOp __reduce_op; + _InitType __init; +}; + +template +struct __parallel_reduce_then_scan_scan_submitter; + +template +struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive, _ReduceOp, + _GenScanInput, _ScanInputTransform, _WriteOp, _InitType, + __internal::__optional_kernel_name<_KernelName...>> +{ + using _InitValueType = typename _InitType::__value_type; + + _InitValueType + __get_block_carry_in(const std::size_t __block_num, _InitValueType* __tmp_ptr) const + { + return __tmp_ptr[__num_sub_groups_global + (__block_num % 2)]; + } + + template + void + __set_block_carry_out(const std::size_t __block_num, _InitValueType* __tmp_ptr, + const _ValueType __block_carry_out) const + { + __tmp_ptr[__num_sub_groups_global + 1 - (__block_num % 2)] = __block_carry_out; + } + + template + sycl::event + operator()(_ExecutionPolicy&& __exec, const sycl::nd_range<1> __nd_range, _InRng&& __in_rng, _OutRng&& __out_rng, + _TmpStorageAcc& __scratch_container, const sycl::event& __prior_event, + const std::uint32_t __inputs_per_sub_group, const std::uint32_t __inputs_per_item, + const std::size_t __block_num) const + { + std::uint32_t __inputs_in_block = std::min(__n - __block_num * __max_block_size, std::size_t{__max_block_size}); + std::uint32_t __active_groups = oneapi::dpl::__internal::__dpl_ceiling_div( + __inputs_in_block, __inputs_per_sub_group * __num_sub_groups_local); + return __exec.queue().submit([&, this](sycl::handler& __cgh) { + // We need __num_sub_groups_local + 1 temporary SLM locations to store intermediate results: + // __num_sub_groups_local for each sub-group partial from the reduce kernel + + // 1 element for the accumulated block-local carry-in from previous groups in the block + __dpl_sycl::__local_accessor<_InitValueType> __sub_group_partials(__num_sub_groups_local + 1, __cgh); + __cgh.depends_on(__prior_event); + oneapi::dpl::__ranges::__require_access(__cgh, __in_rng, __out_rng); + auto __temp_acc = __scratch_container.__get_scratch_acc(__cgh); + auto __res_acc = __scratch_container.__get_result_acc(__cgh); + + __cgh.parallel_for<_KernelName...>( + __nd_range, [=, *this] (sycl::nd_item<1> __ndi) [[sycl::reqd_sub_group_size(__sub_group_size)]] { + _InitValueType* __tmp_ptr = _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__temp_acc); + _InitValueType* __res_ptr = + _TmpStorageAcc::__get_usm_or_buffer_accessor_ptr(__res_acc, __num_sub_groups_global + 2); + std::uint32_t __group_id = __ndi.get_group(0); + __dpl_sycl::__sub_group __sub_group = __ndi.get_sub_group(); + std::uint32_t __sub_group_id = __sub_group.get_group_linear_id(); + std::uint8_t __sub_group_local_id = __sub_group.get_local_linear_id(); + + std::size_t __group_start_id = + (__block_num * __max_block_size) + (__group_id * __inputs_per_sub_group * __num_sub_groups_local); + + std::size_t __max_inputs_in_group = __inputs_per_sub_group * __num_sub_groups_local; + std::uint32_t __inputs_in_group = std::min(__n - __group_start_id, __max_inputs_in_group); + std::uint32_t __active_subgroups = + oneapi::dpl::__internal::__dpl_ceiling_div(__inputs_in_group, __inputs_per_sub_group); + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __carry_last; + + // propagate carry in from previous block + oneapi::dpl::__internal::__lazy_ctor_storage<_InitValueType> __sub_group_carry; + + // on the first sub-group in a work-group (assuming S subgroups in a work-group): + // 1. load S sub-group local carry prefix sums (T0..TS-1) to SLM + // 2. load 32, 64, 96, etc. TS-1 work-group carry-outs (32 for WG num<32, 64 for WG num<64, etc.), + // and then compute the prefix sum to generate global carry out + // for each WG, i.e., prefix sum on TS-1 carries over all WG. + // 3. on each WG select the adjacent neighboring WG carry in + // 4. on each WG add the global carry-in to S sub-group local prefix sums to + // get a T-local global carry in + // 5. recompute T-local prefix values, add the T-local global carries, + // and then write back the final values to memory + if (__sub_group_id == 0) + { + // step 1) load to SLM the WG-local S prefix sums + // on WG T-local carries + // 0: T0 carry, 1: T0 + T1 carry, 2: T0 + T1 + T2 carry, ... + // S: sum(T0 carry...TS carry) + std::uint8_t __iters = + oneapi::dpl::__internal::__dpl_ceiling_div(__active_subgroups, __sub_group_size); + std::size_t __subgroups_before_my_group = __group_id * __num_sub_groups_local; + std::uint32_t __load_reduction_id = __sub_group_local_id; + std::uint8_t __i = 0; + for (; __i < __iters - 1; __i++) + { + __sub_group_partials[__load_reduction_id] = + __tmp_ptr[__subgroups_before_my_group + __load_reduction_id]; + __load_reduction_id += __sub_group_size; + } + if (__load_reduction_id < __active_subgroups) + { + __sub_group_partials[__load_reduction_id] = + __tmp_ptr[__subgroups_before_my_group + __load_reduction_id]; + } + + // step 2) load 32, 64, 96, etc. work-group carry outs on every work-group; then + // compute the prefix in a sub-group to get global work-group carries + // memory accesses: gather(63, 127, 191, 255, ...) + std::uint32_t __offset = __num_sub_groups_local - 1; + // only need 32 carries for WGs0..WG32, 64 for WGs32..WGs64, etc. + if (__group_id > 0) + { + // only need the last element from each scan of num_sub_groups_local subgroup reductions + const std::size_t __elements_to_process = __subgroups_before_my_group / __num_sub_groups_local; + const std::size_t __pre_carry_iters = + oneapi::dpl::__internal::__dpl_ceiling_div(__elements_to_process, __sub_group_size); + if (__pre_carry_iters == 1) + { + // single partial scan + std::size_t __proposed_id = __num_sub_groups_local * __sub_group_local_id + __offset; + std::size_t __remaining_elements = __elements_to_process; + std::size_t __reduction_id = (__proposed_id < __subgroups_before_my_group) + ? __proposed_id + : __subgroups_before_my_group - 1; + _InitValueType __value = __tmp_ptr[__reduction_id]; + __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, + /*__init_present=*/false>(__sub_group, __value, __reduce_op, + __carry_last, __remaining_elements); + } + else + { + // multiple iterations + // first 1 full + std::uint32_t __reduction_id = __num_sub_groups_local * __sub_group_local_id + __offset; + std::uint32_t __reduction_id_increment = __num_sub_groups_local * __sub_group_size; + _InitValueType __value = __tmp_ptr[__reduction_id]; + __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/false>( + __sub_group, __value, __reduce_op, __carry_last); + __reduction_id += __reduction_id_increment; + // then some number of full iterations + for (std::uint32_t __i = 1; __i < __pre_carry_iters - 1; __i++) + { + __value = __tmp_ptr[__reduction_id]; + __sub_group_scan<__sub_group_size, /*__is_inclusive=*/true, /*__init_present=*/true>( + __sub_group, __value, __reduce_op, __carry_last); + __reduction_id += __reduction_id_increment; + } + + // final partial iteration + + std::size_t __remaining_elements = + __elements_to_process - ((__pre_carry_iters - 1) * __sub_group_size); + // fill with unused dummy values to avoid overruning input + std::size_t __final_reduction_id = + std::min(std::size_t{__reduction_id}, __subgroups_before_my_group - 1); + __value = __tmp_ptr[__final_reduction_id]; + __sub_group_scan_partial<__sub_group_size, /*__is_inclusive=*/true, + /*__init_present=*/true>(__sub_group, __value, __reduce_op, + __carry_last, __remaining_elements); + } + + // steps 3+4) load global carry in from neighbor work-group + // and apply to local sub-group prefix carries + std::size_t __carry_offset = __sub_group_local_id; + + std::uint8_t __iters = + oneapi::dpl::__internal::__dpl_ceiling_div(__active_subgroups, __sub_group_size); + + std::uint8_t __i = 0; + for (; __i < __iters - 1; ++__i) + { + __sub_group_partials[__carry_offset] = + __reduce_op(__carry_last.__v, __sub_group_partials[__carry_offset]); + __carry_offset += __sub_group_size; + } + if (__i * __sub_group_size + __sub_group_local_id < __active_subgroups) + { + __sub_group_partials[__carry_offset] = + __reduce_op(__carry_last.__v, __sub_group_partials[__carry_offset]); + __carry_offset += __sub_group_size; + } + if (__sub_group_local_id == 0) + __sub_group_partials[__active_subgroups] = __carry_last.__v; + __carry_last.__destroy(); + } + } + + __dpl_sycl::__group_barrier(__ndi); + + // Get inter-work group and adjusted for intra-work group prefix + bool __sub_group_carry_initialized = true; + if (__block_num == 0) + { + if (__sub_group_id > 0) + { + _InitValueType __value = + __sub_group_partials[std::min(__sub_group_id - 1, __active_subgroups - 1)]; + oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __reduce_op); + __sub_group_carry.__setup(__value); + } + else if (__group_id > 0) + { + _InitValueType __value = __sub_group_partials[__active_subgroups]; + oneapi::dpl::unseq_backend::__init_processing<_InitValueType>{}(__init, __value, __reduce_op); + __sub_group_carry.__setup(__value); + } + else // zeroth block, group and subgroup + { + if constexpr (std::is_same_v<_InitType, + oneapi::dpl::unseq_backend::__no_init_value<_InitValueType>>) + { + // This is the only case where we still don't have a carry in. No init value, 0th block, + // group, and subgroup. This changes the final scan through elements below. + __sub_group_carry_initialized = false; + } + else + { + __sub_group_carry.__setup(__init.__value); + } + } + } + else + { + if (__sub_group_id > 0) + { + _InitValueType __value = + __sub_group_partials[std::min(__sub_group_id - 1, __active_subgroups - 1)]; + __sub_group_carry.__setup(__reduce_op(__get_block_carry_in(__block_num, __tmp_ptr), __value)); + } + else if (__group_id > 0) + { + __sub_group_carry.__setup(__reduce_op(__get_block_carry_in(__block_num, __tmp_ptr), + __sub_group_partials[__active_subgroups])); + } + else + { + __sub_group_carry.__setup(__get_block_carry_in(__block_num, __tmp_ptr)); + } + } + + // step 5) apply global carries + std::size_t __subgroup_start_id = __group_start_id + (__sub_group_id * __inputs_per_sub_group); + std::size_t __start_id = __subgroup_start_id + __sub_group_local_id; + + if (__sub_group_carry_initialized) + { + __scan_through_elements_helper<__sub_group_size, __is_inclusive, + /*__init_present=*/true, + /*__capture_output=*/true, __max_inputs_per_item>( + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_id, __n, __inputs_per_item, __subgroup_start_id, + __sub_group_id, __active_subgroups); + } + else // first group first block, no subgroup carry + { + __scan_through_elements_helper<__sub_group_size, __is_inclusive, + /*__init_present=*/false, + /*__capture_output=*/true, __max_inputs_per_item>( + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_id, __n, __inputs_per_item, __subgroup_start_id, + __sub_group_id, __active_subgroups); + } + // If within the last active group and sub-group of the block, use the 0th work-item of the sub-group + // to write out the last carry out for either the return value or the next block + if (__sub_group_local_id == 0 && (__active_groups == __group_id + 1) && + (__active_subgroups == __sub_group_id + 1)) + { + if (__block_num + 1 == __num_blocks) + { + __res_ptr[0] = __sub_group_carry.__v; + } + else + { + // capture the last carry out for the next block + __set_block_carry_out(__block_num, __tmp_ptr, __sub_group_carry.__v); + } + } + __sub_group_carry.__destroy(); + }); + }); + } + + const std::uint32_t __max_block_size; + const std::uint32_t __num_sub_groups_local; + const std::uint32_t __num_sub_groups_global; + const std::uint32_t __num_work_items; + const std::size_t __num_blocks; + const std::size_t __n; + + const _ReduceOp __reduce_op; + const _GenScanInput __gen_scan_input; + const _ScanInputTransform __scan_input_transform; + const _WriteOp __write_op; + _InitType __init; +}; + +// reduce_then_scan requires subgroup size of 32, and performs well only on devices with fast coordinated subgroup +// operations. We do not want to run this scan on CPU targets, as they are not performant with this algorithm. +template +bool +__is_gpu_with_sg_32(const _ExecutionPolicy& __exec) +{ + const bool __dev_has_sg32 = oneapi::dpl::__internal::__supports_sub_group_size(__exec, 32); + return (__exec.queue().get_device().is_gpu() && __dev_has_sg32); +} + +// General scan-like algorithm helpers +// _GenReduceInput - a function which accepts the input range and index to generate the data needed by the main output +// used in the reduction operation (to calculate the global carries) +// _GenScanInput - a function which accepts the input range and index to generate the data needed by the final scan +// and write operations, for scan patterns +// _ScanInputTransform - a unary function applied to the output of `_GenScanInput` to extract the component used in the +// scan, but not the part only required for the final write operation +// _ReduceOp - a binary function which is used in the reduction and scan operations +// _WriteOp - a function which accepts output range, index, and output of `_GenScanInput` applied to the input range +// and performs the final write to output operation +template +auto +__parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, + _InRng&& __in_rng, _OutRng&& __out_rng, _GenReduceInput __gen_reduce_input, + _ReduceOp __reduce_op, _GenScanInput __gen_scan_input, + _ScanInputTransform __scan_input_transform, _WriteOp __write_op, _InitType __init, + _Inclusive) +{ + using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; + using _ReduceKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __reduce_then_scan_reduce_kernel<_CustomName>>; + using _ScanKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider< + __reduce_then_scan_scan_kernel<_CustomName>>; + using _ValueType = typename _InitType::__value_type; + + constexpr std::uint8_t __sub_group_size = 32; + constexpr std::uint8_t __block_size_scale = std::max(std::size_t{1}, sizeof(double) / sizeof(_ValueType)); + // Empirically determined maximum. May be less for non-full blocks. + constexpr std::uint16_t __max_inputs_per_item = 64 * __block_size_scale; + constexpr bool __inclusive = _Inclusive::value; + + const std::uint32_t __max_work_group_size = oneapi::dpl::__internal::__max_work_group_size(__exec, 8192); + // Round down to nearest multiple of the subgroup size + const std::uint32_t __work_group_size = (__max_work_group_size / __sub_group_size) * __sub_group_size; + + // TODO: Investigate potentially basing this on some scale of the number of compute units. 128 work-groups has been + // found to be reasonable number for most devices. + constexpr std::uint32_t __num_work_groups = 128; + const std::uint32_t __num_work_items = __num_work_groups * __work_group_size; + const std::uint32_t __num_sub_groups_local = __work_group_size / __sub_group_size; + const std::uint32_t __num_sub_groups_global = __num_sub_groups_local * __num_work_groups; + const std::size_t __n = __in_rng.size(); + const std::uint32_t __max_inputs_per_block = __work_group_size * __max_inputs_per_item * __num_work_groups; + std::size_t __inputs_remaining = __n; + + // reduce_then_scan kernel is not built to handle "empty". + // These trivial end cases should be handled at a higher level. + assert(__inputs_remaining > 0); + const std::uint32_t __max_inputs_per_subgroup = __max_inputs_per_block / __num_sub_groups_global; + std::uint32_t __evenly_divided_remaining_inputs = + std::max(std::size_t{__sub_group_size}, + oneapi::dpl::__internal::__dpl_bit_ceil(__inputs_remaining) / __num_sub_groups_global); + std::uint32_t __inputs_per_sub_group = + __inputs_remaining >= __max_inputs_per_block ? __max_inputs_per_subgroup : __evenly_divided_remaining_inputs; + std::uint32_t __inputs_per_item = __inputs_per_sub_group / __sub_group_size; + const std::size_t __block_size = std::min(__inputs_remaining, std::size_t{__max_inputs_per_block}); + const std::size_t __num_blocks = __inputs_remaining / __block_size + (__inputs_remaining % __block_size != 0); + + // We need temporary storage for reductions of each sub-group (__num_sub_groups_global). + // Additionally, we need two elements for the block carry-out to prevent a race condition + // between reading and writing the block carry-out within a single kernel. + __result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{__exec, 1, + __num_sub_groups_global + 2}; + + // Reduce and scan step implementations + using _ReduceSubmitter = + __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inputs_per_item, __inclusive, + _GenReduceInput, _ReduceOp, _InitType, _ReduceKernel>; + using _ScanSubmitter = + __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __inclusive, _ReduceOp, + _GenScanInput, _ScanInputTransform, _WriteOp, _InitType, + _ScanKernel>; + _ReduceSubmitter __reduce_submitter{__max_inputs_per_block, + __num_sub_groups_local, + __num_sub_groups_global, + __num_work_items, + __n, + __gen_reduce_input, + __reduce_op, + __init}; + _ScanSubmitter __scan_submitter{__max_inputs_per_block, + __num_sub_groups_local, + __num_sub_groups_global, + __num_work_items, + __num_blocks, + __n, + __reduce_op, + __gen_scan_input, + __scan_input_transform, + __write_op, + __init}; + sycl::event __event; + // Data is processed in 2-kernel blocks to allow contiguous input segment to persist in LLC between the first and second kernel for accelerators + // with sufficiently large L2 / L3 caches. + for (std::size_t __b = 0; __b < __num_blocks; ++__b) + { + std::uint32_t __workitems_in_block = oneapi::dpl::__internal::__dpl_ceiling_div( + std::min(__inputs_remaining, std::size_t{__max_inputs_per_block}), __inputs_per_item); + std::uint32_t __workitems_in_block_round_up_workgroup = + oneapi::dpl::__internal::__dpl_ceiling_div(__workitems_in_block, __work_group_size) * __work_group_size; + auto __global_range = sycl::range<1>(__workitems_in_block_round_up_workgroup); + auto __local_range = sycl::range<1>(__work_group_size); + auto __kernel_nd_range = sycl::nd_range<1>(__global_range, __local_range); + // 1. Reduce step - Reduce assigned input per sub-group, compute and apply intra-wg carries, and write to global memory. + __event = __reduce_submitter(__exec, __kernel_nd_range, __in_rng, __result_and_scratch, __event, + __inputs_per_sub_group, __inputs_per_item, __b); + // 2. Scan step - Compute intra-wg carries, determine sub-group carry-ins, and perform full input block scan. + __event = __scan_submitter(__exec, __kernel_nd_range, __in_rng, __out_rng, __result_and_scratch, __event, + __inputs_per_sub_group, __inputs_per_item, __b); + __inputs_remaining -= std::min(__inputs_remaining, __block_size); + // We only need to resize these parameters prior to the last block as it is the only non-full case. + if (__b + 2 == __num_blocks) + { + __evenly_divided_remaining_inputs = + std::max(std::size_t{__sub_group_size}, + oneapi::dpl::__internal::__dpl_bit_ceil(__inputs_remaining) / __num_sub_groups_global); + __inputs_per_sub_group = __inputs_remaining >= __max_inputs_per_block ? __max_inputs_per_subgroup + : __evenly_divided_remaining_inputs; + __inputs_per_item = __inputs_per_sub_group / __sub_group_size; + } + } + return __future(__event, __result_and_scratch); +} + +} // namespace __par_backend_hetero +} // namespace dpl +} // namespace oneapi + +#endif // _ONEDPL_PARALLEL_BACKEND_SYCL_REDUCE_THEN_SCAN_H diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h index c7d46dd2057..9bd195a80a9 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "../../iterator_impl.h" @@ -94,6 +95,15 @@ __max_compute_units(const _ExecutionPolicy& __policy) return __policy.queue().get_device().template get_info(); } +template +bool +__supports_sub_group_size(const _ExecutionPolicy& __exec, std::size_t __target_size) +{ + const std::vector __subgroup_sizes = + __exec.queue().get_device().template get_info(); + return std::find(__subgroup_sizes.begin(), __subgroup_sizes.end(), __target_size) != __subgroup_sizes.end(); +} + //----------------------------------------------------------------------------- // Kernel run-time information helpers //----------------------------------------------------------------------------- diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h index 97f206d57ff..83c44a8a07d 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_defs.h @@ -59,6 +59,8 @@ // TODO: determine which compiler configurations provide subgroup load/store #define _ONEDPL_SYCL_SUB_GROUP_LOAD_STORE_PRESENT false +#define _ONEDPL_SYCL_SUB_GROUP_PRESENT (_ONEDPL_LIBSYCL_VERSION >= 50700) + // Macro to check if we are compiling for SPIR-V devices. This macro must only be used within // SYCL kernels for determining SPIR-V compilation. Using this macro on the host may lead to incorrect behavior. #ifndef _ONEDPL_DETECT_SPIRV_COMPILATION // Check if overridden for testing @@ -140,6 +142,12 @@ template using __minimum = sycl::ONEAPI::minimum<_T>; #endif // _ONEDPL_SYCL2020_FUNCTIONAL_OBJECTS_PRESENT +#if _ONEDPL_SYCL_SUB_GROUP_PRESENT +using __sub_group = sycl::sub_group; +#else +using __sub_group = sycl::ONEAPI::sub_group; +#endif + template constexpr auto __get_buffer_size(const _Buffer& __buffer) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h index 2144c454864..b820741ea00 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h @@ -233,11 +233,21 @@ struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__internal:: namespace oneapi::dpl::__par_backend_hetero { +template +struct __gen_transform_input; + template struct __early_exit_find_or; } // namespace oneapi::dpl::__par_backend_hetero +template +struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__par_backend_hetero::__gen_transform_input, + _UnaryOp)> + : oneapi::dpl::__internal::__are_all_device_copyable<_UnaryOp> +{ +}; + template struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__par_backend_hetero::__early_exit_find_or, _ExecutionPolicy, _Pred)> diff --git a/include/oneapi/dpl/pstl/hetero/numeric_ranges_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/numeric_ranges_impl_hetero.h index f85dc811ab2..e00a89441f2 100644 --- a/include/oneapi/dpl/pstl/hetero/numeric_ranges_impl_hetero.h +++ b/include/oneapi/dpl/pstl/hetero/numeric_ranges_impl_hetero.h @@ -91,35 +91,15 @@ oneapi::dpl::__internal::__difference_t<_Range2> __pattern_transform_scan_base(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2, _UnaryOperation __unary_op, _InitType __init, _BinaryOperation __binary_op, _Inclusive) { - if (__rng1.empty()) + oneapi::dpl::__internal::__difference_t<_Range1> __n = __rng1.size(); + if (__n == 0) return 0; - oneapi::dpl::__internal::__difference_t<_Range2> __rng1_size = __rng1.size(); - - using _Type = typename _InitType::__value_type; - using _Assigner = unseq_backend::__scan_assigner; - using _NoAssign = unseq_backend::__scan_no_assign; - using _UnaryFunctor = unseq_backend::walk_n<_ExecutionPolicy, _UnaryOperation>; - using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>; - - _Assigner __assign_op; - _NoAssign __no_assign_op; - _NoOpFunctor __get_data_op; - - oneapi::dpl::__par_backend_hetero::__parallel_transform_scan_base( - _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__rng1), - ::std::forward<_Range2>(__rng2), __binary_op, __init, - // local scan - unseq_backend::__scan<_Inclusive, _ExecutionPolicy, _BinaryOperation, _UnaryFunctor, _Assigner, _Assigner, - _NoOpFunctor, _InitType>{__binary_op, _UnaryFunctor{__unary_op}, __assign_op, __assign_op, - __get_data_op}, - // scan between groups - unseq_backend::__scan>{ - __binary_op, _NoOpFunctor{}, __no_assign_op, __assign_op, __get_data_op}, - // global scan - unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init}) + + oneapi::dpl::__par_backend_hetero::__parallel_transform_scan( + _BackendTag{}, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), + std::forward<_Range2>(__rng2), __n, __unary_op, __init, __binary_op, _Inclusive{}) .__deferrable_wait(); - return __rng1_size; + return __n; } template >, "__brick_reduce_idx is not device copyable with device copyable types"); + //__gen_transform_input + static_assert( + sycl::is_device_copyable_v>, + "__gen_transform_input is not device copyable with device copyable types"); + // __early_exit_find_or static_assert( sycl::is_device_copyable_v< @@ -343,6 +348,11 @@ test_non_device_copyable() oneapi::dpl::unseq_backend::__brick_reduce_idx>, "__brick_reduce_idx is device copyable with non device copyable types"); + //__gen_transform_input + static_assert( + !sycl::is_device_copyable_v>, + "__gen_transform_input is device copyable with non device copyable types"); + // __early_exit_find_or static_assert( !sycl::is_device_copyable_v