From 75d0f10b7200bf1e855d6829080188d10c28b2cc Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Thu, 1 Aug 2024 01:16:43 +0200 Subject: [PATCH] Prune CUB's ChainedPolicy by __CUDA_ARCH_LIST__ --- cub/cub/util_device.cuh | 76 ++++++++++++++++++++++++++++- cub/test/catch2_test_util_device.cu | 72 +++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 1 deletion(-) diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index 714aa014ceb..740f384af70 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -358,7 +358,8 @@ struct SmVersionCacheTag {}; /** - * \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). + * \brief Retrieves the PTX virtual architecture that will be used on \p device (major * 100 + minor * 10). This value + * must be one of __CUDA_ARCH_LIST__. * * \note This function may cache the result internally. * \note This function is thread safe. @@ -635,11 +636,68 @@ struct ChainedPolicy template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Invoke(int device_ptx_version, FunctorT& op) { +#ifdef __CUDA_ARCH_LIST__ + return runtime_to_compiletime<__CUDA_ARCH_LIST__>(device_ptx_version, op); +#else if (device_ptx_version < PolicyPtxVersion) { return PrevPolicyT::Invoke(device_ptx_version, op); } return op.template Invoke(); +#endif + } + +private: + template + friend struct ChainedPolicy; // let us call invoke_static of other ChainedPolicy instantiations + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t runtime_to_compiletime(int device_ptx_version, FunctorT& op) + { + // we instantiate invoke_static for each CudaArches, but only call the one matching device_ptx_version + cudaError_t e = cudaSuccess; + const cudaError_t dummy[] = { + (device_ptx_version == CudaArches ? (e = invoke_static(op, ::cuda::std::true_type{})) + : cudaSuccess)...}; + (void) dummy; + return e; + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type) + { + // TODO(bgruber): drop diagnostic suppression in C++17 + _CCCL_DIAG_PUSH + _CCCL_DIAG_SUPPRESS_MSVC(4127) // suppress Conditional Expression is Constant + _CCCL_IF_CONSTEXPR (DevicePtxVersion < PolicyPtxVersion) + { + // TODO(bgruber): drop boolean tag dispatches in C++17, since _CCCL_IF_CONSTEXPR will discard this branch properly + return PrevPolicyT::template invoke_static( + op, ::cuda::std::bool_constant<(DevicePtxVersion < PolicyPtxVersion)>{}); + } + else + { + return DoInvoke(op, ::cuda::std::bool_constant= PolicyPtxVersion>{}); + } + _CCCL_DIAG_POP + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT& op, ::cuda::std::true_type) + { + return op.template Invoke(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t DoInvoke(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); } }; @@ -647,6 +705,9 @@ struct ChainedPolicy template struct ChainedPolicy { + template + friend struct ChainedPolicy; // befriend primary template, so it can call invoke_static + /// The policy for the active compiler pass using ActivePolicy = PolicyT; @@ -656,6 +717,19 @@ struct ChainedPolicy { return op.template Invoke(); } + +private: + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT& op, ::cuda::std::true_type) + { + return op.template Invoke(); + } + + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t invoke_static(FunctorT&, ::cuda::std::false_type) + { + _LIBCUDACXX_UNREACHABLE(); + } }; CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_util_device.cu b/cub/test/catch2_test_util_device.cu index c59c076ec50..9d0b9f6fab1 100644 --- a/cub/test/catch2_test_util_device.cu +++ b/cub/test/catch2_test_util_device.cu @@ -87,3 +87,75 @@ CUB_TEST("CUB correctly identifies the ptx version the kernel was compiled for", REQUIRE(ptx_version == kernel_cuda_arch); REQUIRE(host_ptx_version == kernel_cuda_arch); } + +#ifdef __CUDA_ARCH_LIST__ +CUB_TEST("PtxVersion returns a value from __CUDA_ARCH_LIST__", "[util][dispatch]") +{ + int ptx_version = 0; + cub::PtxVersion(ptx_version); + const auto arch_list = std::vector{__CUDA_ARCH_LIST__}; + REQUIRE(std::find(arch_list.begin(), arch_list.end(), ptx_version) != arch_list.end()); +} +#endif + +#ifdef __CUDA_ARCH_LIST__ +// We list policies for all virtual architectures that __CUDA_ARCH_LIST__ can contain, so the actual architectures the +// tests are compiled for should match to one of those +struct policy_hub +{ +# define GEN_POLICY(cur, prev) \ + struct policy##cur : cub::ChainedPolicy \ + { \ + static constexpr int value = cur; \ + } + // for the list of supported architectures, see libcudacxx/include/nv/target + GEN_POLICY(350, 350); + GEN_POLICY(370, 350); + GEN_POLICY(500, 370); + GEN_POLICY(520, 500); + GEN_POLICY(530, 520); + GEN_POLICY(600, 530); + GEN_POLICY(610, 600); + GEN_POLICY(620, 610); + GEN_POLICY(700, 620); + GEN_POLICY(720, 700); + GEN_POLICY(750, 720); + GEN_POLICY(800, 750); + GEN_POLICY(860, 800); + GEN_POLICY(870, 860); + GEN_POLICY(890, 870); + GEN_POLICY(900, 890); + GEN_POLICY(1000, 900); + // add more policies here when new architectures emerge + GEN_POLICY(2000, 1000); // non-existing architecture, just to test pruning +# undef GEN_POLICY + + using max_policy = policy2000; +}; + +// Check that selected is one of arches +template +struct check +{ + static_assert(::cuda::std::_Or<::cuda::std::bool_constant...>::value, ""); + using type = cudaError_t; +}; + +struct Closure +{ + // We need to fail template instantiation if ActivePolicy::value is not one from the __CUDA_ARCH_LIST__ + template + _CCCL_HOST_DEVICE auto Invoke() const -> typename check::type + { + return cudaSuccess; + } +}; + +CUB_TEST("ChainedPolicy prunes based on __CUDA_ARCH_LIST__", "[util][dispatch]") +{ + int ptx_version = 0; + cub::PtxVersion(ptx_version); + Closure c; + policy_hub::max_policy::Invoke(ptx_version, c); +} +#endif