Skip to content

Commit

Permalink
::std -> std:: and replacing assert with exception
Browse files Browse the repository at this point in the history
  • Loading branch information
adamfidel committed Apr 17, 2024
1 parent eb6fb65 commit 2ff21c9
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions include/oneapi/dpl/experimental/kt/single_pass_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct __scan_status_flag

bool __is_full = __tile_flag == __full_status;
auto __is_full_ballot = sycl::ext::oneapi::group_ballot(__subgroup, __is_full);
::std::uint32_t __is_full_ballot_bits{};
std::uint32_t __is_full_ballot_bits{};
__is_full_ballot.extract_bits(__is_full_ballot_bits);

_AtomicValueT __tile_value_atomic(
Expand Down Expand Up @@ -141,7 +141,7 @@ struct __lookback_init_submitter<_FlagType, _Type, _BinaryOp,
template <typename _StatusFlags, typename _PartialValues>
sycl::event
operator()(sycl::queue __q, _StatusFlags&& __status_flags, _PartialValues&& __partial_values,
::std::size_t __status_flags_size, ::std::uint16_t __status_flag_padding) const
std::size_t __status_flags_size, std::uint16_t __status_flag_padding) const
{
using _KernelName = __lookback_init_kernel<_Name..., _Type, _BinaryOp>;

Expand All @@ -156,11 +156,11 @@ struct __lookback_init_submitter<_FlagType, _Type, _BinaryOp,
}
};

template <::std::uint16_t __data_per_workitem, ::std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
template <std::uint16_t __data_per_workitem, std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
typename _KernelName>
struct __lookback_submitter;

template <::std::uint16_t __data_per_workitem, ::std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
template <std::uint16_t __data_per_workitem, std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
typename _InRng, typename _OutRng, typename _BinaryOp, typename _StatusFlags, typename _StatusValues,
typename _TileVals>
struct __lookback_kernel_func
Expand All @@ -171,12 +171,12 @@ struct __lookback_kernel_func
_InRng __in_rng;
_OutRng __out_rng;
_BinaryOp __binary_op;
::std::size_t __n;
std::size_t __n;
_StatusFlags __status_flags;
::std::size_t __status_flags_size;
std::size_t __status_flags_size;
_StatusValues __status_vals_full;
_StatusValues __status_vals_partial;
::std::size_t __current_num_items;
std::size_t __current_num_items;
_TileVals __tile_vals;

[[sycl::reqd_sub_group_size(SUBGROUP_SIZE)]] void
Expand All @@ -186,7 +186,7 @@ struct __lookback_kernel_func
auto __subgroup = __item.get_sub_group();
auto __local_id = __item.get_local_id(0);

::std::uint32_t __tile_id = 0;
std::uint32_t __tile_id = 0;

// Obtain unique ID for this work-group that will be used in decoupled lookback
if (__group.leader())
Expand All @@ -199,7 +199,7 @@ struct __lookback_kernel_func

__tile_id = sycl::group_broadcast(__group, __tile_id, 0);

::std::size_t __current_offset = static_cast<::std::size_t>(__tile_id) * __elems_in_tile;
std::size_t __current_offset = static_cast<std::size_t>(__tile_id) * __elems_in_tile;
auto __out_begin = __out_rng.begin() + __current_offset;

if (__current_offset >= __n)
Expand Down Expand Up @@ -265,7 +265,7 @@ struct __lookback_kernel_func
}
};

template <::std::uint16_t __data_per_workitem, ::std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
template <std::uint16_t __data_per_workitem, std::uint16_t __workgroup_size, typename _Type, typename _FlagType,
typename... _Name>
struct __lookback_submitter<__data_per_workitem, __workgroup_size, _Type, _FlagType,
oneapi::dpl::__par_backend_hetero::__internal::__optional_kernel_name<_Name...>>
Expand All @@ -274,15 +274,15 @@ struct __lookback_submitter<__data_per_workitem, __workgroup_size, _Type, _FlagT
template <typename _InRng, typename _OutRng, typename _BinaryOp, typename _StatusFlags, typename _StatusValues>
sycl::event
operator()(sycl::queue __q, sycl::event __prev_event, _InRng&& __in_rng, _OutRng&& __out_rng, _BinaryOp __binary_op,
::std::size_t __n, _StatusFlags&& __status_flags, ::std::size_t __status_flags_size,
std::size_t __n, _StatusFlags&& __status_flags, std::size_t __status_flags_size,
_StatusValues&& __status_vals_full, _StatusValues&& __status_vals_partial,
::std::size_t __current_num_items) const
std::size_t __current_num_items) const
{
using _LocalAccessorType = sycl::local_accessor<_Type, 1>;
using _KernelFunc =
__lookback_kernel_func<__data_per_workitem, __workgroup_size, _Type, _FlagType, ::std::decay_t<_InRng>,
::std::decay_t<_OutRng>, ::std::decay_t<_BinaryOp>, ::std::decay_t<_StatusFlags>,
::std::decay_t<_StatusValues>, ::std::decay_t<_LocalAccessorType>>;
__lookback_kernel_func<__data_per_workitem, __workgroup_size, _Type, _FlagType, std::decay_t<_InRng>,
std::decay_t<_OutRng>, std::decay_t<_BinaryOp>, std::decay_t<_StatusFlags>,
std::decay_t<_StatusValues>, std::decay_t<_LocalAccessorType>>;
using _KernelName = __lookback_kernel<_Name..., _KernelFunc>;

static constexpr std::uint32_t __elems_in_tile = __workgroup_size * __data_per_workitem;
Expand Down Expand Up @@ -314,7 +314,7 @@ __single_pass_scan(sycl::queue __queue, _InRange&& __in_rng, _OutRange&& __out_r
using _LookbackKernel =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__lookback_kernel<_KernelName>>;

const ::std::size_t __n = __in_rng.size();
const std::size_t __n = __in_rng.size();

if (__n == 0)
return sycl::event{};
Expand All @@ -335,44 +335,45 @@ __single_pass_scan(sycl::queue __queue, _InRange&& __in_rng, _OutRange&& __out_r
return oneapi::dpl::__par_backend_hetero::__parallel_transform_scan_single_group(
oneapi::dpl::__internal::__device_backend_tag{},
oneapi::dpl::execution::__dpl::make_device_policy<typename _KernelParam::kernel_name>(__queue),
::std::forward<_InRange>(__in_rng), ::std::forward<_OutRange>(__out_rng), __n,
std::forward<_InRange>(__in_rng), std::forward<_OutRange>(__out_rng), __n,
oneapi::dpl::__internal::__no_op{}, unseq_backend::__no_init_value<_Type>{}, __binary_op,
::std::true_type{});
std::true_type{});
}

constexpr ::std::size_t __workgroup_size = _KernelParam::workgroup_size;
constexpr ::std::size_t __data_per_workitem = _KernelParam::data_per_workitem;
constexpr std::size_t __workgroup_size = _KernelParam::workgroup_size;
constexpr std::size_t __data_per_workitem = _KernelParam::data_per_workitem;

// Avoid non_uniform n by padding up to a multiple of workgroup_size
::std::size_t __elems_in_tile = __workgroup_size * __data_per_workitem;
::std::size_t __num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile);
std::size_t __elems_in_tile = __workgroup_size * __data_per_workitem;
std::size_t __num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile);

constexpr int __status_flag_padding = SUBGROUP_SIZE;
::std::size_t __status_flags_size = __num_wgs + 1 + __status_flag_padding;
std::size_t __status_flags_size = __num_wgs + 1 + __status_flag_padding;

::std::size_t __mem_align_pad = sizeof(_Type);
::std::size_t __status_flags_bytes = __status_flags_size * sizeof(_FlagStorageType);
::std::size_t __status_vals_full_offset_bytes = __status_flags_size * sizeof(_Type);
::std::size_t __status_vals_partial_offset_bytes = __status_flags_size * sizeof(_Type);
::std::size_t __mem_bytes = __status_flags_bytes + __status_vals_full_offset_bytes +
std::size_t __mem_align_pad = sizeof(_Type);
std::size_t __status_flags_bytes = __status_flags_size * sizeof(_FlagStorageType);
std::size_t __status_vals_full_offset_bytes = __status_flags_size * sizeof(_Type);
std::size_t __status_vals_partial_offset_bytes = __status_flags_size * sizeof(_Type);
std::size_t __mem_bytes = __status_flags_bytes + __status_vals_full_offset_bytes +
__status_vals_partial_offset_bytes + __mem_align_pad;

::std::byte* __device_mem = reinterpret_cast<::std::byte*>(sycl::malloc_device(__mem_bytes, __queue));
assert(__device_mem);
std::byte* __device_mem = reinterpret_cast<std::byte*>(sycl::malloc_device(__mem_bytes, __queue));
if (!__device_mem)
throw std::bad_alloc();

_FlagStorageType* __status_flags = reinterpret_cast<_FlagStorageType*>(__device_mem);
::std::size_t __remainder = __mem_bytes - __status_flags_bytes;
std::size_t __remainder = __mem_bytes - __status_flags_bytes;
void* __vals_base_ptr = reinterpret_cast<void*>(__device_mem + __status_flags_bytes);
void* __vals_aligned_ptr = ::std::align(::std::alignment_of_v<_Type>, __status_vals_full_offset_bytes,
void* __vals_aligned_ptr = std::align(std::alignment_of_v<_Type>, __status_vals_full_offset_bytes,
__vals_base_ptr, __remainder);
_Type* __status_vals_full = reinterpret_cast<_Type*>(__vals_aligned_ptr);
_Type* __status_vals_partial = reinterpret_cast<_Type*>(__status_vals_full + __status_vals_full_offset_bytes);

auto __fill_event = __lookback_init_submitter<_FlagType, _Type, _BinaryOp, _LookbackInitKernel>{}(
__queue, __status_flags, __status_vals_partial, __status_flags_size, __status_flag_padding);

::std::size_t __current_num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile);
::std::size_t __current_num_items = __current_num_wgs * __workgroup_size;
std::size_t __current_num_wgs = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __elems_in_tile);
std::size_t __current_num_items = __current_num_wgs * __workgroup_size;

auto __prev_event =
__lookback_submitter<__data_per_workitem, __workgroup_size, _Type, _FlagType, _LookbackKernel>{}(
Expand Down Expand Up @@ -405,10 +406,10 @@ sycl::event
inclusive_scan(sycl::queue __queue, _InRng&& __in_rng, _OutRng&& __out_rng, _BinaryOp __binary_op,
_KernelParam __param = {})
{
auto __in_view = oneapi::dpl::__ranges::views::all(::std::forward<_InRng>(__in_rng));
auto __out_view = oneapi::dpl::__ranges::views::all(::std::forward<_OutRng>(__out_rng));
auto __in_view = oneapi::dpl::__ranges::views::all(std::forward<_InRng>(__in_rng));
auto __out_view = oneapi::dpl::__ranges::views::all(std::forward<_OutRng>(__out_rng));

return __impl::__single_pass_scan<true>(__queue, ::std::move(__in_view), ::std::move(__out_view), __binary_op,
return __impl::__single_pass_scan<true>(__queue, std::move(__in_view), std::move(__out_view), __binary_op,
__param);
}

Expand Down

0 comments on commit 2ff21c9

Please sign in to comment.