From d0729686eb2d0c9264980614e7481bb088f4f1e5 Mon Sep 17 00:00:00 2001 From: Dan Hoeflinger Date: Wed, 17 Jul 2024 16:28:57 -0400 Subject: [PATCH] type / variable naming and clang-format Signed-off-by: Dan Hoeflinger --- .../parallel_backend_sycl_reduce_then_scan.h | 92 +++++++++---------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h index 65cd7a6fbcb..2979120ad1a 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h @@ -148,22 +148,23 @@ __sub_group_scan_partial(const _SubGroup& __sub_group, _ValueType& __value, _Bin } template void -__scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_input, _ScanPred __scan_pred, - _BinaryOp __binary_op, _WriteOp __write_op, _LazyValueType& __sub_group_carry, - _InRng __in_rng, _OutRng __out_rng, std::size_t __start_idx, std::size_t __n, - std::uint32_t __iters_per_item, std::size_t __subgroup_start_idx, - std::uint32_t __sub_group_id, std::uint32_t __active_subgroups) +__scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_input, + _ScanInputTransform __scan_input_transform, _BinaryOp __binary_op, _WriteOp __write_op, + _LazyValueType& __sub_group_carry, _InRng __in_rng, _OutRng __out_rng, + std::size_t __start_idx, std::size_t __n, std::uint32_t __iters_per_item, + std::size_t __subgroup_start_idx, std::uint32_t __sub_group_id, + std::uint32_t __active_subgroups) { bool __is_full_block = (__iters_per_item == __max_inputs_per_item); bool __is_full_thread = __subgroup_start_idx + __iters_per_item * __sub_group_size <= __n; if (__is_full_thread && __is_full_block) { auto __v = __gen_input(__in_rng, __start_idx); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v), __binary_op, - __sub_group_carry); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v), + __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __start_idx, __v); @@ -173,8 +174,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp for (std::uint32_t __j = 1; __j < __max_inputs_per_item; __j++) { __v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size); - __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __scan_pred(__v), - __binary_op, __sub_group_carry); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __start_idx + __j * __sub_group_size, __v); @@ -184,8 +185,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp else if (__is_full_thread) { auto __v = __gen_input(__in_rng, __start_idx); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v), __binary_op, - __sub_group_carry); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v), + __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __start_idx, __v); @@ -193,8 +194,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp for (std::uint32_t __j = 1; __j < __iters_per_item; __j++) { __v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size); - __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __scan_pred(__v), - __binary_op, __sub_group_carry); + __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __start_idx + __j * __sub_group_size, __v); @@ -211,7 +212,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp { auto __v = __gen_input(__in_rng, __start_idx); __sub_group_scan_partial<__sub_group_size, __is_inclusive, __init_present>( - __sub_group, __scan_pred(__v), __binary_op, __sub_group_carry, __n - __subgroup_start_idx); + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, + __n - __subgroup_start_idx); if constexpr (__capture_output) { if (__start_idx < __n) @@ -221,8 +223,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp else { auto __v = __gen_input(__in_rng, __start_idx); - __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v), - __binary_op, __sub_group_carry); + __sub_group_scan<__sub_group_size, __is_inclusive, __init_present>( + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __start_idx, __v); @@ -233,7 +235,7 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp auto __local_idx = __start_idx + __j * __sub_group_size; __v = __gen_input(__in_rng, __local_idx); __sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>( - __sub_group, __scan_pred(__v), __binary_op, __sub_group_carry); + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry); if constexpr (__capture_output) { __write_op(__out_rng, __local_idx, __v); @@ -244,7 +246,7 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp auto __local_idx = (__offset < __n) ? __offset : __n - 1; __v = __gen_input(__in_rng, __local_idx); __sub_group_scan_partial<__sub_group_size, __is_inclusive, /*__init_present=*/true>( - __sub_group, __scan_pred(__v), __binary_op, __sub_group_carry, + __sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry, __n - (__subgroup_start_idx + (__iters - 1) * __sub_group_size)); if constexpr (__capture_output) { @@ -393,20 +395,19 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu const _GenReduceInput __gen_reduce_input; const _ReduceOp __reduce_op; _InitType __init; - }; template + typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform, + typename _WriteOp, typename _InitType, typename _KernelName> struct __parallel_reduce_then_scan_scan_submitter; template -struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive, - _GenReduceInput, _ReduceOp, _GenScanInput, _ScanPred, _WriteOp, - _InitType, __internal::__optional_kernel_name<_KernelName...>> + typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform, + typename _WriteOp, typename _InitType, typename... _KernelName> +struct __parallel_reduce_then_scan_scan_submitter< + __sub_group_size, __max_inputs_per_item, __is_inclusive, _GenReduceInput, _ReduceOp, _GenScanInput, + _ScanInputTransform, _WriteOp, _InitType, __internal::__optional_kernel_name<_KernelName...>> { template auto @@ -633,20 +634,20 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs __scan_through_elements_helper<__sub_group_size, __is_inclusive, /*__init_present=*/true, /*__capture_output=*/true, __max_inputs_per_item>( - __sub_group, __gen_scan_input, __scan_pred, __reduce_op, __write_op, __sub_group_carry, - __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, __subgroup_start_idx, __sub_group_id, - __active_subgroups); + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, + __subgroup_start_idx, __sub_group_id, __active_subgroups); } else // first group first block, no subgroup carry { __scan_through_elements_helper<__sub_group_size, __is_inclusive, /*__init_present=*/false, /*__capture_output=*/true, __max_inputs_per_item>( - __sub_group, __gen_scan_input, __scan_pred, __reduce_op, __write_op, __sub_group_carry, - __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, __subgroup_start_idx, __sub_group_id, - __active_subgroups); + __sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op, + __sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item, + __subgroup_start_idx, __sub_group_id, __active_subgroups); } - //If within the last active group and subgroup of the block, use the 0th work item of the subgroup + //If within the last active group and subgroup of the block, use the 0th work item of the subgroup // to write out the last carry out for either the return value or the next block if (__sub_group_local_id == 0 && (__active_groups == __g + 1) && (__active_subgroups == __sub_group_id + 1)) @@ -677,10 +678,9 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs const _GenReduceInput __gen_reduce_input; const _ReduceOp __reduce_op; const _GenScanInput __gen_scan_input; - const _ScanPred __scan_pred; + const _ScanInputTransform __scan_input_transform; const _WriteOp __write_op; _InitType __init; - }; // General scan-like algorithm helpers @@ -688,18 +688,19 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs // used in the reduction operation (to calculate the global carries) // _GenScanInput - a function which accepts the input range and index to generate the data needed by the final scan // and write operations, for scan patterns -// _ScanPred - a unary function applied to the ouput of `_GenScanInput` to extract the component used in the scan, but +// _ScanInputTransform - a unary function applied to the ouput of `_GenScanInput` to extract the component used in the scan, but // not the part only required for the final write operation // _ReduceOp - a binary function which is used in the reduction and scan operations // _WriteOp - a function which accepts output range, index, and output of `_GenScanInput` applied to the input range // and performs the final write to output operation template + typename _GenScanInput, typename _ScanInputTransform, typename _WriteOp, typename _InitType, + typename _Inclusive> auto __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _InRng&& __in_rng, _OutRng&& __out_rng, _GenReduceInput __gen_reduce_input, - _ReduceOp __reduce_op, _GenScanInput __gen_scan_input, _ScanPred __scan_pred, - _WriteOp __write_op, + _ReduceOp __reduce_op, _GenScanInput __gen_scan_input, + _ScanInputTransform __scan_input_transform, _WriteOp __write_op, _InitType __init /*TODO mask assigners for generalization go here*/, _Inclusive) { using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; @@ -735,9 +736,8 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_ const auto __block_size = (__n < __max_inputs_per_block) ? __n : __max_inputs_per_block; const auto __num_blocks = __n / __block_size + (__n % __block_size != 0); - - __result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{ - __exec, __num_sub_groups_global + 1}; + __result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{__exec, + __num_sub_groups_global + 1}; // Reduce and scan step implementations using _ReduceSubmitter = @@ -745,14 +745,14 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_ _GenReduceInput, _ReduceOp, _InitType, _ReduceKernel>; using _ScanSubmitter = __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __inclusive, - _GenReduceInput, _ReduceOp, _GenScanInput, _ScanPred, _WriteOp, - _InitType, _ScanKernel>; + _GenReduceInput, _ReduceOp, _GenScanInput, _ScanInputTransform, + _WriteOp, _InitType, _ScanKernel>; // TODO: remove below before merging. used for convenience now // clang-format off _ReduceSubmitter __reduce_submitter{__max_inputs_per_block, __num_sub_groups_local, __num_sub_groups_global, __num_work_items, __n, __gen_reduce_input, __reduce_op, __init}; _ScanSubmitter __scan_submitter{__max_inputs_per_block, __num_sub_groups_local, - __num_sub_groups_global, __num_work_items, __num_blocks, __n, __gen_reduce_input, __reduce_op, __gen_scan_input, __scan_pred, + __num_sub_groups_global, __num_work_items, __num_blocks, __n, __gen_reduce_input, __reduce_op, __gen_scan_input, __scan_input_transform, __write_op, __init}; // clang-format on