Skip to content

Commit

Permalink
type / variable naming and clang-format
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <dan.hoeflinger@intel.com>
  • Loading branch information
danhoeflinger committed Jul 17, 2024
1 parent 530ec43 commit d072968
Showing 1 changed file with 46 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,23 @@ __sub_group_scan_partial(const _SubGroup& __sub_group, _ValueType& __value, _Bin
}

template <std::uint8_t __sub_group_size, bool __is_inclusive, bool __init_present, bool __capture_output,
std::uint32_t __max_inputs_per_item, typename _SubGroup, typename _GenInput, typename _ScanPred,
std::uint32_t __max_inputs_per_item, typename _SubGroup, typename _GenInput, typename _ScanInputTransform,
typename _BinaryOp, typename _WriteOp, typename _LazyValueType, typename _InRng, typename _OutRng>
void
__scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_input, _ScanPred __scan_pred,
_BinaryOp __binary_op, _WriteOp __write_op, _LazyValueType& __sub_group_carry,
_InRng __in_rng, _OutRng __out_rng, std::size_t __start_idx, std::size_t __n,
std::uint32_t __iters_per_item, std::size_t __subgroup_start_idx,
std::uint32_t __sub_group_id, std::uint32_t __active_subgroups)
__scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_input,
_ScanInputTransform __scan_input_transform, _BinaryOp __binary_op, _WriteOp __write_op,
_LazyValueType& __sub_group_carry, _InRng __in_rng, _OutRng __out_rng,
std::size_t __start_idx, std::size_t __n, std::uint32_t __iters_per_item,
std::size_t __subgroup_start_idx, std::uint32_t __sub_group_id,
std::uint32_t __active_subgroups)
{
bool __is_full_block = (__iters_per_item == __max_inputs_per_item);
bool __is_full_thread = __subgroup_start_idx + __iters_per_item * __sub_group_size <= __n;
if (__is_full_thread && __is_full_block)
{
auto __v = __gen_input(__in_rng, __start_idx);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v), __binary_op,
__sub_group_carry);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v),
__binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __start_idx, __v);
Expand All @@ -173,8 +174,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
for (std::uint32_t __j = 1; __j < __max_inputs_per_item; __j++)
{
__v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size);
__sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __scan_pred(__v),
__binary_op, __sub_group_carry);
__sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __start_idx + __j * __sub_group_size, __v);
Expand All @@ -184,17 +185,17 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
else if (__is_full_thread)
{
auto __v = __gen_input(__in_rng, __start_idx);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v), __binary_op,
__sub_group_carry);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_input_transform(__v),
__binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __start_idx, __v);
}
for (std::uint32_t __j = 1; __j < __iters_per_item; __j++)
{
__v = __gen_input(__in_rng, __start_idx + __j * __sub_group_size);
__sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(__sub_group, __scan_pred(__v),
__binary_op, __sub_group_carry);
__sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __start_idx + __j * __sub_group_size, __v);
Expand All @@ -211,7 +212,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
{
auto __v = __gen_input(__in_rng, __start_idx);
__sub_group_scan_partial<__sub_group_size, __is_inclusive, __init_present>(
__sub_group, __scan_pred(__v), __binary_op, __sub_group_carry, __n - __subgroup_start_idx);
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry,
__n - __subgroup_start_idx);
if constexpr (__capture_output)
{
if (__start_idx < __n)
Expand All @@ -221,8 +223,8 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
else
{
auto __v = __gen_input(__in_rng, __start_idx);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(__sub_group, __scan_pred(__v),
__binary_op, __sub_group_carry);
__sub_group_scan<__sub_group_size, __is_inclusive, __init_present>(
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __start_idx, __v);
Expand All @@ -233,7 +235,7 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
auto __local_idx = __start_idx + __j * __sub_group_size;
__v = __gen_input(__in_rng, __local_idx);
__sub_group_scan<__sub_group_size, __is_inclusive, /*__init_present=*/true>(
__sub_group, __scan_pred(__v), __binary_op, __sub_group_carry);
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry);
if constexpr (__capture_output)
{
__write_op(__out_rng, __local_idx, __v);
Expand All @@ -244,7 +246,7 @@ __scan_through_elements_helper(const _SubGroup& __sub_group, _GenInput __gen_inp
auto __local_idx = (__offset < __n) ? __offset : __n - 1;
__v = __gen_input(__in_rng, __local_idx);
__sub_group_scan_partial<__sub_group_size, __is_inclusive, /*__init_present=*/true>(
__sub_group, __scan_pred(__v), __binary_op, __sub_group_carry,
__sub_group, __scan_input_transform(__v), __binary_op, __sub_group_carry,
__n - (__subgroup_start_idx + (__iters - 1) * __sub_group_size));
if constexpr (__capture_output)
{
Expand Down Expand Up @@ -393,20 +395,19 @@ struct __parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inpu
const _GenReduceInput __gen_reduce_input;
const _ReduceOp __reduce_op;
_InitType __init;

};

template <std::size_t __sub_group_size, std::size_t __max_inputs_per_item, bool __is_inclusive,
typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanPred, typename _WriteOp,
typename _InitType, typename _KernelName>
typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform,
typename _WriteOp, typename _InitType, typename _KernelName>
struct __parallel_reduce_then_scan_scan_submitter;

template <std::size_t __sub_group_size, std::size_t __max_inputs_per_item, bool __is_inclusive,
typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanPred, typename _WriteOp,
typename _InitType, typename... _KernelName>
struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __is_inclusive,
_GenReduceInput, _ReduceOp, _GenScanInput, _ScanPred, _WriteOp,
_InitType, __internal::__optional_kernel_name<_KernelName...>>
typename _GenReduceInput, typename _ReduceOp, typename _GenScanInput, typename _ScanInputTransform,
typename _WriteOp, typename _InitType, typename... _KernelName>
struct __parallel_reduce_then_scan_scan_submitter<
__sub_group_size, __max_inputs_per_item, __is_inclusive, _GenReduceInput, _ReduceOp, _GenScanInput,
_ScanInputTransform, _WriteOp, _InitType, __internal::__optional_kernel_name<_KernelName...>>
{
template <typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _TmpStorageAcc>
auto
Expand Down Expand Up @@ -633,20 +634,20 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs
__scan_through_elements_helper<__sub_group_size, __is_inclusive,
/*__init_present=*/true,
/*__capture_output=*/true, __max_inputs_per_item>(
__sub_group, __gen_scan_input, __scan_pred, __reduce_op, __write_op, __sub_group_carry,
__in_rng, __out_rng, __start_idx, __n, __inputs_per_item, __subgroup_start_idx, __sub_group_id,
__active_subgroups);
__sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op,
__sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item,
__subgroup_start_idx, __sub_group_id, __active_subgroups);
}
else // first group first block, no subgroup carry
{
__scan_through_elements_helper<__sub_group_size, __is_inclusive,
/*__init_present=*/false,
/*__capture_output=*/true, __max_inputs_per_item>(
__sub_group, __gen_scan_input, __scan_pred, __reduce_op, __write_op, __sub_group_carry,
__in_rng, __out_rng, __start_idx, __n, __inputs_per_item, __subgroup_start_idx, __sub_group_id,
__active_subgroups);
__sub_group, __gen_scan_input, __scan_input_transform, __reduce_op, __write_op,
__sub_group_carry, __in_rng, __out_rng, __start_idx, __n, __inputs_per_item,
__subgroup_start_idx, __sub_group_id, __active_subgroups);
}
//If within the last active group and subgroup of the block, use the 0th work item of the subgroup
//If within the last active group and subgroup of the block, use the 0th work item of the subgroup
// to write out the last carry out for either the return value or the next block
if (__sub_group_local_id == 0 && (__active_groups == __g + 1) &&
(__active_subgroups == __sub_group_id + 1))
Expand Down Expand Up @@ -677,29 +678,29 @@ struct __parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs
const _GenReduceInput __gen_reduce_input;
const _ReduceOp __reduce_op;
const _GenScanInput __gen_scan_input;
const _ScanPred __scan_pred;
const _ScanInputTransform __scan_input_transform;
const _WriteOp __write_op;
_InitType __init;

};

// General scan-like algorithm helpers
// _GenReduceInput - a function which accepts the input range and index to generate the data needed by the main output
// used in the reduction operation (to calculate the global carries)
// _GenScanInput - a function which accepts the input range and index to generate the data needed by the final scan
// and write operations, for scan patterns
// _ScanPred - a unary function applied to the ouput of `_GenScanInput` to extract the component used in the scan, but
// _ScanInputTransform - a unary function applied to the ouput of `_GenScanInput` to extract the component used in the scan, but
// not the part only required for the final write operation
// _ReduceOp - a binary function which is used in the reduction and scan operations
// _WriteOp - a function which accepts output range, index, and output of `_GenScanInput` applied to the input range
// and performs the final write to output operation
template <typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _GenReduceInput, typename _ReduceOp,
typename _GenScanInput, typename _ScanPred, typename _WriteOp, typename _InitType, typename _Inclusive>
typename _GenScanInput, typename _ScanInputTransform, typename _WriteOp, typename _InitType,
typename _Inclusive>
auto
__parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec,
_InRng&& __in_rng, _OutRng&& __out_rng, _GenReduceInput __gen_reduce_input,
_ReduceOp __reduce_op, _GenScanInput __gen_scan_input, _ScanPred __scan_pred,
_WriteOp __write_op,
_ReduceOp __reduce_op, _GenScanInput __gen_scan_input,
_ScanInputTransform __scan_input_transform, _WriteOp __write_op,
_InitType __init /*TODO mask assigners for generalization go here*/, _Inclusive)
{
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
Expand Down Expand Up @@ -735,24 +736,23 @@ __parallel_transform_reduce_then_scan(oneapi::dpl::__internal::__device_backend_
const auto __block_size = (__n < __max_inputs_per_block) ? __n : __max_inputs_per_block;
const auto __num_blocks = __n / __block_size + (__n % __block_size != 0);


__result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{
__exec, __num_sub_groups_global + 1};
__result_and_scratch_storage<_ExecutionPolicy, _ValueType> __result_and_scratch{__exec,
__num_sub_groups_global + 1};

// Reduce and scan step implementations
using _ReduceSubmitter =
__parallel_reduce_then_scan_reduce_submitter<__sub_group_size, __max_inputs_per_item, __inclusive,
_GenReduceInput, _ReduceOp, _InitType, _ReduceKernel>;
using _ScanSubmitter =
__parallel_reduce_then_scan_scan_submitter<__sub_group_size, __max_inputs_per_item, __inclusive,
_GenReduceInput, _ReduceOp, _GenScanInput, _ScanPred, _WriteOp,
_InitType, _ScanKernel>;
_GenReduceInput, _ReduceOp, _GenScanInput, _ScanInputTransform,
_WriteOp, _InitType, _ScanKernel>;
// TODO: remove below before merging. used for convenience now
// clang-format off
_ReduceSubmitter __reduce_submitter{__max_inputs_per_block, __num_sub_groups_local,
__num_sub_groups_global, __num_work_items, __n, __gen_reduce_input, __reduce_op, __init};
_ScanSubmitter __scan_submitter{__max_inputs_per_block, __num_sub_groups_local,
__num_sub_groups_global, __num_work_items, __num_blocks, __n, __gen_reduce_input, __reduce_op, __gen_scan_input, __scan_pred,
__num_sub_groups_global, __num_work_items, __num_blocks, __n, __gen_reduce_input, __reduce_op, __gen_scan_input, __scan_input_transform,
__write_op, __init};
// clang-format on

Expand Down

0 comments on commit d072968

Please sign in to comment.