diff --git a/ci-overview.md b/ci-overview.md index b23f49abf7..17033b9cbb 100644 --- a/ci-overview.md +++ b/ci-overview.md @@ -65,7 +65,7 @@ The syntax of the build and test scripts is the same: ./ci/test_thrust.sh #examples -./ci/build_thrust.sh g++ 17 70;80;86 +./ci/build_thrust.sh g++ 17 "70;80;86" ``` In summary, the heart of our build and test jobs is the corresponding build or test script. diff --git a/cub/cub/detail/detect_cuda_runtime.cuh b/cub/cub/detail/detect_cuda_runtime.cuh index a2af93b718..b8e776db74 100644 --- a/cub/cub/detail/detect_cuda_runtime.cuh +++ b/cub/cub/detail/detect_cuda_runtime.cuh @@ -27,20 +27,14 @@ ******************************************************************************/ /** - * \file + * @file * Utilities for CUDA dynamic parallelism. */ #pragma once -#include - #include -CUB_NAMESPACE_BEGIN -namespace detail -{ - #ifdef DOXYGEN_SHOULD_SKIP_THIS // Only parse this during doxygen passes: /** @@ -111,6 +105,3 @@ namespace detail #endif #endif // Do not document - -} // namespace detail -CUB_NAMESPACE_END diff --git a/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh b/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh index 25f5c25e7f..b005f361d7 100644 --- a/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh +++ b/cub/cub/device/dispatch/dispatch_adjacent_difference.cuh @@ -42,22 +42,14 @@ CUB_NAMESPACE_BEGIN - -template -void __global__ DeviceAdjacentDifferenceInitKernel(InputIteratorT first, - InputT *result, - OffsetT num_tiles, - int items_per_tile) +template +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceAdjacentDifferenceInitKernel(InputIteratorT first, + InputT *result, + OffsetT num_tiles, + int items_per_tile) { const int tile_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - AgentDifferenceInitT::Process(tile_idx, - first, - result, - num_tiles, - items_per_tile); + AgentDifferenceInitT::Process(tile_idx, first, result, num_tiles, items_per_tile); } template -void __global__ +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceAdjacentDifferenceDifferenceKernel(InputIteratorT input, InputT *first_tile_previous, OutputIteratorT result, diff --git a/cub/cub/device/dispatch/dispatch_batch_memcpy.cuh b/cub/cub/device/dispatch/dispatch_batch_memcpy.cuh index 4dfddce59e..04384ae045 100644 --- a/cub/cub/device/dispatch/dispatch_batch_memcpy.cuh +++ b/cub/cub/device/dispatch/dispatch_batch_memcpy.cuh @@ -70,9 +70,10 @@ struct AgentBatchMemcpyLargeBuffersPolicy template -__global__ void InitTileStateKernel(BufferOffsetScanTileStateT buffer_offset_scan_tile_state, - BlockOffsetScanTileStateT block_offset_scan_tile_state, - TileOffsetT num_tiles) +CUB_DETAIL_KERNEL_ATTRIBUTES void +InitTileStateKernel(BufferOffsetScanTileStateT buffer_offset_scan_tile_state, + BlockOffsetScanTileStateT block_offset_scan_tile_state, + TileOffsetT num_tiles) { // Initialize tile status buffer_offset_scan_tile_state.InitializeStatus(num_tiles); @@ -93,12 +94,13 @@ template __launch_bounds__(int(ChainedPolicyT::ActivePolicy::AgentLargeBufferPolicyT::BLOCK_THREADS)) - __global__ void MultiBlockBatchMemcpyKernel(InputBufferIt input_buffer_it, - OutputBufferIt output_buffer_it, - BufferSizeIteratorT buffer_sizes, - BufferTileOffsetItT buffer_tile_offsets, - TileT buffer_offset_tile, - TileOffsetT last_tile_offset) + CUB_DETAIL_KERNEL_ATTRIBUTES + void MultiBlockBatchMemcpyKernel(InputBufferIt input_buffer_it, + OutputBufferIt output_buffer_it, + BufferSizeIteratorT buffer_sizes, + BufferTileOffsetItT buffer_tile_offsets, + TileT buffer_offset_tile, + TileOffsetT last_tile_offset) { using StatusWord = typename TileT::StatusWord; using ActivePolicyT = typename ChainedPolicyT::ActivePolicy::AgentLargeBufferPolicyT; @@ -219,16 +221,17 @@ template __launch_bounds__(int(ChainedPolicyT::ActivePolicy::AgentSmallBufferPolicyT::BLOCK_THREADS)) - __global__ void BatchMemcpyKernel(InputBufferIt input_buffer_it, - OutputBufferIt output_buffer_it, - BufferSizeIteratorT buffer_sizes, - BufferOffsetT num_buffers, - BlevBufferSrcsOutItT blev_buffer_srcs, - BlevBufferDstsOutItT blev_buffer_dsts, - BlevBufferSizesOutItT blev_buffer_sizes, - BlevBufferTileOffsetsOutItT blev_buffer_tile_offsets, - BLevBufferOffsetTileState blev_buffer_scan_state, - BLevBlockOffsetTileState blev_block_scan_state) + CUB_DETAIL_KERNEL_ATTRIBUTES + void BatchMemcpyKernel(InputBufferIt input_buffer_it, + OutputBufferIt output_buffer_it, + BufferSizeIteratorT buffer_sizes, + BufferOffsetT num_buffers, + BlevBufferSrcsOutItT blev_buffer_srcs, + BlevBufferDstsOutItT blev_buffer_dsts, + BlevBufferSizesOutItT blev_buffer_sizes, + BlevBufferTileOffsetsOutItT blev_buffer_tile_offsets, + BLevBufferOffsetTileState blev_buffer_scan_state, + BLevBlockOffsetTileState blev_block_scan_state) { // Internal type used for storing a buffer's size using BufferSizeT = cub::detail::value_t; diff --git a/cub/cub/device/dispatch/dispatch_histogram.cuh b/cub/cub/device/dispatch/dispatch_histogram.cuh index b0b8d6fa87..400ea36305 100644 --- a/cub/cub/device/dispatch/dispatch_histogram.cuh +++ b/cub/cub/device/dispatch/dispatch_histogram.cuh @@ -89,7 +89,7 @@ CUB_NAMESPACE_BEGIN * Drain queue descriptor for dynamically mapping tile data onto thread blocks */ template -__global__ void +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceHistogramInitKernel(ArrayWrapper num_output_bins_wrapper, ArrayWrapper d_output_histograms_wrapper, GridQueue tile_queue) @@ -193,8 +193,8 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::AgentHistogramPolicyT::BLOCK_THREADS)) __global__ - void DeviceHistogramSweepKernel( +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::AgentHistogramPolicyT::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceHistogramSweepKernel( SampleIteratorT d_samples, ArrayWrapper num_output_bins_wrapper, ArrayWrapper num_privatized_bins_wrapper, diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 7b6c73a1dd..015a25ed01 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -38,7 +38,6 @@ CUB_NAMESPACE_BEGIN - template -void __global__ __launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS) -DeviceMergeSortBlockSortKernel(bool ping, - KeyInputIteratorT keys_in, - ValueInputIteratorT items_in, - KeyIteratorT keys_out, - ValueIteratorT items_out, - OffsetT keys_count, - KeyT *tmp_keys_out, - ValueT *tmp_items_out, - CompareOpT compare_op, - char *vshmem) +__launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortBlockSortKernel(bool ping, + KeyInputIteratorT keys_in, + ValueInputIteratorT items_in, + KeyIteratorT keys_out, + ValueIteratorT items_out, + OffsetT keys_count, + KeyT *tmp_keys_out, + ValueT *tmp_items_out, + CompareOpT compare_op, + char *vshmem) { extern __shared__ char shmem[]; using ActivePolicyT = typename ChainedPolicyT::ActivePolicy::MergeSortPolicy; @@ -95,19 +94,16 @@ DeviceMergeSortBlockSortKernel(bool ping, agent.Process(); } -template -__global__ void DeviceMergeSortPartitionKernel(bool ping, - KeyIteratorT keys_ping, - KeyT *keys_pong, - OffsetT keys_count, - OffsetT num_partitions, - OffsetT *merge_partitions, - CompareOpT compare_op, - OffsetT target_merged_tiles_number, - int items_per_tile) +template +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortPartitionKernel(bool ping, + KeyIteratorT keys_ping, + KeyT *keys_pong, + OffsetT keys_count, + OffsetT num_partitions, + OffsetT *merge_partitions, + CompareOpT compare_op, + OffsetT target_merged_tiles_number, + int items_per_tile) { OffsetT partition_idx = blockDim.x * blockIdx.x + threadIdx.x; @@ -136,17 +132,17 @@ template -void __global__ __launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS) -DeviceMergeSortMergeKernel(bool ping, - KeyIteratorT keys_ping, - ValueIteratorT items_ping, - OffsetT keys_count, - KeyT *keys_pong, - ValueT *items_pong, - CompareOpT compare_op, - OffsetT *merge_partitions, - OffsetT target_merged_tiles_number, - char *vshmem) +__launch_bounds__(ChainedPolicyT::ActivePolicy::MergeSortPolicy::BLOCK_THREADS) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceMergeSortMergeKernel(bool ping, + KeyIteratorT keys_ping, + ValueIteratorT items_ping, + OffsetT keys_count, + KeyT *keys_pong, + ValueT *items_pong, + CompareOpT compare_op, + OffsetT *merge_partitions, + OffsetT target_merged_tiles_number, + char *vshmem) { extern __shared__ char shmem[]; diff --git a/cub/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/cub/device/dispatch/dispatch_radix_sort.cuh index 9a75e6fe0b..d5d2ef93b7 100644 --- a/cub/cub/device/dispatch/dispatch_radix_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_radix_sort.cuh @@ -33,9 +33,6 @@ #pragma once -#include -#include - #include #include #include @@ -52,6 +49,9 @@ #include +#include +#include + // suppress warnings triggered by #pragma unroll: // "warning: loop not unrolled: the optimizer was unable to perform the requested transformation; the transformation might be disabled or specified as part of an unsupported transformation ordering [-Wpass-failed=transform-warning]" #if defined(__clang__) @@ -79,7 +79,7 @@ template < __launch_bounds__ (int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltUpsweepPolicy::BLOCK_THREADS) : int(ChainedPolicyT::ActivePolicy::UpsweepPolicy::BLOCK_THREADS))) -__global__ void DeviceRadixSortUpsweepKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortUpsweepKernel( const KeyT *d_keys, ///< [in] Input keys buffer OffsetT *d_spine, ///< [out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) OffsetT /*num_items*/, ///< [in] Total number of input data items @@ -138,7 +138,7 @@ template < typename ChainedPolicyT, ///< Chained tuning policy typename OffsetT> ///< Signed integer type for global offsets __launch_bounds__ (int(ChainedPolicyT::ActivePolicy::ScanPolicy::BLOCK_THREADS), 1) -__global__ void RadixSortScanBinsKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void RadixSortScanBinsKernel( OffsetT *d_spine, ///< [in,out] Privatized (per block) digit histograms (striped, i.e., 0s counts from each block, then 1s counts from each block, etc.) int num_counts) ///< [in] Total number of bin-counts { @@ -191,7 +191,7 @@ template < __launch_bounds__ (int((ALT_DIGIT_BITS) ? int(ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS) : int(ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS))) -__global__ void DeviceRadixSortDownsweepKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortDownsweepKernel( const KeyT *d_keys_in, ///< [in] Input keys buffer KeyT *d_keys_out, ///< [in] Output keys buffer const ValueT *d_values_in, ///< [in] Input values buffer @@ -255,7 +255,7 @@ template < typename OffsetT, ///< Signed integer type for global offsets typename DecomposerT = detail::identity_decomposer_t> __launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) -__global__ void DeviceRadixSortSingleTileKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortSingleTileKernel( const KeyT *d_keys_in, ///< [in] Input keys buffer KeyT *d_keys_out, ///< [in] Output keys buffer const ValueT *d_values_in, ///< [in] Input values buffer @@ -380,7 +380,7 @@ template < __launch_bounds__ (int((ALT_DIGIT_BITS) ? ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS : ChainedPolicyT::ActivePolicy::SegmentedPolicy::BLOCK_THREADS)) -__global__ void DeviceSegmentedRadixSortKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedRadixSortKernel( const KeyT *d_keys_in, ///< [in] Input keys buffer KeyT *d_keys_out, ///< [in] Output keys buffer const ValueT *d_values_in, ///< [in] Input values buffer @@ -552,7 +552,7 @@ template -__global__ __launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS) +CUB_DETAIL_KERNEL_ATTRIBUTES __launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS) void DeviceRadixSortHistogramKernel(OffsetT *d_bins_out, const KeyT *d_keys_in, OffsetT num_items, @@ -576,7 +576,7 @@ template < typename PortionOffsetT, typename AtomicOffsetT = PortionOffsetT, typename DecomposerT = detail::identity_decomposer_t> -__global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS) +CUB_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS) DeviceRadixSortOnesweepKernel (AtomicOffsetT* d_lookback, AtomicOffsetT* d_ctrs, OffsetT* d_bins_out, const OffsetT* d_bins_in, KeyT* d_keys_out, const KeyT* d_keys_in, ValueT* d_values_out, @@ -600,7 +600,7 @@ DeviceRadixSortOnesweepKernel template < typename ChainedPolicyT, typename OffsetT> -__global__ void DeviceRadixSortExclusiveSumKernel(OffsetT* d_bins) +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRadixSortExclusiveSumKernel(OffsetT* d_bins) { typedef typename ChainedPolicyT::ActivePolicy::ExclusiveSumPolicy ExclusiveSumPolicyT; const int RADIX_BITS = ExclusiveSumPolicyT::RADIX_BITS; diff --git a/cub/cub/device/dispatch/dispatch_reduce.cuh b/cub/cub/device/dispatch/dispatch_reduce.cuh index 698ce0552e..2dbcdc76fd 100644 --- a/cub/cub/device/dispatch/dispatch_reduce.cuh +++ b/cub/cub/device/dispatch/dispatch_reduce.cuh @@ -153,12 +153,12 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) -__global__ void DeviceReduceKernel(InputIteratorT d_in, - AccumT* d_out, - OffsetT num_items, - GridEvenShare even_share, - ReductionOpT reduction_op) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceReduceKernel(InputIteratorT d_in, + AccumT *d_out, + OffsetT num_items, + GridEvenShare even_share, + ReductionOpT reduction_op) { // Thread block type for reducing input tiles using AgentReduceT = @@ -232,12 +232,12 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) -__global__ void DeviceReduceSingleTileKernel(InputIteratorT d_in, - OutputIteratorT d_out, - OffsetT num_items, - ReductionOpT reduction_op, - InitT init) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SingleTilePolicy::BLOCK_THREADS), 1) // + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceReduceSingleTileKernel(InputIteratorT d_in, + OutputIteratorT d_out, + OffsetT num_items, + ReductionOpT reduction_op, + InitT init) { // Thread block type for reducing input tiles using AgentReduceT = @@ -358,15 +358,15 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) -__global__ void DeviceSegmentedReduceKernel( - InputIteratorT d_in, - OutputIteratorT d_out, - BeginOffsetIteratorT d_begin_offsets, - EndOffsetIteratorT d_end_offsets, - int /*num_segments*/, - ReductionOpT reduction_op, - InitT init) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReducePolicy::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES + void DeviceSegmentedReduceKernel(InputIteratorT d_in, + OutputIteratorT d_out, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets, + int /*num_segments*/, + ReductionOpT reduction_op, + InitT init) { // Thread block type for reducing input tiles using AgentReduceT = diff --git a/cub/cub/device/dispatch/dispatch_reduce_by_key.cuh b/cub/cub/device/dispatch/dispatch_reduce_by_key.cuh index a29c1376a4..5040e39f7b 100644 --- a/cub/cub/device/dispatch/dispatch_reduce_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_reduce_by_key.cuh @@ -131,7 +131,8 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReduceByKeyPolicyT::BLOCK_THREADS)) __global__ +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ReduceByKeyPolicyT::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceReduceByKeyKernel(KeysInputIteratorT d_keys_in, UniqueOutputIteratorT d_unique_out, ValuesInputIteratorT d_values_in, diff --git a/cub/cub/device/dispatch/dispatch_rle.cuh b/cub/cub/device/dispatch/dispatch_rle.cuh index 06c6fc90f7..4401a7c124 100644 --- a/cub/cub/device/dispatch/dispatch_rle.cuh +++ b/cub/cub/device/dispatch/dispatch_rle.cuh @@ -119,15 +119,15 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::RleSweepPolicyT::BLOCK_THREADS)) __global__ - void DeviceRleSweepKernel(InputIteratorT d_in, - OffsetsOutputIteratorT d_offsets_out, - LengthsOutputIteratorT d_lengths_out, - NumRunsOutputIteratorT d_num_runs_out, - ScanTileStateT tile_status, - EqualityOpT equality_op, - OffsetT num_items, - int num_tiles) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::RleSweepPolicyT::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceRleSweepKernel(InputIteratorT d_in, + OffsetsOutputIteratorT d_offsets_out, + LengthsOutputIteratorT d_lengths_out, + NumRunsOutputIteratorT d_num_runs_out, + ScanTileStateT tile_status, + EqualityOpT equality_op, + OffsetT num_items, + int num_tiles) { using AgentRlePolicyT = typename ChainedPolicyT::ActivePolicy::RleSweepPolicyT; diff --git a/cub/cub/device/dispatch/dispatch_scan.cuh b/cub/cub/device/dispatch/dispatch_scan.cuh index f16f1c0fd9..6893f24e1d 100644 --- a/cub/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/cub/device/dispatch/dispatch_scan.cuh @@ -68,7 +68,7 @@ CUB_NAMESPACE_BEGIN * Number of tiles */ template -__global__ void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles) +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles) { // Initialize tile status tile_state.InitializeStatus(num_tiles); @@ -94,9 +94,9 @@ __global__ void DeviceScanInitKernel(ScanTileStateT tile_state, int num_tiles) * (i.e., length of `d_selected_out`) */ template -__global__ void DeviceCompactInitKernel(ScanTileStateT tile_state, - int num_tiles, - NumSelectedIteratorT d_num_selected_out) +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceCompactInitKernel(ScanTileStateT tile_state, + int num_tiles, + NumSelectedIteratorT d_num_selected_out) { // Initialize tile status tile_state.InitializeStatus(num_tiles); @@ -165,13 +165,13 @@ template __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS)) - __global__ void DeviceScanKernel(InputIteratorT d_in, - OutputIteratorT d_out, - ScanTileStateT tile_state, - int start_tile, - ScanOpT scan_op, - InitValueT init_value, - OffsetT num_items) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanKernel(InputIteratorT d_in, + OutputIteratorT d_out, + ScanTileStateT tile_state, + int start_tile, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items) { using RealInitValueT = typename InitValueT::value_type; typedef typename ChainedPolicyT::ActivePolicy::ScanPolicyT ScanPolicyT; diff --git a/cub/cub/device/dispatch/dispatch_scan_by_key.cuh b/cub/cub/device/dispatch/dispatch_scan_by_key.cuh index b70e49be27..62df5c6b91 100644 --- a/cub/cub/device/dispatch/dispatch_scan_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_scan_by_key.cuh @@ -124,17 +124,17 @@ template > -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THREADS)) -__global__ void DeviceScanByKeyKernel(KeysInputIteratorT d_keys_in, - KeyT *d_keys_prev_in, - ValuesInputIteratorT d_values_in, - ValuesOutputIteratorT d_values_out, - ScanByKeyTileStateT tile_state, - int start_tile, - EqualityOp equality_op, - ScanOpT scan_op, - InitValueT init_value, - OffsetT num_items) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanByKeyKernel(KeysInputIteratorT d_keys_in, + KeyT *d_keys_prev_in, + ValuesInputIteratorT d_values_in, + ValuesOutputIteratorT d_values_out, + ScanByKeyTileStateT tile_state, + int start_tile, + EqualityOp equality_op, + ScanOpT scan_op, + InitValueT init_value, + OffsetT num_items) { using ScanByKeyPolicyT = typename ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT; @@ -166,12 +166,12 @@ __global__ void DeviceScanByKeyKernel(KeysInputIteratorT d_keys_in, } template -__global__ void DeviceScanByKeyInitKernel( - ScanTileStateT tile_state, - KeysInputIteratorT d_keys_in, - cub::detail::value_t *d_keys_prev_in, - unsigned items_per_tile, - int num_tiles) +CUB_DETAIL_KERNEL_ATTRIBUTES void +DeviceScanByKeyInitKernel(ScanTileStateT tile_state, + KeysInputIteratorT d_keys_in, + cub::detail::value_t *d_keys_prev_in, + unsigned items_per_tile, + int num_tiles) { // Initialize tile status tile_state.InitializeStatus(num_tiles); diff --git a/cub/cub/device/dispatch/dispatch_segmented_sort.cuh b/cub/cub/device/dispatch/dispatch_segmented_sort.cuh index 2eec9290bb..8cc2d01697 100644 --- a/cub/cub/device/dispatch/dispatch_segmented_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_segmented_sort.cuh @@ -104,7 +104,7 @@ template __launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREADS) -__global__ void DeviceSegmentedSortFallbackKernel( + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortFallbackKernel( const KeyT *d_keys_in_orig, KeyT *d_keys_out_orig, cub::detail::device_double_buffer d_keys_double_buffer, @@ -299,18 +299,18 @@ template __launch_bounds__(ChainedPolicyT::ActivePolicy::SmallAndMediumSegmentedSortPolicyT::BLOCK_THREADS) -__global__ void DeviceSegmentedSortKernelSmall( - unsigned int small_segments, - unsigned int medium_segments, - unsigned int medium_blocks, - const unsigned int *d_small_segments_indices, - const unsigned int *d_medium_segments_indices, - const KeyT *d_keys_in, - KeyT *d_keys_out, - const ValueT *d_values_in, - ValueT *d_values_out, - BeginOffsetIteratorT d_begin_offsets, - EndOffsetIteratorT d_end_offsets) + CUB_DETAIL_KERNEL_ATTRIBUTES + void DeviceSegmentedSortKernelSmall(unsigned int small_segments, + unsigned int medium_segments, + unsigned int medium_blocks, + const unsigned int *d_small_segments_indices, + const unsigned int *d_medium_segments_indices, + const KeyT *d_keys_in, + KeyT *d_keys_out, + const ValueT *d_values_in, + ValueT *d_values_out, + BeginOffsetIteratorT d_begin_offsets, + EndOffsetIteratorT d_end_offsets) { const unsigned int tid = threadIdx.x; const unsigned int bid = blockIdx.x; @@ -428,7 +428,7 @@ template __launch_bounds__(ChainedPolicyT::ActivePolicy::LargeSegmentPolicy::BLOCK_THREADS) -__global__ void DeviceSegmentedSortKernelLarge( + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortKernelLarge( const unsigned int *d_segments_indices, const KeyT *d_keys_in_orig, KeyT *d_keys_out_orig, @@ -687,7 +687,7 @@ template -__launch_bounds__(1) __global__ void +__launch_bounds__(1) CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentedSortContinuationKernel( LargeKernelT large_kernel, SmallKernelT small_kernel, diff --git a/cub/cub/device/dispatch/dispatch_select_if.cuh b/cub/cub/device/dispatch/dispatch_select_if.cuh index 56fa86e2ad..6d7dba3186 100644 --- a/cub/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/cub/device/dispatch/dispatch_select_if.cuh @@ -131,16 +131,16 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREADS)) __global__ - void DeviceSelectSweepKernel(InputIteratorT d_in, - FlagsInputIteratorT d_flags, - SelectedOutputIteratorT d_selected_out, - NumSelectedIteratorT d_num_selected_out, - ScanTileStateT tile_status, - SelectOpT select_op, - EqualityOpT equality_op, - OffsetT num_items, - int num_tiles) +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSelectSweepKernel(InputIteratorT d_in, + FlagsInputIteratorT d_flags, + SelectedOutputIteratorT d_selected_out, + NumSelectedIteratorT d_num_selected_out, + ScanTileStateT tile_status, + SelectOpT select_op, + EqualityOpT equality_op, + OffsetT num_items, + int num_tiles) { using AgentSelectIfPolicyT = typename ChainedPolicyT::ActivePolicy::SelectIfPolicyT; diff --git a/cub/cub/device/dispatch/dispatch_spmv_orig.cuh b/cub/cub/device/dispatch/dispatch_spmv_orig.cuh index c38c4bfb48..227c2a42ca 100644 --- a/cub/cub/device/dispatch/dispatch_spmv_orig.cuh +++ b/cub/cub/device/dispatch/dispatch_spmv_orig.cuh @@ -64,37 +64,33 @@ CUB_NAMESPACE_BEGIN /** * Spmv search kernel. Identifies merge path starting coordinates for each tile. */ -template < - typename AgentSpmvPolicyT, ///< Parameterized SpmvPolicy tuning policy type - typename ValueT, ///< Matrix and vector value type - typename OffsetT> ///< Signed integer type for sequence offsets -__global__ void DeviceSpmv1ColKernel( - SpmvParams spmv_params) ///< [in] SpMV input parameter bundle +template ///< Signed integer type for sequence offsets +CUB_DETAIL_KERNEL_ATTRIBUTES void +DeviceSpmv1ColKernel(SpmvParams spmv_params) ///< [in] SpMV input parameter bundle { - typedef CacheModifiedInputIterator< - AgentSpmvPolicyT::VECTOR_VALUES_LOAD_MODIFIER, - ValueT, - OffsetT> - VectorValueIteratorT; - - VectorValueIteratorT wrapped_vector_x(spmv_params.d_vector_x); + typedef CacheModifiedInputIterator + VectorValueIteratorT; - int row_idx = (blockIdx.x * blockDim.x) + threadIdx.x; - if (row_idx < spmv_params.num_rows) - { - OffsetT end_nonzero_idx = spmv_params.d_row_end_offsets[row_idx]; - OffsetT nonzero_idx = spmv_params.d_row_end_offsets[row_idx - 1]; + VectorValueIteratorT wrapped_vector_x(spmv_params.d_vector_x); - ValueT value = 0.0; - if (end_nonzero_idx != nonzero_idx) - { - value = spmv_params.d_values[nonzero_idx] * wrapped_vector_x[spmv_params.d_column_indices[nonzero_idx]]; - } + int row_idx = (blockIdx.x * blockDim.x) + threadIdx.x; + if (row_idx < spmv_params.num_rows) + { + OffsetT end_nonzero_idx = spmv_params.d_row_end_offsets[row_idx]; + OffsetT nonzero_idx = spmv_params.d_row_end_offsets[row_idx - 1]; - spmv_params.d_vector_y[row_idx] = value; + ValueT value = 0.0; + if (end_nonzero_idx != nonzero_idx) + { + value = spmv_params.d_values[nonzero_idx] * + wrapped_vector_x[spmv_params.d_column_indices[nonzero_idx]]; } -} + spmv_params.d_vector_y[row_idx] = value; + } +} /** * Spmv search kernel. Identifies merge path starting coordinates for each tile. @@ -104,7 +100,7 @@ template < typename OffsetT, ///< Signed integer type for sequence offsets typename CoordinateT, ///< Merge path coordinate type typename SpmvParamsT> ///< SpmvParams type -__global__ void DeviceSpmvSearchKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvSearchKernel( int num_merge_tiles, ///< [in] Number of SpMV merge tiles (spmv grid size) CoordinateT* d_tile_coordinates, ///< [out] Pointer to the temporary array of tile starting coordinates SpmvParamsT spmv_params) ///< [in] SpMV input parameter bundle @@ -158,7 +154,7 @@ template < bool HAS_ALPHA, ///< Whether the input parameter Alpha is 1 bool HAS_BETA> ///< Whether the input parameter Beta is 0 __launch_bounds__ (int(SpmvPolicyT::BLOCK_THREADS)) -__global__ void DeviceSpmvKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSpmvKernel( SpmvParams spmv_params, ///< [in] SpMV input parameter bundle CoordinateT* d_tile_coordinates, ///< [in] Pointer to the temporary array of tile starting coordinates KeyValuePair* d_tile_carry_pairs, ///< [out] Pointer to the temporary array carry-out dot product row-ids, one per block @@ -191,7 +187,8 @@ __global__ void DeviceSpmvKernel( template ///< Whether the input parameter Beta is 0 -__global__ void DeviceSpmvEmptyMatrixKernel(SpmvParams spmv_params) +CUB_DETAIL_KERNEL_ATTRIBUTES void +DeviceSpmvEmptyMatrixKernel(SpmvParams spmv_params) { const int row = static_cast(threadIdx.x + blockIdx.x * blockDim.x); @@ -218,7 +215,7 @@ template < typename OffsetT, ///< Signed integer type for global offsets typename ScanTileStateT> ///< Tile status interface type __launch_bounds__ (int(AgentSegmentFixupPolicyT::BLOCK_THREADS)) -__global__ void DeviceSegmentFixupKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceSegmentFixupKernel( PairsInputIteratorT d_pairs_in, ///< [in] Pointer to the array carry-out dot product row-ids, one per spmv block AggregatesOutputIteratorT d_aggregates_out, ///< [in,out] Output value aggregates OffsetT num_items, ///< [in] Total number of items to select from diff --git a/cub/cub/device/dispatch/dispatch_three_way_partition.cuh b/cub/cub/device/dispatch/dispatch_three_way_partition.cuh index 2277956e24..52f8dec7cd 100644 --- a/cub/cub/device/dispatch/dispatch_three_way_partition.cuh +++ b/cub/cub/device/dispatch/dispatch_three_way_partition.cuh @@ -59,7 +59,8 @@ template -__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ThreeWayPartitionPolicy::BLOCK_THREADS)) __global__ +__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ThreeWayPartitionPolicy::BLOCK_THREADS)) + CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceThreeWayPartitionKernel(InputIteratorT d_in, FirstOutputIteratorT d_first_part_out, SecondOutputIteratorT d_second_part_out, @@ -122,9 +123,10 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ThreeWayPartitionPolicy::BLO * (i.e., length of @p d_selected_out) */ template -__global__ void DeviceThreeWayPartitionInitKernel(ScanTileStateT tile_state, - int num_tiles, - NumSelectedIteratorT d_num_selected_out) +CUB_DETAIL_KERNEL_ATTRIBUTES void +DeviceThreeWayPartitionInitKernel(ScanTileStateT tile_state, + int num_tiles, + NumSelectedIteratorT d_num_selected_out) { // Initialize tile status tile_state.InitializeStatus(num_tiles); diff --git a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh index e70d28f229..c924e71ef7 100644 --- a/cub/cub/device/dispatch/dispatch_unique_by_key.cuh +++ b/cub/cub/device/dispatch/dispatch_unique_by_key.cuh @@ -60,7 +60,7 @@ template < typename EqualityOpT, ///< Equality operator type typename OffsetT> ///< Signed integer type for global offsets __launch_bounds__ (int(ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT::BLOCK_THREADS)) -__global__ void DeviceUniqueByKeySweepKernel( +CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceUniqueByKeySweepKernel( KeyInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys ValueInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of values KeyOutputIteratorT d_keys_out, ///< [out] Pointer to the output sequence of selected data items diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index d8caaedbb4..c7e15cafe0 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -70,7 +70,7 @@ CUB_NAMESPACE_BEGIN * \brief Empty kernel for querying PTX manifest metadata (e.g., version) for the current device */ template -__global__ void EmptyKernel(void) { } +CUB_DETAIL_KERNEL_ATTRIBUTES void EmptyKernel(void) { } #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/cub/cub/util_macro.cuh b/cub/cub/util_macro.cuh index c486aa439f..d8f46f0907 100644 --- a/cub/cub/util_macro.cuh +++ b/cub/cub/util_macro.cuh @@ -32,9 +32,11 @@ #pragma once -#include +#include +#include -#include "util_namespace.cuh" +#include +#include // _LIBCUDACXX_HIDDEN, _LIBCUDACXX_{CLANG,GCC}_DIAGNOSTIC_IGNORED CUB_NAMESPACE_BEGIN @@ -113,6 +115,19 @@ constexpr __host__ __device__ auto max CUB_PREVENT_MACRO_SUBSTITUTION(T &&t, #define CUB_STATIC_ASSERT(cond, msg) typedef int CUB_CAT(cub_static_assert, __LINE__)[(cond) ? 1 : -1] #endif +#ifndef CUB_DETAIL_KERNEL_ATTRIBUTES +#define CUB_DETAIL_KERNEL_ATTRIBUTES __global__ _LIBCUDACXX_HIDDEN +#endif + +/** + * @def CUB_DISABLE_KERNEL_VISIBILITY_WARNING_SUPPRESSION + * If defined, the default suppression of kernel visibility attribute warning is disabled. + */ +#if !defined(CUB_DISABLE_KERNEL_VISIBILITY_WARNING_SUPPRESSION) +_LIBCUDACXX_GCC_DIAGNOSTIC_IGNORED("-Wattributes") +_LIBCUDACXX_CLANG_DIAGNOSTIC_IGNORED("-Wattributes") +#endif + /** @} */ // end group UtilModule CUB_NAMESPACE_END diff --git a/cub/cub/util_namespace.cuh b/cub/cub/util_namespace.cuh index cc8e353767..27ff12dbba 100644 --- a/cub/cub/util_namespace.cuh +++ b/cub/cub/util_namespace.cuh @@ -38,7 +38,8 @@ // This is not used by this file; this is a hack so that we can detect the // CUB version from Thrust on older versions of CUB that did not have // version.cuh. -#include "version.cuh" +#include +#include // Prior to 1.13.1, only the PREFIX/POSTFIX macros were used. Notify users // that they must now define the qualifier macro, too. @@ -161,23 +162,25 @@ #define CUB_DETAIL_MAGIC_NS_NAME(...) CUB_DETAIL_IDENTITY(CUB_DETAIL_APPLY(CUB_DETAIL_DISPATCH, CUB_DETAIL_COUNT(__VA_ARGS__))(__VA_ARGS__)) #endif // !defined(CUB_DETAIL_MAGIC_NS_NAME) -#if defined(CUB_DISABLE_NAMESPACE_MAGIC) -#if !defined(CUB_WRAPPED_NAMESPACE) -#if !defined(CUB_IGNORE_NAMESPACE_MAGIC_ERROR) -#error "Disabling namespace magic is unsafe without wrapping namespace" -#endif // !defined(CUB_IGNORE_NAMESPACE_MAGIC_ERROR) -#endif // !defined(CUB_WRAPPED_NAMESPACE) -#define CUB_DETAIL_MAGIC_NS_BEGIN -#define CUB_DETAIL_MAGIC_NS_END +// clang-format off +#if defined(CUB_DISABLE_NAMESPACE_MAGIC) || defined(CUB_WRAPPED_NAMESPACE) +# if !defined(CUB_WRAPPED_NAMESPACE) +# if !defined(CUB_IGNORE_NAMESPACE_MAGIC_ERROR) +# error "Disabling namespace magic is unsafe without wrapping namespace" +# endif // !defined(CUB_IGNORE_NAMESPACE_MAGIC_ERROR) +# endif // !defined(CUB_WRAPPED_NAMESPACE) +# define CUB_DETAIL_MAGIC_NS_BEGIN +# define CUB_DETAIL_MAGIC_NS_END #else // not defined(CUB_DISABLE_NAMESPACE_MAGIC) -#if defined(_NVHPC_CUDA) -#define CUB_DETAIL_MAGIC_NS_BEGIN inline namespace CUB_DETAIL_MAGIC_NS_NAME(CUB_VERSION, NV_TARGET_SM_INTEGER_LIST) { -#define CUB_DETAIL_MAGIC_NS_END } -#else // not defined(_NVHPC_CUDA) -#define CUB_DETAIL_MAGIC_NS_BEGIN inline namespace CUB_DETAIL_MAGIC_NS_NAME(CUB_VERSION, __CUDA_ARCH_LIST__) { -#define CUB_DETAIL_MAGIC_NS_END } -#endif // not defined(_NVHPC_CUDA) +# if defined(_NVHPC_CUDA) +# define CUB_DETAIL_MAGIC_NS_BEGIN inline namespace CUB_DETAIL_MAGIC_NS_NAME(CUB_VERSION, NV_TARGET_SM_INTEGER_LIST) { +# define CUB_DETAIL_MAGIC_NS_END } +# else // not defined(_NVHPC_CUDA) +# define CUB_DETAIL_MAGIC_NS_BEGIN inline namespace CUB_DETAIL_MAGIC_NS_NAME(CUB_VERSION, __CUDA_ARCH_LIST__) { +# define CUB_DETAIL_MAGIC_NS_END } +# endif // not defined(_NVHPC_CUDA) #endif // not defined(CUB_DISABLE_NAMESPACE_MAGIC) +// clang-format on /** * \def CUB_NAMESPACE_BEGIN @@ -189,7 +192,7 @@ CUB_NS_PREFIX \ namespace cub \ { \ - CUB_DETAIL_MAGIC_NS_BEGIN + CUB_DETAIL_MAGIC_NS_BEGIN /** * \def CUB_NAMESPACE_END diff --git a/cub/docs/developer_overview.rst b/cub/docs/developer_overview.rst index c60886d0d0..892d0dc8f0 100644 --- a/cub/docs/developer_overview.rst +++ b/cub/docs/developer_overview.rst @@ -714,3 +714,33 @@ we introduced ``cub::detail::temporary_storage::layout``: // `allocation_2` alias `allocation_1`, safe to use in stream order use(allocation_2.get(), stream); + +Symbols visibility +==================================== + +Using CUB/Thrust in shared libraries is a known source of issues. +For a while, the solution to these issues consisted of wrapping CUB/Thrust namespaces with +the ``THRUST_CUB_WRAPPED_NAMESPACE`` macro so that different shared libraries have different symbols. +This solution has poor discoverability, +since issues present themselves in forms of segmentation faults, hangs, wrong results, etc. +To eliminate the symbol visibility issues on our end, we follow the following rules: + + #. Hiding kernel launchers: + it's important that kernel launchers like Thrust ``triple_chevron`` always reside in the same + library as the API using these kernel launchers. + + #. Hiding all kernels: + it's important that kernels always reside in the same library as the API using these kernels. + + #. Incorporating GPU architectures into symbol names: + it's important that kernels compiled for a given GPU architecture are always used by the host + API compiled for that architecture. + +To satisfy (1), ``thrust::cuda_cub::launcher::triple_chevron`` visibility is hidden. + +To satisfy (2), instead of annotating kernels as ``__global__`` we annotate them as +``CUB_DETAIL_KERNEL_ATTRIBUTES``. Apart from annotating a kernel as global function, the macro +contains hidden visibility attribute. + +To satisfy (3), CUB symbols are placed inside an inline namespace containing the set of +GPU architectures for which the TU is being compiled. diff --git a/thrust/thrust/detail/config/namespace.h b/thrust/thrust/detail/config/namespace.h index 9c79046169..91b9f879cd 100644 --- a/thrust/thrust/detail/config/namespace.h +++ b/thrust/thrust/detail/config/namespace.h @@ -16,6 +16,9 @@ #pragma once +#include +#include + /** * \file namespace.h * \brief Utilities that allow `thrust::` to be placed inside an @@ -84,6 +87,84 @@ #define THRUST_NS_QUALIFIER ::thrust #endif +// clang-format off +#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA +# if !defined(THRUST_DETAIL_ABI_NS_NAME) +# define THRUST_DETAIL_COUNT_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ + _14, _15, _16, _17, _18, _19, _20, N, ...) \ + N +# define THRUST_DETAIL_COUNT(...) \ + THRUST_DETAIL_IDENTITY(THRUST_DETAIL_COUNT_N(__VA_ARGS__, 20, 19, 18, 17, 16, 15, 14, 13, 12, \ + 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1)) +# define THRUST_DETAIL_IDENTITY(N) N +# define THRUST_DETAIL_APPLY(MACRO, ...) THRUST_DETAIL_IDENTITY(MACRO(__VA_ARGS__)) +# define THRUST_DETAIL_ABI_NS_NAME1(P1) \ + THRUST_##P1##_NS +# define THRUST_DETAIL_ABI_NS_NAME2(P1, P2) \ + THRUST_##P1##_##P2##_NS +# define THRUST_DETAIL_ABI_NS_NAME3(P1, P2, P3) \ + THRUST_##P1##_##P2##_##P3##_NS +# define THRUST_DETAIL_ABI_NS_NAME4(P1, P2, P3, P4) \ + THRUST_##P1##_##P2##_##P3##_##P4##_NS +# define THRUST_DETAIL_ABI_NS_NAME5(P1, P2, P3, P4, P5) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_NS +# define THRUST_DETAIL_ABI_NS_NAME6(P1, P2, P3, P4, P5, P6) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_NS +# define THRUST_DETAIL_ABI_NS_NAME7(P1, P2, P3, P4, P5, P6, P7) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_NS +# define THRUST_DETAIL_ABI_NS_NAME8(P1, P2, P3, P4, P5, P6, P7, P8) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_NS +# define THRUST_DETAIL_ABI_NS_NAME9(P1, P2, P3, P4, P5, P6, P7, P8, P9) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_NS +# define THRUST_DETAIL_ABI_NS_NAME10(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_NS +# define THRUST_DETAIL_ABI_NS_NAME11(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_NS +# define THRUST_DETAIL_ABI_NS_NAME12(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_NS +# define THRUST_DETAIL_ABI_NS_NAME13(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_NS +# define THRUST_DETAIL_ABI_NS_NAME14(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_NS +# define THRUST_DETAIL_ABI_NS_NAME15(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_NS +# define THRUST_DETAIL_ABI_NS_NAME16(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15, P16) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_##P16##_NS +# define THRUST_DETAIL_ABI_NS_NAME17(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15, P16, P17) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_##P16##_##P17##_NS +# define THRUST_DETAIL_ABI_NS_NAME18(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15, P16, P17, P18) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_##P16##_##P17##_##P18##_NS +# define THRUST_DETAIL_ABI_NS_NAME19(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15, P16, P17, P18, P19) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_##P16##_##P17##_##P18##_##P19##_NS +# define THRUST_DETAIL_ABI_NS_NAME20(P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15, P16, P17, P18, P19, P20) \ + THRUST_##P1##_##P2##_##P3##_##P4##_##P5##_##P6##_##P7##_##P8##_##P9##_##P10##_##P11##_##P12##_##P13##_##P14##_##P15##_##P16##_##P17##_##P18##_##P19##_##P20##_NS +# define THRUST_DETAIL_DISPATCH(N) THRUST_DETAIL_ABI_NS_NAME ## N +# define THRUST_DETAIL_ABI_NS_NAME(...) THRUST_DETAIL_IDENTITY(THRUST_DETAIL_APPLY(THRUST_DETAIL_DISPATCH, THRUST_DETAIL_COUNT(__VA_ARGS__))(__VA_ARGS__)) +# endif // !defined(THRUST_DETAIL_ABI_NS_NAME) + +# if defined(THRUST_DISABLE_ABI_NAMESPACE) || defined(THRUST_WRAPPED_NAMESPACE) +# if !defined(THRUST_WRAPPED_NAMESPACE) +# if !defined(THRUST_IGNORE_ABI_NAMESPACE_ERROR) +# error "Disabling ABI namespace is unsafe without wrapping namespace" +# endif // !defined(THRUST_IGNORE_ABI_NAMESPACE_ERROR) +# endif // !defined(THRUST_WRAPPED_NAMESPACE) +# define THRUST_DETAIL_ABI_NS_BEGIN +# define THRUST_DETAIL_ABI_NS_END +# else // not defined(THRUST_DISABLE_ABI_NAMESPACE) +# if defined(_NVHPC_CUDA) +# define THRUST_DETAIL_ABI_NS_BEGIN inline namespace THRUST_DETAIL_ABI_NS_NAME(THRUST_VERSION, NV_TARGET_SM_INTEGER_LIST) { +# define THRUST_DETAIL_ABI_NS_END } +# else // not defined(_NVHPC_CUDA) +# define THRUST_DETAIL_ABI_NS_BEGIN inline namespace THRUST_DETAIL_ABI_NS_NAME(THRUST_VERSION, __CUDA_ARCH_LIST__) { +# define THRUST_DETAIL_ABI_NS_END } +# endif // not defined(_NVHPC_CUDA) +# endif // not defined(THRUST_DISABLE_ABI_NAMESPACE) +#else // THRUST_DEVICE_SYSTEM != THRUST_DEVICE_SYSTEM_CUDA +# define THRUST_DETAIL_ABI_NS_BEGIN +# define THRUST_DETAIL_ABI_NS_END +#endif // THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA +// clang-format on + /** * \def THRUST_NAMESPACE_BEGIN * This macro is used to open a `thrust::` namespace block, along with any @@ -93,7 +174,8 @@ #define THRUST_NAMESPACE_BEGIN \ THRUST_NS_PREFIX \ namespace thrust \ - { + { \ + THRUST_DETAIL_ABI_NS_BEGIN /** * \def THRUST_NAMESPACE_END @@ -102,6 +184,7 @@ * This macro is defined by Thrust and may not be overridden. */ #define THRUST_NAMESPACE_END \ + THRUST_DETAIL_ABI_NS_END \ } /* end namespace thrust */ \ THRUST_NS_POSTFIX diff --git a/thrust/thrust/system/cuda/config.h b/thrust/thrust/system/cuda/config.h index f6c8b9cb38..f29a72ac86 100644 --- a/thrust/thrust/system/cuda/config.h +++ b/thrust/thrust/system/cuda/config.h @@ -101,6 +101,7 @@ #define THRUST_DEVICE_FUNCTION __device__ __forceinline__ #define THRUST_HOST_FUNCTION __host__ __forceinline__ #define THRUST_FUNCTION __host__ __device__ __forceinline__ + #if 0 #define THRUST_ARGS(...) __VA_ARGS__ #define THRUST_STRIP_PARENS(X) X diff --git a/thrust/thrust/system/cuda/detail/core/agent_launcher.h b/thrust/thrust/system/cuda/detail/core/agent_launcher.h index dbb26f33f7..825628e8b0 100644 --- a/thrust/thrust/system/cuda/detail/core/agent_launcher.h +++ b/thrust/thrust/system/cuda/detail/core/agent_launcher.h @@ -31,9 +31,12 @@ #include #if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC -#include #include #include +#include + +#include // _LIBCUDACXX_HIDDEN, _LIBCUDACXX_{CLANG,GCC}_DIAGNOSTIC_IGNORED + #include #include @@ -42,11 +45,23 @@ THRUST_NAMESPACE_BEGIN namespace cuda_cub { namespace core { +/** + * @def THRUST_DISABLE_KERNEL_VISIBILITY_WARNING_SUPPRESSION + * If defined, the default suppression of kernel visibility attribute warning is disabled. + */ +#if !defined(THRUST_DISABLE_KERNEL_VISIBILITY_WARNING_SUPPRESSION) +_LIBCUDACXX_GCC_DIAGNOSTIC_IGNORED("-Wattributes") +_LIBCUDACXX_CLANG_DIAGNOSTIC_IGNORED("-Wattributes") +#endif + +#ifndef THRUST_DETAIL_KERNEL_ATTRIBUTES +#define THRUST_DETAIL_KERNEL_ATTRIBUTES __global__ _LIBCUDACXX_HIDDEN +#endif #if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) #if 0 template - void __global__ + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(Args... args) { @@ -55,105 +70,105 @@ namespace core { } #else template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0) { extern __shared__ char shmem[]; Agent::entry(x0, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1) { extern __shared__ char shmem[]; Agent::entry(x0, x1, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, xC, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC, _xD xD) { extern __shared__ char shmem[]; Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, xC, xD, shmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent(_0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC, _xD xD, _xE xE) { extern __shared__ char shmem[]; @@ -166,7 +181,7 @@ namespace core { #if 0 template - void __global__ + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, Args... args) { @@ -176,7 +191,7 @@ namespace core { } #else template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0) { extern __shared__ char shmem[]; @@ -184,7 +199,7 @@ namespace core { Agent::entry(x0, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1) { extern __shared__ char shmem[]; @@ -192,7 +207,7 @@ namespace core { Agent::entry(x0, x1, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2) { extern __shared__ char shmem[]; @@ -200,7 +215,7 @@ namespace core { Agent::entry(x0, x1, x2, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3) { extern __shared__ char shmem[]; @@ -208,7 +223,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4) { extern __shared__ char shmem[]; @@ -216,7 +231,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5) { extern __shared__ char shmem[]; @@ -224,7 +239,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6) { extern __shared__ char shmem[]; @@ -232,7 +247,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7) { extern __shared__ char shmem[]; @@ -240,7 +255,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8) { extern __shared__ char shmem[]; @@ -248,7 +263,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9) { extern __shared__ char shmem[]; @@ -256,7 +271,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA) { extern __shared__ char shmem[]; @@ -264,7 +279,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB) { extern __shared__ char shmem[]; @@ -272,7 +287,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC) { extern __shared__ char shmem[]; @@ -280,7 +295,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, xC, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC, _xD xD) { extern __shared__ char shmem[]; @@ -288,7 +303,7 @@ namespace core { Agent::entry(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, xA, xB, xC, xD, vshmem); } template - void __global__ __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) + THRUST_DETAIL_KERNEL_ATTRIBUTES void __launch_bounds__(Agent::ptx_plan::BLOCK_THREADS) _kernel_agent_vshmem(char* vshmem, _0 x0, _1 x1, _2 x2, _3 x3, _4 x4, _5 x5, _6 x6, _7 x7, _8 x8, _9 x9, _xA xA, _xB xB, _xC xC, _xD xD, _xE xE) { extern __shared__ char shmem[]; @@ -299,71 +314,71 @@ namespace core { #else #if 0 template - void __global__ _kernel_agent(Args... args) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(Args... args) {} template - void __global__ _kernel_agent_vshmem(char*, Args... args) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*, Args... args) {} #else template - void __global__ _kernel_agent(_0) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0) {} template - void __global__ _kernel_agent(_0,_1) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1) {} template - void __global__ _kernel_agent(_0,_1,_2) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2) {} template - void __global__ _kernel_agent(_0,_1,_2,_3) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3) {} template - void __global__ _kernel_agent(_0,_1,_2,_3, _4) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3, _4) {} template - void __global__ _kernel_agent(_0,_1,_2,_3, _4, _5) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3, _4, _5) {} template - void __global__ _kernel_agent(_0,_1,_2,_3, _4, _5, _6) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3, _4, _5, _6) {} template - void __global__ _kernel_agent(_0,_1,_2,_3, _4, _5, _6, _7) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3, _4, _5, _6, _7) {} template - void __global__ _kernel_agent(_0,_1,_2,_3, _4, _5, _6, _7, _8) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0,_1,_2,_3, _4, _5, _6, _7, _8) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC, _xD) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC, _xD) {} template - void __global__ _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC, _xD, _xE) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB,_xC, _xD, _xE) {} //////////////////////////////////////////////////////////// template - void __global__ _kernel_agent_vshmem(char*,_0) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6, _7) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6, _7) {} template - void __global__ _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6, _7, _8) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0,_1,_2,_3, _4, _5, _6, _7, _8) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC, _xD) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC, _xD) {} template - void __global__ _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC, _xD, _xE) {} + THRUST_DETAIL_KERNEL_ATTRIBUTES void _kernel_agent_vshmem(char*,_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _xA, _xB, _xC, _xD, _xE) {} #endif #endif @@ -1139,7 +1154,7 @@ namespace core { }; -} // namespace core -} +} // namespace core +} // namespace cuda_cub THRUST_NAMESPACE_END #endif diff --git a/thrust/thrust/system/cuda/detail/core/triple_chevron_launch.h b/thrust/thrust/system/cuda/detail/core/triple_chevron_launch.h index 65a7283b74..28ac1a2305 100644 --- a/thrust/thrust/system/cuda/detail/core/triple_chevron_launch.h +++ b/thrust/thrust/system/cuda/detail/core/triple_chevron_launch.h @@ -29,15 +29,17 @@ #include #include #include -#include +#include // _LIBCUDACXX_HIDDEN + +#include THRUST_NAMESPACE_BEGIN namespace cuda_cub { namespace launcher { - struct triple_chevron + struct _LIBCUDACXX_HIDDEN triple_chevron { typedef size_t Size; dim3 const grid;