Skip to content

Commit

Permalink
Implement std::bit_cast (#2258)
Browse files Browse the repository at this point in the history
* Implement `std::bit_cast`

This backport C++20 `std::bit_cast` to be available in all standard modes.

As this requires compiler builtin support, we have a non-constexpr workaround with the usual memcpy implementation.

Fixes #2257

* Add additional contraint in the fallback mode

* Use bit_cast in cub

* Formatting fix?

* Fix typo
  • Loading branch information
miscco authored and bernhardmgruber committed Aug 28, 2024
1 parent 792c4e1 commit df56483
Show file tree
Hide file tree
Showing 23 changed files with 561 additions and 54 deletions.
13 changes: 0 additions & 13 deletions cub/test/c2h/utility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@
namespace c2h
{

/**
* Return a value of type `T0` with the same bitwise representation of `in`.
* Types `To` and `From` must be the same size.
*/
template <typename To, typename From>
__host__ __device__ To bit_cast(const From& in)
{
static_assert(sizeof(To) == sizeof(From), "Types must be same size.");
To out;
memcpy(&out, &in, sizeof(To));
return out;
}

// TODO(bgruber): duplicated version of thrust/testing/unittest/system.h
inline std::string demangle(const char* name)
{
Expand Down
4 changes: 3 additions & 1 deletion cub/test/catch2_radix_sort_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include <thrust/scan.h>
#include <thrust/sequence.h>

#include <cuda/std/bit>

#include <array>
#include <climits>
#include <cstdint>
Expand Down Expand Up @@ -199,7 +201,7 @@ c2h::host_vector<KeyT> get_striped_keys(const c2h::host_vector<KeyT>& h_keys, in

for (std::size_t i = 0; i < h_keys.size(); i++)
{
bit_ordered_t key = c2h::bit_cast<bit_ordered_t>(h_keys[i]);
bit_ordered_t key = ::cuda::std::bit_cast<bit_ordered_t>(h_keys[i]);

_CCCL_IF_CONSTEXPR (traits_t::CATEGORY == cub::FLOATING_POINT)
{
Expand Down
7 changes: 3 additions & 4 deletions cub/test/catch2_test_device_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@
#include <cub/device/device_histogram.cuh>
#include <cub/iterator/counting_input_iterator.cuh>

#include <cuda/std/__algorithm/copy.h>
#include <cuda/std/__cccl/dialect.h>
#include <cuda/std/__cccl/execution_space.h>
#include <cuda/std/__algorithm_>
#include <cuda/std/array>
#include <cuda/std/bit>
#include <cuda/std/type_traits>

#include <algorithm>
Expand Down Expand Up @@ -213,7 +212,7 @@ struct bit_and_anything
_CCCL_HOST_DEVICE auto operator()(const T& a, const T& b) const -> T
{
using U = typename cub::Traits<T>::UnsignedBits;
return c2h::bit_cast<T>(static_cast<U>(c2h::bit_cast<U>(a) & c2h::bit_cast<U>(b)));
return ::cuda::std::bit_cast<T>(static_cast<U>(::cuda::std::bit_cast<U>(a) & ::cuda::std::bit_cast<U>(b)));
}
};

Expand Down
4 changes: 2 additions & 2 deletions cub/test/catch2_test_device_radix_sort_keys.cu
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ CUB_TEST("DeviceRadixSort::SortKeys: negative zero handling", "[keys][radix][sor
using bits_t = typename cub::Traits<key_t>::UnsignedBits;

constexpr std::size_t num_bits = sizeof(key_t) * CHAR_BIT;
const key_t positive_zero = c2h::bit_cast<key_t>(bits_t(0));
const key_t negative_zero = c2h::bit_cast<key_t>(bits_t(1) << (num_bits - 1));
const key_t positive_zero = ::cuda::std::bit_cast<key_t>(bits_t(0));
const key_t negative_zero = ::cuda::std::bit_cast<key_t>(bits_t(1) << (num_bits - 1));

constexpr std::size_t max_num_items = 1 << 18;
const std::size_t num_items = GENERATE_COPY(take(1, random(max_num_items / 2, max_num_items)));
Expand Down
5 changes: 3 additions & 2 deletions cub/test/catch2_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_CCCL_NV_DIAG_SUPPRESS(177) // catch2 may contain unused variableds
#endif // nvcc-11

#include <cuda/std/bit>
#include <cuda/std/cmath>
#include <cuda/std/utility>

Expand Down Expand Up @@ -133,8 +134,8 @@ struct bitwise_equal
bool operator()(const T& a, const T& b) const
{
using bits_t = typename cub::Traits<T>::UnsignedBits;
bits_t a_bits = c2h::bit_cast<bits_t>(a);
bits_t b_bits = c2h::bit_cast<bits_t>(b);
bits_t a_bits = ::cuda::std::bit_cast<bits_t>(a);
bits_t b_bits = ::cuda::std::bit_cast<bits_t>(b);
return a_bits == b_bits;
}
};
Expand Down
4 changes: 2 additions & 2 deletions libcudacxx/include/cuda/std/__algorithm/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/is_trivially_copyable.h>
#include <cuda/std/__type_traits/remove_const.h>
#include <cuda/std/detail/libcxx/include/cstdint>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdint>
#include <cuda/std/cstdlib>
#include <cuda/std/detail/libcxx/include/cstring>

_LIBCUDACXX_BEGIN_NAMESPACE_STD
Expand Down
58 changes: 58 additions & 0 deletions libcudacxx/include/cuda/std/__bit/bit_cast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___BIT_BIT_CAST_H
#define _LIBCUDACXX___BIT_BIT_CAST_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_trivially_copyable.h>
#include <cuda/std/__type_traits/is_trivially_default_constructible.h>
#include <cuda/std/detail/libcxx/include/cstring>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

#if defined(_LIBCUDACXX_BIT_CAST)
# define _LIBCUDACXX_CONSTEXPR_BIT_CAST constexpr
#else // ^^^ _LIBCUDACXX_BIT_CAST ^^^ / vvv !_LIBCUDACXX_BIT_CAST vvv
# define _LIBCUDACXX_CONSTEXPR_BIT_CAST
#endif // !_LIBCUDACXX_BIT_CAST

template <class _To,
class _From,
__enable_if_t<(sizeof(_To) == sizeof(_From)), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _To), int> = 0,
__enable_if_t<_CCCL_TRAIT(is_trivially_copyable, _From), int> = 0>
_CCCL_NODISCARD _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_BIT_CAST _To bit_cast(const _From& __from) noexcept
{
#if defined(_LIBCUDACXX_BIT_CAST)
return _LIBCUDACXX_BIT_CAST(_To, __from);
#else // ^^^ _LIBCUDACXX_BIT_CAST ^^^ / vvv !_LIBCUDACXX_BIT_CAST vvv
static_assert(_CCCL_TRAIT(is_trivially_default_constructible, _To),
"The compiler does not support __builtin_bit_cast, so bit_cast additionally requires the destination "
"type to be trivially constructible");
_To __temp;
_CUDA_VSTD::memcpy(&__temp, &__from, sizeof(_To));
return __temp;
#endif // !_LIBCUDACXX_BIT_CAST
}

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___BIT_BIT_CAST_H
3 changes: 3 additions & 0 deletions libcudacxx/include/cuda/std/__cccl/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@
#endif // defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 1103000
#if !defined(_CCCL_CUDA_COMPILER) || (defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 1104000)
# define _CCCL_CUDACC_BELOW_11_4
#endif // defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 11074000
#if !defined(_CCCL_CUDA_COMPILER) || (defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 1107000)
# define _CCCL_CUDACC_BELOW_11_7
#endif // defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 1104000
#if !defined(_CCCL_CUDA_COMPILER) || (defined(_CCCL_CUDACC) && _CCCL_CUDACC_VER < 1108000)
# define _CCCL_CUDACC_BELOW_11_8
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__cuda/barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

#include <cuda/std/__atomic/api/owned.h>
#include <cuda/std/__type_traits/void_t.h> // _CUDA_VSTD::void_t
#include <cuda/std/detail/libcxx/include/cstdlib> // _LIBCUDACXX_UNREACHABLE
#include <cuda/std/cstdlib> // _LIBCUDACXX_UNREACHABLE

#if defined(_CCCL_CUDA_COMPILER)
# include <cuda/ptx> // cuda::ptx::*
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__exception/terminate.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# include <exception>
#endif // !_CCCL_COMPILER_NVRTC

#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4702) // unreachable code
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__expected/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
#include <cuda/std/__utility/in_place.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/__utility/swap.h>
#include <cuda/std/cstdlib>
#include <cuda/std/detail/libcxx/include/__assert>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/initializer_list>

#if _CCCL_STD_VER > 2011
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__iterator/advance.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
#include <cuda/std/__iterator/iterator_traits.h>
#include <cuda/std/__utility/convert_to_integral.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/cstdlib>
#include <cuda/std/detail/libcxx/include/__assert>
#include <cuda/std/detail/libcxx/include/cstdlib>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__iterator/move_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include <cuda/std/__type_traits/is_reference.h>
#include <cuda/std/__type_traits/remove_reference.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__memory/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <cuda/std/__type_traits/is_volatile.h>
#include <cuda/std/__utility/forward.h>
#include <cuda/std/cstddef>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>

#if defined(_CCCL_HAS_CONSTEXPR_ALLOCATION) && !defined(_CCCL_COMPILER_NVRTC)
# include <memory>
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__ranges/size.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <cuda/std/__utility/auto_cast.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/cstddef>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>

_LIBCUDACXX_BEGIN_NAMESPACE_RANGES

Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__ranges/subrange.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
#include <cuda/std/__type_traits/remove_const.h>
#include <cuda/std/__type_traits/remove_pointer.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/cstdlib>
#include <cuda/std/detail/libcxx/include/__assert>
#include <cuda/std/detail/libcxx/include/cstdlib>

#if _CCCL_STD_VER >= 2017 && !defined(_CCCL_COMPILER_MSVC_2017)

Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/__utility/unreachable.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# pragma system_header
#endif // no system header

#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

Expand Down
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/bit
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# pragma system_header
#endif // no system header

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/__bit/clz.h>
#include <cuda/std/__bit/ctz.h>
#include <cuda/std/__bit/popc.h>
Expand Down
5 changes: 4 additions & 1 deletion libcudacxx/include/cuda/std/detail/libcxx/include/__config
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ extern "C++" {
# define _LIBCUDACXX_ADDRESSOF(...) __builtin_addressof(__VA_ARGS__)
# endif // __check_builtin(builtin_addressof)

# if __check_builtin(builtin_bit_cast) || (defined(_CCCL_COMPILER_MSVC) && _MSC_VER > 1925)
// MSVC supports __builtin_bit_cast from 19.25 on
// clang-9 supports __builtin_bit_cast but it is not a constant expression
# if (__check_builtin(builtin_bit_cast) || (defined(_CCCL_COMPILER_MSVC) && _MSC_VER > 1925)) \
&& !defined(_CCCL_CUDACC_BELOW_11_7) && !(defined(_CCCL_COMPILER_CLANG) && __clang_major__ < 10)
# define _LIBCUDACXX_BIT_CAST(...) __builtin_bit_cast(__VA_ARGS__)
# endif // __check_builtin(builtin_bit_cast)

Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/std/detail/libcxx/include/variant
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ C++20
#include <cuda/std/__utility/swap.h>
#include <cuda/std/__variant/monostate.h>
#include <cuda/std/cstddef>
#include <cuda/std/detail/libcxx/include/cstdlib>
#include <cuda/std/cstdlib>
#include <cuda/std/initializer_list>
#include <cuda/std/tuple>
#include <cuda/std/version>
Expand Down
39 changes: 20 additions & 19 deletions libcudacxx/include/cuda/std/version
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
#endif

#if _CCCL_STD_VER >= 2014
# define __cccl_lib_bit_cast 201806L
# define __cccl_lib_chrono_udls 201304L
# define __cccl_lib_complex_udls 201309L
# ifdef _LIBCUDACXX_IS_CONSTANT_EVALUATED
# define __cccl_lib_constexpr_complex 201711L
# endif
# endif // _LIBCUDACXX_IS_CONSTANT_EVALUATED
# define __cccl_lib_concepts 202002L
# define __cccl_lib_exchange_function 201304L
# define __cccl_lib_expected 202211L
Expand All @@ -50,9 +51,9 @@
// # define __cccl_lib_quoted_string_io 201304L
# define __cccl_lib_result_of_sfinae 201210L
# define __cccl_lib_robust_nonmodifying_seq_ops 201304L
# if !defined(_LIBCUDACXX_HAS_NO_THREADS)
# ifndef _LIBCUDACXX_HAS_NO_THREADS
// # define __cccl_lib_shared_timed_mutex 201402L
# endif
# endif // !_LIBCUDACXX_HAS_NO_THREADS
# define __cccl_lib_span 202002L
// # define __cccl_lib_string_udls 201304L
# define __cccl_lib_transformation_trait_aliases 201304L
Expand All @@ -62,17 +63,17 @@
#endif // _CCCL_STD_VER >= 2014

#if _CCCL_STD_VER >= 2017
# if defined(_LIBCUDACXX_ADDRESSOF)
# ifdef _LIBCUDACXX_ADDRESSOF
# define __cccl_lib_addressof_constexpr 201603L
# endif
# endif // _LIBCUDACXX_ADDRESSOF
// # define __cccl_lib_allocator_traits_is_always_equal 201411L
// # define __cccl_lib_any 201606L
# define __cccl_lib_apply 201603L
# define __cccl_lib_array_constexpr 201603L
# define __cccl_lib_as_const 201510L
# if !defined(_LIBCUDACXX_HAS_NO_THREADS)
# ifndef _LIBCUDACXX_HAS_NO_THREADS
# define __cccl_lib_atomic_is_always_lock_free 201603L
# endif
# endif // _LIBCUDACXX_HAS_NO_THREADS
# define __cccl_lib_bind_front 201907L
# define __cccl_lib_bool_constant 201505L
// # define __cccl_lib_boyer_moore_searcher 201603L
Expand All @@ -84,15 +85,15 @@
// # define __cccl_lib_filesystem 201703L
# define __cccl_lib_gcd_lcm 201606L
# define __cccl_lib_hardware_interference_size 201703L
# if defined(_LIBCUDACXX_HAS_UNIQUE_OBJECT_REPRESENTATIONS)
# ifdef _LIBCUDACXX_HAS_UNIQUE_OBJECT_REPRESENTATIONS
# define __cccl_lib_has_unique_object_representations 201606L
# endif
# endif // _LIBCUDACXX_HAS_UNIQUE_OBJECT_REPRESENTATIONS
# define __cccl_lib_hypot 201603L
// # define __cccl_lib_incomplete_container_elements 201505L
# define __cccl_lib_invoke 201411L
# if !defined(_LIBCUDACXX_HAS_NO_IS_AGGREGATE)
# ifndef _LIBCUDACXX_HAS_NO_IS_AGGREGATE
# define __cccl_lib_is_aggregate 201703L
# endif
# endif // _LIBCUDACXX_HAS_NO_IS_AGGREGATE
# define __cccl_lib_is_invocable 201703L
# define __cccl_lib_is_swappable 201603L
# define __cccl_lib_launder 201606L
Expand Down Expand Up @@ -129,23 +130,23 @@
# define __cccl_lib_atomic_flag_test 201907L
# define __cccl_lib_atomic_float 201711L
# define __cccl_lib_atomic_lock_free_type_aliases 201907L
# if !defined(_LIBCUDACXX_HAS_NO_THREADS)
# ifndef _LIBCUDACXX_HAS_NO_THREADS
# define __cccl_lib_atomic_ref 201806L
# endif
# endif // _LIBCUDACXX_HAS_NO_THREADS
// # define __cccl_lib_atomic_shared_ptr 201711L
# define __cccl_lib_atomic_value_initialization 201911L
# if !defined(_LIBCUDACXX_AVAILABILITY_DISABLE_FTM___cpp_lib_atomic_wait)
# ifndef _LIBCUDACXX_AVAILABILITY_DISABLE_FTM___cpp_lib_atomic_wait
# define __cccl_lib_atomic_wait 201907L
# endif
# endif // _LIBCUDACXX_AVAILABILITY_DISABLE_FTM___cpp_lib_atomic_wait
# if !defined(_LIBCUDACXX_HAS_NO_THREADS) && !defined(_LIBCUDACXX_AVAILABILITY_DISABLE_FTM___cpp_lib_barrier)
# define __cccl_lib_barrier 201907L
# endif
# define __cccl_lib_bit_cast 201806L
# define __cccl_lib_bitops 201907L
# define __cccl_lib_bounded_array_traits 201902L
# if !defined(_LIBCUDACXX_NO_HAS_CHAR8_T)
# ifndef _LIBCUDACXX_NO_HAS_CHAR8_T
# define __cccl_lib_char8_t 201811L
# endif
# endif // _LIBCUDACXX_NO_HAS_CHAR8_T
// # define __cccl_lib_constexpr_algorithms 201806L
// # define __cccl_lib_constexpr_dynamic_alloc 201907L
# define __cccl_lib_constexpr_functional 201907L
Expand Down Expand Up @@ -175,9 +176,9 @@
// # define __cccl_lib_int_pow2 202002L
// # define __cccl_lib_integer_comparison_functions 202002L
// # define __cccl_lib_interpolate 201902L
# if defined(_LIBCUDACXX_IS_CONSTANT_EVALUATED)
# ifdef _LIBCUDACXX_IS_CONSTANT_EVALUATED
# define __cccl_lib_is_constant_evaluated 201811L
# endif
# endif // _LIBCUDACXX_IS_CONSTANT_EVALUATED
// # define __cccl_lib_is_layout_compatible 201907L
# define __cccl_lib_is_nothrow_convertible 201806L
// # define __cccl_lib_is_pointer_interconvertible 201907L
Expand Down
Loading

0 comments on commit df56483

Please sign in to comment.