diff --git a/CMakeLists.txt b/CMakeLists.txt index c1f6e93d..98033426 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,10 @@ cmake_minimum_required(VERSION 3.17) -project(cudnn_frontend VERSION 1.7.0) +project(cudnn_frontend VERSION 1.8.0) option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) -option(CUDNN_FRONTEND_BUILD_UNIT_TESTS "Defines if unittests are built or not." ON) +option(CUDNN_FRONTEND_BUILD_TESTS "Defines if unittests are built or not." ON) if(MSVC OR MSYS OR MINGW) option(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS "Defines if python bindings are built or not." OFF) @@ -28,13 +28,11 @@ target_include_directories( ) # Find the cuda compiler -find_package(CUDAToolkit) +find_package(CUDAToolkit REQUIRED) -target_link_libraries( +target_include_directories( cudnn_frontend INTERFACE - - CUDA::cudart - CUDA::nvrtc + ${CUDAToolkit_INCLUDE_DIRS} ) target_compile_features(cudnn_frontend INTERFACE cxx_std_17) @@ -47,7 +45,7 @@ if (CUDNN_FRONTEND_BUILD_SAMPLES) add_subdirectory(samples) endif() -if (CUDNN_FRONTEND_BUILD_UNIT_TESTS) +if (CUDNN_FRONTEND_BUILD_TESTS) add_subdirectory(test) endif() diff --git a/README.md b/README.md index dfd88a27..9519f7ea 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ To provide a custom CUDNN installation path, use environment variable: `CUDNN_PA #### Checking the installation To test whether installation is successful, run: ``` -pytest test/python_fe +pytest test/python ``` NOTE: Only v1.0 API is exposed via python bindings. @@ -95,6 +95,8 @@ To skip building samples, use `-DCUDNN_FRONTEND_BUILD_SAMPLES=OFF`. To skip building python bindings, use `-DCUDNN_FRONTEND_BUILD_PYTHON_BINDINGS=OFF`. +To add debug symbols, use `-DCMAKE_BUILD_TYPE=Debug`. + In case, you have a stale cmake cache and want to update the cudnn/cuda paths, please delete the cmake cache (or build directory and redo the above steps). ## Debugging diff --git a/cmake/cuDNN.cmake b/cmake/cuDNN.cmake index f6640ca9..d948eb16 100644 --- a/cmake/cuDNN.cmake +++ b/cmake/cuDNN.cmake @@ -2,7 +2,7 @@ add_library(CUDNN::cudnn_all INTERFACE IMPORTED) find_path( CUDNN_INCLUDE_DIR cudnn.h - HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS} + HINTS $ENV{CUDNN_INCLUDE_PATH} ${CUDNN_INCLUDE_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS} PATH_SUFFIXES include REQUIRED ) @@ -14,7 +14,7 @@ string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") function(find_cudnn_library NAME) find_library( ${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" - HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR} + HINTS $ENV{CUDNN_LIBRARY_PATH} ${CUDNN_LIBRARY_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR} PATH_SUFFIXES lib64 lib/x64 lib REQUIRED ) diff --git a/docs/cuda_graphs.md b/docs/cuda_graphs.md new file mode 100644 index 00000000..7bd18dcd --- /dev/null +++ b/docs/cuda_graphs.md @@ -0,0 +1,31 @@ + + +### `populate_cuda_graph` + +The `populate_cuda_graph` function is a member function of the `Graph` class. It is used to populate a CUDA graph with the necessary data and operations. + +#### Parameters + +- `handle`: A cuDNN handle. +- `uid_to_device_ptrs`: A map of tensor UIDs to device pointers. +- `workspace`: A pointer to the workspace memory. +- `cudnn_cuda_graph`: A pointer to the CUDA graph. + +#### Return Value + +- An `error_t` object indicating the success or failure of the function. + +### `update_cuda_graph` + +The `update_cuda_graph` function is a member function of the `Graph` class. It is used to update a CUDA graph with the necessary data and operations. + +#### Parameters + +- `handle`: A cuDNN handle. +- `uid_to_device_ptrs`: A map of tensor UIDs to device pointers. +- `workspace`: A pointer to the workspace memory. +- `cudnn_cuda_graph`: A pointer to the CUDA graph. + +#### Return Value + +- An `error_t` object indicating the success or failure of the function. diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index e62428d8..0293959f 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -15,9 +15,11 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2 - Python sample: [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb) +- Python sample with paged caches: [samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb) + - C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) -- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) +- Python tests: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py) #### Configurable Options: @@ -38,22 +40,35 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2 - `dropout mask` that matches the attention weights' dimensions, indicating which weights to drop. The dimensions that are passed as 1 will apply a broadcasted dropout mask. - `dropout scale` used to adjust the scale of the remaining weights accordingly, such as $1 / (1 - \text{dropout probability})$. - Packed layout: With packed layout, the query, key, value, and output tensor should be [ragged tensors](https://www.tensorflow.org/guide/ragged_tensor), which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the `Tensor_attributes.set_ragged_offset()` method. the ragged offset tensor must be a tensor of size $(B + 1, 1, 1, 1)$ that contains the nested tensor's offset in terms of number of elements (not bytes). The last value of the offset tensor specifies the offset of the past-the-end element of the ragged tensor. See Appendix A for more information on the supported layouts. +- Paged attention: with paged K and/or V caches, the K/V blocks no longer need to be contiguous, allowing users to better utilize memory by avoiding fragmentation. + - Users must therefore: + - Pass a `page table k` tensor containing offsets to the container with K blocks. This is optional, and only needed if the K cache is paged. + - Pass a `page table v` tensor containing offsets to the container with V blocks. This is optional, and only needed if the V cache is paged. + - Pass anything required for `Padding mask` above (i.e., per-batch sequence lengths for both K and V caches). This is needed if at least one of the K/V caches are paged. + - Optionally, but recommended, pass the maximum sequence length for the K/V caches. When omitted, it will be (over)estimated, which could result in a corrupted graph in some corner cases. + - Offsets to the K/V containers will be calculcated as + - $Kcache[b,h,s,d] = K[page\ table\ k[b,1,s / bs_k, 1],h,s\ mod\ bs_{k},d]$ + - $Vcache[b,h,s,d] = V[page\ table\ v[b,1,s / bs_v, 1],h,s\ mod\ bs_{v},d]$ + - See also the [PagedAttention paper](https://arxiv.org/abs/2309.06180). ##### Input Tensors: -| Tensor Name | Device | Data Type | Dimensions | -|-------------------------------------|------------|----------------|----------------------------------------------------------------------------------------------------------------| -| Q | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ | -| K | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$ | -| V | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$ | -| (Bias mask) Bias Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | -| (Padding mask) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ | -| (Padding mask) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ | -| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | -| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | -| (Custom Dropout Mask) Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | -| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ | -| (Packed Layout) Ragged Offset | GPU | INT32 | $(B + 1, 1, 1, 1)$ | +| Tensor Name | Device | Data Type | Dimensions | +|------------------------------------------------|------------|----------------|----------------------------------------------------------------------------------------------------------------| +| Q | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ | +| K | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$, or $(num\_blocks_{k}, H_{k}, bs_{k}, D_{qk})$ in case of paged K cache | +| V | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$, or $(num\_blocks_{v}, H_{v}, bs_{v}, D_{v})$ in case of paged V cache | +| (Bias mask) Bias Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Padding mask/Paged Caches) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Padding mask/Paged Caches) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ | +| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ | +| (Custom Dropout Mask) Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ | +| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ | +| (Packed Layout) Ragged Offset | GPU | INT32 | $(B + 1, 1, 1, 1)$ | +| (Paged Attention) Page Table K | GPU | INT32 | $(B, 1, ceil(S_{kv}/bs_{k}), 1)$ | +| (Paged Attention) Page Table V | GPU | INT32 | $(B, 1, ceil(S_{kv}/bs_{v}), 1)$ | +| (Paged Attention) Max Sequence Length KV | CPU | INT32 or INT64 | $(1, 1, 1, 1)$ | ##### Output Tensors @@ -73,6 +88,10 @@ Where, - $S_{kv}$ is the sequence length of the key and value - $D_{qk}$ is the embedding dimension per head of query and key - $D_{v}$ is the embedding dimension per head of value +- $bs_{k}$ is the (power of 2) block size of the K container +- $bs_{v}$ is the (power of 2) block size of the V container +- $num\_blocks_{k}$ is the number of blocks in the K container +- $num\_blocks_{v}$ is the number of blocks in the V container #### Group-query attention (GQA) and Multi-query attention (MQA) @@ -146,6 +165,16 @@ set_dropout(std::shared_ptr mask, SDPA_attributes& set_compute_data_type(DataType_t value); + +SDPA_attributes& +set_paged_attention_k_table(std::shared_ptr value); + +SDPA_attributes& +set_paged_attention_v_table(std::shared_ptr value); + +SDPA_attributes& +set_paged_attention_max_seq_len_kv(int const value); + ``` #### Python API: @@ -153,8 +182,8 @@ set_compute_data_type(DataType_t value); ``` Args: q (cudnn_tensor): The query data. - k (cudnn_tensor): The key data. - v (cudnn_tensor): The value data. + k (cudnn_tensor): The key data. When page_table_k is provided, 'k' is a container of non-contiguous key data. + v (cudnn_tensor): The value data. When page_table_v is provided, 'v' is a container of non-contiguous value data. is_inference (bool): Whether it is an inference step or training step. attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. @@ -166,6 +195,9 @@ Args: use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask + paged_attention_k_table (Optional[cudnn_tensor]): The page table to look up offsets into 'k' + paged_attention_v_table (Optional[cudnn_tensor]): The page table to look up offsets into 'v' + paged_attention_max_seq_len_kv (Optional[integer]): The maximum sequence length for k/v caches when paged attention is active. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -182,7 +214,7 @@ This operation computes gradient tensors for scaled dot product attention (SDPA) - C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa) -- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) +- Python tests: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py) #### Configurable Options: diff --git a/include/cudnn_frontend/cudnn_interface.h b/include/cudnn_frontend/cudnn_interface.h index fc61de6c..17ff44d2 100644 --- a/include/cudnn_frontend/cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_interface.h @@ -17,6 +17,81 @@ namespace cudnn_frontend { +namespace detail { +inline void +assign_uid(graph::Tensor_attributes* const tensor, + int64_t& potential_uid, + std::unordered_set const& used_uids) { + // get_next_potential_uid + while (used_uids.find(potential_uid) != used_uids.end()) { + ++potential_uid; + } + + tensor->set_uid(potential_uid); + ++potential_uid; // increment, as used its used now +} + +// TODO: Always returns OK. Can the status and error message be accessed from tensor descriptor? +inline error_t +create_cudnn_tensor( + std::shared_ptr const& props, + std::unordered_map>& tensors, + int64_t& potential_uid, + std::unordered_set const& used_uids) { + // Assign tensor a uid + if (props->has_uid() == false) { + assign_uid(props.get(), potential_uid, used_uids); + } + + // Check whether backend tensor already created + auto tensor_uid = props->get_uid(); + if (tensors.find(tensor_uid) != tensors.end()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Backend Tensor named '" << props->get_name() << "' with UID " << tensor_uid + << " already created."); + return {error_code_t::OK, ""}; + } + CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating Backend Tensor named '" << props->get_name() << "' with UID " + << tensor_uid); + + auto&& tensor_builder = cudnn_frontend::TensorBuilder(); + + tensor_builder.setDim(props->get_dim().size(), props->get_dim().data()) + .setStrides(props->get_stride().size(), props->get_stride().data()) + .setId(tensor_uid) + .setAlignment(16) + .setDataType(props->get_data_type()) + .setVirtual(props->get_is_virtual()) + .setByValue(props->get_is_pass_by_value()) + .setReorderType(props->get_reordering_type()); + + if (auto ragged_offset_props = props->get_ragged_offset()) { + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids)); + tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid())); + } + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto tensor = tensor_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF( + tensor.get_status() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, tensor.get_error()); + tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); +#else + // build() can throw + // wrap in try catch + try { + auto tensor = tensor_builder.build(); + tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); + } catch (cudnn_frontend::cudnnException& e) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + } +#endif + + return {error_code_t::OK, ""}; +} +} // namespace detail + class ICudnn { protected: using uid_t = int64_t; @@ -43,78 +118,6 @@ class ICudnn { bool is_dynamic_shape_enabled = false; std::shared_ptr kernel_cache = nullptr; - void - assign_uid(graph::Tensor_attributes* const tensor, - int64_t& potential_uid, - std::unordered_set const& used_uids) const { - // get_next_potential_uid - while (used_uids.find(potential_uid) != used_uids.end()) { - ++potential_uid; - } - - tensor->set_uid(potential_uid); - ++potential_uid; // increment, as used its used now - } - - // TODO: Always returns OK. Can the status and error message be accessed from tensor descriptor? - error_t - create_cudnn_tensor(std::shared_ptr const& props, - std::unordered_map>& tensors, - int64_t& potential_uid, - std::unordered_set const& used_uids) const { - // Assign tensor a uid - if (props->has_uid() == false) { - assign_uid(props.get(), potential_uid, used_uids); - } - - // Check whether backend tensor already created - auto tensor_uid = props->get_uid(); - if (tensors.find(tensor_uid) != tensors.end()) { - CUDNN_FE_LOG_LABEL_ENDL("INFO: Backend Tensor named '" << props->get_name() << "' with UID " << tensor_uid - << " already created."); - return {error_code_t::OK, ""}; - } - CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating Backend Tensor named '" << props->get_name() << "' with UID " - << tensor_uid); - - auto&& tensor_builder = cudnn_frontend::TensorBuilder(); - - tensor_builder.setDim(props->get_dim().size(), props->get_dim().data()) - .setStrides(props->get_stride().size(), props->get_stride().data()) - .setId(tensor_uid) - .setAlignment(16) - .setDataType(props->get_data_type()) - .setVirtual(props->get_is_virtual()) - .setByValue(props->get_is_pass_by_value()) - .setReorderType(props->get_reordering_type()); - - if (auto ragged_offset_props = props->get_ragged_offset()) { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(ragged_offset_props, tensors, potential_uid, used_uids)); - tensor_builder.setRaggedOffset(tensors.at(ragged_offset_props->get_uid())); - } - -#ifdef NV_CUDNN_DISABLE_EXCEPTION - // disable exception macro is defined. Calling build will not throw. - // Check status of desc and return error. - auto tensor = tensor_builder.build(); - RETURN_CUDNN_FRONTEND_ERROR_IF( - tensor.get_status() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, tensor.get_error()); - tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); -#else - // build() can throw - // wrap in try catch - try { - auto tensor = tensor_builder.build(); - tensors.emplace(tensor_uid, std::make_shared(std::move(tensor))); - } catch (cudnn_frontend::cudnnException& e) { - RETURN_CUDNN_FRONTEND_ERROR_IF( - e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); - } -#endif - - return {error_code_t::OK, ""}; - } - error_t create_cudnn_operation_graph(cudnnHandle_t handle) { std::vector cudnn_operations; diff --git a/include/cudnn_frontend/graph_helpers.h b/include/cudnn_frontend/graph_helpers.h index 15e1ac31..c84a38ad 100644 --- a/include/cudnn_frontend/graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -120,6 +120,20 @@ typedef struct [[nodiscard]] error_object { } \ CUDNN_FRONTEND_WHILE_FALSE +#define CHECK_CU_ERROR(x) \ + do { \ + if (auto cu_retval = x; cu_retval != CUDA_SUCCESS) { \ + std::stringstream error_msg; \ + const char* error_code_string; \ + detail::cu_get_error_string(cu_retval, &error_code_string); \ + error_msg << #x << " failed with " << error_code_string; \ + getLogger() << "[cudnn_frontend] ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__ \ + << std::endl; \ + return {error_code_t::CUDA_API_FAILED, error_msg.str()}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + NLOHMANN_JSON_SERIALIZE_ENUM(error_code_t, { {error_code_t::OK, "OK"}, diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index 1d84850b..1304f8ed 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -30,7 +30,7 @@ namespace cudnn_frontend::graph { -class Graph : public INode { +class Graph : public ICudnn, public INode { private: std::unordered_set> full_graph_inputs; std::unordered_set used_uids; @@ -75,6 +75,9 @@ class Graph : public INode { (is_dynamic_shape_enabled || kernel_cache != nullptr) && detail::get_backend_version() < 90400, error_code_t::GRAPH_NOT_SUPPORTED, "Dynamic shapes or kernel caching enabled, but cuDNN version < 9.4!"); + RETURN_CUDNN_FRONTEND_ERROR_IF(((is_dynamic_shape_enabled == false) && (kernel_cache != nullptr)), + error_code_t::GRAPH_NOT_SUPPORTED, + "Kernel caching enabled but dynamic shapes is disabled"); return {error_code_t::OK, ""}; } @@ -99,8 +102,9 @@ class Graph : public INode { virtual error_t collect_tensors_in_workspace_node( - std::unordered_map>> &worskspace_modifications, - int64_t &) const { + std::unordered_map>> + &worskspace_modifications, + int64_t &) const override { for (auto [uid, value] : deserialized_workspace_modifications) { worskspace_modifications.emplace(uid, value); } @@ -137,6 +141,8 @@ class Graph : public INode { tensor_to_pointer_map.emplace(uid, nv_bfloat16_value_ptr); } else if (int32_t *int32_t_value_ptr = std::get_if(&value)) { tensor_to_pointer_map.emplace(uid, int32_t_value_ptr); + } else if (int64_t *int64_t_value_ptr = std::get_if(&value)) { + tensor_to_pointer_map.emplace(uid, int64_t_value_ptr); } else if (float *float_value_ptr = std::get_if(&value)) { tensor_to_pointer_map.emplace(uid, float_value_ptr); } else { @@ -223,6 +229,293 @@ class Graph : public INode { public: Graph() : INode(detail::Context{}) {} + error_t + update_cuda_graph(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + tensor_uid_to_pointer_map.reserve(tensor_to_pointer_map.size()); + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return update_cuda_graph(handle, tensor_uid_to_pointer_map, workspace, cudnn_cuda_graph); + } + + error_t + update_cuda_graph(cudnnHandle_t handle, + std::unordered_map &uid_to_device_ptrs, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // Initializes this cudnn graph + RETURN_CUDNN_FRONTEND_ERROR_IF( + cudnn_cuda_graph == nullptr, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should not be a nullptr"); + + size_t num_root_nodes; + CHECK_CUDA_ERROR(detail::cuda_graph_get_root_nodes(cudnn_cuda_graph, nullptr, &num_root_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + num_root_nodes != 1, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should have exactly 1 root node."); + + cudaGraphNode_t current_node = nullptr; + CHECK_CUDA_ERROR(detail::cuda_graph_get_root_nodes(cudnn_cuda_graph, ¤t_node, &num_root_nodes)); + + /////////////////////////////////////// + //// PASS BY VALUE TENSOR HANDLING //// + /////////////////////////////////////// + // Add pass_by_value data pointers to uid_to_pointer map + // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid while + // making the cuda graph. cuda graph will then keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, tensor_to_pass_by_value being deallocated does not affect these cpu + // value's. + // No cuda graph nodes are required for handling fe owned pass by value tensors. + std::unordered_map tensor_to_pass_by_value; + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, tensor_to_pass_by_value)); + + // Make sure device pointer is provided for all uids expected for this plan + std::vector device_ptrs; + std::vector uids; + + device_ptrs.reserve(variant_pack_uids.size()); + uids.reserve(variant_pack_uids.size()); + + for (auto const &uid : variant_pack_uids) { + auto search = uid_to_device_ptrs.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(search == uid_to_device_ptrs.end(), + error_code_t::INVALID_VARIANT_PACK, + "Uid " + std::to_string(uid) + " does not exist in variant pack."); + device_ptrs.push_back(search->second); + uids.push_back(uid); + } + + //////////////////////////// + //// WORKSPACE HANDLING //// + //////////////////////////// + // Get all types of extra calls that FE has to do on user workspace. + std::unordered_map>> workspace_modifications; + int64_t workspace_offset = 0; + CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); + + for (auto const &[uid, data] : workspace_modifications) { + (void)uid; + const auto &[operation_type, offset, vec_data] = data; + + // 0 means memcpy + if (operation_type == 0) { + CHECK_CUDA_ERROR( + detail::cuda_graph_add_memcpy_node_set_params_1D(current_node, + static_cast(workspace) + offset, + vec_data.data(), + vec_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + // 1 means memset + else if (operation_type == 1) { + // offset from workspace + void *device_ptr = static_cast(workspace) + offset; + int64_t memset_size = static_cast(vec_data[0]); + + cudaMemsetParams params; + params.dst = device_ptr; + params.elementSize = sizeof(char); + params.value = 0x0; + params.width = memset_size; + params.height = 1; // 1D memset currently + params.pitch = 0; // unused + + CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node_set_params(current_node, ¶ms)); + } + // Other values do not correspond to cuda APIs + + CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_root_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + num_root_nodes != 1, error_code_t::INVALID_VALUE, "cudnn_cuda_graph should have exactly 1 root node."); + CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, ¤t_node, &num_root_nodes)); + } + + /////////////////// + //// BE GRAPH //// + /////////////////// + cudaGraph_t backend_cuda_graph; + CHECK_CUDA_ERROR(detail::cuda_graph_child_graph_node_get_graph(current_node, &backend_cuda_graph)); + + detail::backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, workspace)); + + int64_t candidate = plans.candidate; + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(candidate)); + CHECK_CUDNN_ERROR(detail::update_cuda_graph(handle, + plans.execution_plans[candidate]->get_raw_desc(), + variant_pack_descriptor.get_ptr(), + backend_cuda_graph)); + + // There should be nothing after the backend graph + CHECK_CUDA_ERROR(detail::cuda_graph_node_get_dependent_nodes(current_node, nullptr, &num_root_nodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF(num_root_nodes != 0, + error_code_t::INVALID_VALUE, + "cudnn_cuda_graph should have no graph nodes after the backend graph node."); + + return {error_code_t::OK, ""}; + } + + error_t + populate_cuda_graph(cudnnHandle_t handle, + std::unordered_map, void *> &tensor_to_pointer_map, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // First get all the uids from the map + std::unordered_map tensor_uid_to_pointer_map; + tensor_uid_to_pointer_map.reserve(tensor_to_pointer_map.size()); + for (auto const &[tensor, pointer] : tensor_to_pointer_map) { + tensor_uid_to_pointer_map.emplace(tensor->get_uid(), pointer); + } + + return populate_cuda_graph(handle, tensor_uid_to_pointer_map, workspace, cudnn_cuda_graph); + } + + error_t + populate_cuda_graph(cudnnHandle_t handle, + std::unordered_map &uid_to_device_ptrs, + void *workspace, + cudaGraph_t cudnn_cuda_graph) { + // Check if the cuda graph is empty + size_t numNodes = 0; + CHECK_CU_ERROR(detail::cu_graph_get_nodes(cudnn_cuda_graph, nullptr, &numNodes)); + RETURN_CUDNN_FRONTEND_ERROR_IF(numNodes != 0, + error_code_t::INVALID_VALUE, + "cuda graph provided to populate is not empty. cuDNN requires it to be empty " + "for the corresponding update APIs to work correctly."); + + // This function makes linear cuda graphs. And that makes it easy to walk + // the graph when updating it. + // So just keeping track of the last node in the cuda graph is sufficient. + cudaGraphNode_t last_node = nullptr; + + /////////////////////////////////////// + //// PASS BY VALUE TENSOR HANDLING //// + /////////////////////////////////////// + // Add pass_by_value data pointers to uid_to_pointer map + // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid while + // making the cuda graph. cuda graph will then keep a copy of the kernel parameters, meaning that at the time of + // launching the cuda_graph executable, tensor_to_pass_by_value being deallocated does not affect these cpu + // value's. + // No cuda graph nodes are required for handling fe owned pass by value tensors. + std::unordered_map tensor_to_pass_by_value; + CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); + CHECK_CUDNN_FRONTEND_ERROR( + extend_tensor_map_with_pass_by_value_tensors_(uid_to_device_ptrs, tensor_to_pass_by_value)); + + ///////////////////////////////// + //// WORKSPACE HANDLING //// + ///////////////////////////////// + // Get all types of extra calls that FE has to do on user workspace. + std::unordered_map>> workspace_modifications; + int64_t workspace_offset = 0; + CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_subtree(workspace_modifications, workspace_offset)); + + for (auto const &[uid, data] : workspace_modifications) { + (void)uid; + const auto &[operation_type, offset, vec_data] = data; + + cudaGraphNode_t node = nullptr; + + // 0 means memcpy + if (operation_type == 0) { + CHECK_CUDA_ERROR(detail::cuda_graph_add_memcpy_node_1D(&node, + cudnn_cuda_graph, + &last_node, + last_node != nullptr, + static_cast(workspace) + offset, + vec_data.data(), + vec_data.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + // 1 means memset + else if (operation_type == 1) { + // offset from workspace + void *device_ptr = static_cast(workspace) + offset; + int64_t memset_size = static_cast(vec_data[0]); + + cudaMemsetParams params; + params.dst = device_ptr; + params.elementSize = sizeof(char); + params.value = 0x0; + params.width = memset_size; + params.height = 1; // 1D memset currently + params.pitch = 0; // unused + + CHECK_CUDA_ERROR(detail::cuda_graph_add_memset_node( + &node, cudnn_cuda_graph, &last_node, last_node != nullptr, ¶ms)); + } + // Other values do not correspond to cuda APIs + + last_node = node; + } + + ////////////// + // BE graph // + ////////////// + + // Get the BE's cuda graph + + // Make sure device pointer is provided for all uids expected for this plan + std::vector device_ptrs; + device_ptrs.reserve(variant_pack_uids.size()); + std::vector uids; + uids.reserve(variant_pack_uids.size()); + for (auto const &uid : variant_pack_uids) { + auto search = uid_to_device_ptrs.find(uid); + RETURN_CUDNN_FRONTEND_ERROR_IF(search == uid_to_device_ptrs.end(), + error_code_t::INVALID_VARIANT_PACK, + "Uid " + std::to_string(uid) + " does not exist in variant pack."); + device_ptrs.push_back(search->second); + uids.push_back(uid); + } + + // Create the variant pack to pass to backend + detail::backend_descriptor variant_pack_descriptor(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR); + RETURN_CUDNN_FRONTEND_ERROR_IF(variant_pack_descriptor.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + "Failed to create variant pack's backend descriptor."); + CHECK_CUDNN_FRONTEND_ERROR(create_variant_pack(variant_pack_descriptor, device_ptrs, uids, workspace)); + + // Get the plan candidate. It only makes to sense to make cuda graph after execution plan has been built. + // And in that case the candidate would have been set. + int64_t candidate = plans.candidate; + CHECK_CUDNN_FRONTEND_ERROR(plans.is_plan_index_executable(candidate)); + + // Finally get the backend cuda graph. + cudaGraph_t backend_cuda_graph; + // Initialize the cudnn cuda graph. + // The responsibility to destroy is on the user. + detail::cu_graph_create(&backend_cuda_graph, 0); // 0 is just what the API says to pass + + CHECK_CUDNN_ERROR(detail::populate_cuda_graph(handle, + plans.execution_plans[candidate]->get_raw_desc(), + variant_pack_descriptor.get_ptr(), + backend_cuda_graph)); + + // Clone BE graph into a graph_node + // This same call also places the newly created into FE's graph + // TODO: BE graph is at the end, so put in appropriate dependencies + cudaGraphNode_t backend_cuda_graph_node; + detail::cuda_graph_add_child_graph_node( + &backend_cuda_graph_node, cudnn_cuda_graph, &last_node, last_node != nullptr, backend_cuda_graph); + + // Destroy the BE graph as it now has been cloned into a node + // It was initialized by internals of backend, but the responsibility to destroy it is on FE. + CHECK_CUDA_ERROR(detail::cuda_graph_destroy(backend_cuda_graph)); + + return {error_code_t::OK, ""}; + } + error_t validate() { CUDNN_FE_LOG_LABEL_ENDL(""); @@ -1545,25 +1838,6 @@ Graph::rmsnorm_backward(std::shared_ptr dy, return {DX, DScale, DBias}; } -// inline std::array, 2> -// Graph::scaled_dot_product_attention(std::shared_ptr q, -// std::shared_ptr k, -// std::shared_ptr v, -// Scaled_dot_product_attention_attributes options) { -// // Make required output tensors -// auto O = options.outputs.O = output_tensor(options.get_name() + "_output"); -// auto S = options.outputs.S = output_tensor(options.get_name() + "_softmax_output"); - -// // Set inputs -// options.inputs.Q = q; -// options.inputs.K = k; -// options.inputs.V = v; - -// sub_nodes.emplace_back(std::make_unique(std::move(options), context)); - -// return {O, S}; -// } - inline std::array, 2> Graph::sdpa(std::shared_ptr q, std::shared_ptr k, diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index 748681b3..b8719d65 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -29,7 +29,7 @@ class Tensor_attributes { // In approach 1, users provide a value to embed into the graph. // In approach 2, users set is_pass_by_value boolean and then pass a pointer to scalar value with execute() API. // A closed set of types that are allowed to be passed by value. - using pass_by_values_t = std::variant; + using pass_by_values_t = std::variant; error_t validate() const { @@ -123,6 +123,13 @@ class Tensor_attributes { data_type = DataType_t::INT32; } + Tensor_attributes(int64_t const& scalar) { + pass_by_value = scalar; + is_pass_by_value = true; + dim = stride = {1}; + data_type = DataType_t::INT64; + } + std::string get_name() const { return name; @@ -1408,6 +1415,7 @@ class SDPA_attributes : public Attributes { std::optional sliding_window_length; std::optional dropout_probability; std::optional attn_scale_value; + std::optional max_seq_len_kv; public: enum class input_names { @@ -1422,6 +1430,8 @@ class SDPA_attributes : public Attributes { Offset, Dropout_mask, Dropout_scale, + Page_table_K, + Page_table_V }; std::unordered_map> inputs; enum class output_names { O, Stats, RNG_DUMP }; @@ -1447,7 +1457,7 @@ class SDPA_attributes : public Attributes { SDPA_attributes& set_attn_scale(std::shared_ptr value) { - inputs[SDPA_attributes::input_names::Attn_scale] = value; + inputs[SDPA_attributes::input_names::Attn_scale] = std::move(value); return *this; } @@ -1459,7 +1469,7 @@ class SDPA_attributes : public Attributes { SDPA_attributes& set_bias(std::shared_ptr value) { - inputs[SDPA_attributes::input_names::Bias] = value; + inputs[SDPA_attributes::input_names::Bias] = std::move(value); return *this; } @@ -1477,13 +1487,13 @@ class SDPA_attributes : public Attributes { SDPA_attributes& set_seq_len_q(std::shared_ptr value) { - inputs[SDPA_attributes::input_names::SEQ_LEN_Q] = value; + inputs[SDPA_attributes::input_names::SEQ_LEN_Q] = std::move(value); return *this; } SDPA_attributes& set_seq_len_kv(std::shared_ptr value) { - inputs[SDPA_attributes::input_names::SEQ_LEN_KV] = value; + inputs[SDPA_attributes::input_names::SEQ_LEN_KV] = std::move(value); return *this; } @@ -1510,22 +1520,40 @@ class SDPA_attributes : public Attributes { std::shared_ptr seed, std::shared_ptr offset) { dropout_probability = probability; - inputs[SDPA_attributes::input_names::Seed] = seed; - inputs[SDPA_attributes::input_names::Offset] = offset; + inputs[SDPA_attributes::input_names::Seed] = std::move(seed); + inputs[SDPA_attributes::input_names::Offset] = std::move(offset); return *this; } SDPA_attributes& set_dropout(std::shared_ptr mask, std::shared_ptr scale) { - inputs[SDPA_attributes::input_names::Dropout_mask] = mask; - inputs[SDPA_attributes::input_names::Dropout_scale] = scale; + inputs[SDPA_attributes::input_names::Dropout_mask] = std::move(mask); + inputs[SDPA_attributes::input_names::Dropout_scale] = std::move(scale); return *this; } // For debugging purposes only. SDPA_attributes& set_rng_dump(std::shared_ptr value) { - outputs[SDPA_attributes::output_names::RNG_DUMP] = value; + outputs[SDPA_attributes::output_names::RNG_DUMP] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_k_table(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Page_table_K] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_v_table(std::shared_ptr value) { + inputs[SDPA_attributes::input_names::Page_table_V] = std::move(value); + return *this; + } + + SDPA_attributes& + set_paged_attention_max_seq_len_kv(int const value) { + max_seq_len_kv = value; return *this; } }; @@ -1657,6 +1685,9 @@ class SDPA_backward_attributes : public Attributes { std::optional dropout_probability; std::optional attn_scale_value; + std::optional max_total_seq_len_q; + std::optional max_total_seq_len_kv; + bool is_deterministic_algorithm = false; public: @@ -1741,6 +1772,18 @@ class SDPA_backward_attributes : public Attributes { return *this; } + SDPA_backward_attributes& + set_max_total_seq_len_q(int64_t const value) { + max_total_seq_len_q = value; + return *this; + } + + SDPA_backward_attributes& + set_max_total_seq_len_kv(int64_t const value) { + max_total_seq_len_kv = value; + return *this; + } + SDPA_backward_attributes& set_causal_mask(bool const value) { causal_mask = value; @@ -2074,6 +2117,19 @@ class Slice_attributes : public Attributes { } }; +class PagedCacheLoad_attributes : public Attributes { + friend class Attributes; + friend class PagedCacheLoadNode; + friend class INode; + + public: + enum class input_names { container, seqLen, pageTable }; + std::unordered_map> inputs; + enum class output_names { yOut }; + std::unordered_map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(PagedCacheLoad_attributes, name, compute_data_type, inputs, outputs) +}; + } // namespace graph } // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/node/batchnorm.h b/include/cudnn_frontend/node/batchnorm.h index c870dcb6..54f1b345 100644 --- a/include/cudnn_frontend/node/batchnorm.h +++ b/include/cudnn_frontend/node/batchnorm.h @@ -76,7 +76,7 @@ class BatchNormNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/batchnorm_inference.h b/include/cudnn_frontend/node/batchnorm_inference.h index 8ffeb6c7..50e8e97f 100644 --- a/include/cudnn_frontend/node/batchnorm_inference.h +++ b/include/cudnn_frontend/node/batchnorm_inference.h @@ -50,7 +50,7 @@ class BatchnormInferenceNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/bn_finalize.h b/include/cudnn_frontend/node/bn_finalize.h index bbab9566..24606447 100644 --- a/include/cudnn_frontend/node/bn_finalize.h +++ b/include/cudnn_frontend/node/bn_finalize.h @@ -59,7 +59,7 @@ class BatchNormFinalizeNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index 2e256899..f108e554 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -70,7 +70,7 @@ class DgradNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/conv_fprop.h b/include/cudnn_frontend/node/conv_fprop.h index 2859b586..d0faa483 100644 --- a/include/cudnn_frontend/node/conv_fprop.h +++ b/include/cudnn_frontend/node/conv_fprop.h @@ -83,7 +83,7 @@ class ConvolutionNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/conv_wgrad.h b/include/cudnn_frontend/node/conv_wgrad.h index 95dd6399..1bc4e1c0 100644 --- a/include/cudnn_frontend/node/conv_wgrad.h +++ b/include/cudnn_frontend/node/conv_wgrad.h @@ -66,7 +66,7 @@ class WgradNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/dbn.h b/include/cudnn_frontend/node/dbn.h index 49b79047..00bd5fdd 100644 --- a/include/cudnn_frontend/node/dbn.h +++ b/include/cudnn_frontend/node/dbn.h @@ -69,7 +69,7 @@ class DBNNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/dbn_weight.h b/include/cudnn_frontend/node/dbn_weight.h index be0d7b81..89a72b57 100644 --- a/include/cudnn_frontend/node/dbn_weight.h +++ b/include/cudnn_frontend/node/dbn_weight.h @@ -73,7 +73,7 @@ class DBNWeightNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/dln.h b/include/cudnn_frontend/node/dln.h index cf43f716..fd86c131 100644 --- a/include/cudnn_frontend/node/dln.h +++ b/include/cudnn_frontend/node/dln.h @@ -92,7 +92,7 @@ class DLNNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/genstats.h b/include/cudnn_frontend/node/genstats.h index 4ccf1dda..4b4bb893 100644 --- a/include/cudnn_frontend/node/genstats.h +++ b/include/cudnn_frontend/node/genstats.h @@ -65,7 +65,7 @@ class GenstatsNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/instancenorm.h b/include/cudnn_frontend/node/instancenorm.h index f50b8583..1b6bb597 100644 --- a/include/cudnn_frontend/node/instancenorm.h +++ b/include/cudnn_frontend/node/instancenorm.h @@ -91,7 +91,7 @@ class InstanceNormNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { @@ -242,7 +242,7 @@ class DINNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/layernorm.h b/include/cudnn_frontend/node/layernorm.h index 2f09f57c..66ba5404 100644 --- a/include/cudnn_frontend/node/layernorm.h +++ b/include/cudnn_frontend/node/layernorm.h @@ -135,7 +135,7 @@ class LayerNormNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index e5aecd00..ff676ba7 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -64,7 +64,7 @@ class MatmulNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/paged_cache_load.h b/include/cudnn_frontend/node/paged_cache_load.h new file mode 100644 index 00000000..2e5320c7 --- /dev/null +++ b/include/cudnn_frontend/node/paged_cache_load.h @@ -0,0 +1,123 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "pointwise.h" +#include "reduction.h" + +namespace cudnn_frontend::graph { + +class PagedCacheLoadNode : public NodeCRTP { + public: + PagedCacheLoad_attributes attributes; + + PagedCacheLoadNode(PagedCacheLoad_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::PAGED_CACHE_LOAD; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + managed_backend_descriptor_t& raw_operations, + std::unordered_map>& tensors) const override final { + CUDNN_FRONTEND_UNUSED(raw_operations); + + auto&& paged_cache_load_operation_builder = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(container, PagedCacheLoad_attributes::input_names::container); + paged_cache_load_operation_builder.setcontainerDesc(*(tensors.at(container->second->get_uid()))); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(pageTable, PagedCacheLoad_attributes::input_names::pageTable); + paged_cache_load_operation_builder.setpageTableDesc(*(tensors.at(pageTable->second->get_uid()))); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(seqLen, PagedCacheLoad_attributes::input_names::seqLen); + paged_cache_load_operation_builder.setsequenceDesc(*(tensors.at(seqLen->second->get_uid()))); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(yOut, PagedCacheLoad_attributes::output_names::yOut); + paged_cache_load_operation_builder.setyDesc(*(tensors.at(yOut->second->get_uid()))); + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto operation = paged_cache_load_operation_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + operation.get_error()); + operations.push_back(std::make_shared(std::move(operation))); +#else + // build() can throw + // wrap in try catch + try { + auto operation = paged_cache_load_operation_builder.build(); + operations.push_back(std::make_shared(std::move(operation))); + } catch (cudnn_frontend::cudnnException& e) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + } +#endif + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + + return {error_code_t::OK, ""}; + } + + error_t + pre_validate_node() const override final { + CUDNN_FE_LOG_LABEL_ENDL("INFO: Validating PagedCacheLoadNode " << attributes.name << "..."); + auto const yOut_dims = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_dim(); + auto const yOut_strides = attributes.outputs.at(PagedCacheLoad_attributes::output_names::yOut)->get_stride(); + auto const container_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::container)->get_dim(); + auto const pageTable_dims = attributes.inputs.at(PagedCacheLoad_attributes::input_names::pageTable)->get_dim(); + + // In the backend, the k-cache is passed as K^T and has dims [B,H,D,S], while v-cache has dims [B,H,S,D] + // Use the strides to distinguish. + auto yIsTransposed = yOut_strides[2] == 1; + auto s_kv = !yIsTransposed ? yOut_dims[2] : yOut_dims[3]; + + auto block_size = container_dims[2]; + auto table_size = pageTable_dims[2]; + RETURN_CUDNN_FRONTEND_ERROR_IF( + (s_kv + (block_size - 1)) / block_size != table_size, + error_code_t::INVALID_VALUE, + "Paged cache load: mismatch between max sequence length, block size and page table size"); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + return {error_code_t::OK, ""}; + } + +#ifndef CUDNN_FRONTEND_SKIP_JSON_LIB + virtual void + serialize(json& j) const override final { + j = attributes; + } +#endif +}; + +inline void +INode::paged_cache_load(std::shared_ptr container, + std::shared_ptr seqLen, + std::shared_ptr pageTable, + PagedCacheLoad_attributes attributes, + std::shared_ptr yOut) { + attributes.inputs[PagedCacheLoad_attributes::input_names::container] = std::move(container); + attributes.inputs[PagedCacheLoad_attributes::input_names::seqLen] = std::move(seqLen); + attributes.inputs[PagedCacheLoad_attributes::input_names::pageTable] = std::move(pageTable); + attributes.outputs[PagedCacheLoad_attributes::output_names::yOut] = std::move(yOut); + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index 9acd8778..2de33294 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -45,11 +45,28 @@ class PointwiseNode : public NodeCRTP { } if (out_0_tensor->get_stride().empty()) { - auto input_stride = attributes.inputs.at(Pointwise_attributes::input_names::IN_0)->get_stride(); - std::vector stride_order; - CHECK_CUDNN_FRONTEND_ERROR( - detail::generate_stride_order_preserving_format(input_stride, output_dim.size(), stride_order)); - out_0_tensor->set_stride(detail::generate_stride(output_dim, stride_order)); + for (const auto& [input_name, input_tensor] : attributes.inputs) { + if (input_tensor == nullptr) { + continue; + } + if (input_tensor->get_dim() == out_0_tensor->get_dim()) { + CUDNN_FE_LOG_LABEL_ENDL("INFO:" << out_0_tensor->get_name() << " stride computed from " + << input_tensor->get_name()); + out_0_tensor->set_stride(input_tensor->get_stride()); + break; + } + } + if (out_0_tensor->get_stride().empty() && out_0_tensor->get_is_virtual()) { + // If the tensor is virtual the strides are immaterial + auto input_stride = attributes.inputs.at(Pointwise_attributes::input_names::IN_0)->get_stride(); + std::vector stride_order; + CHECK_CUDNN_FRONTEND_ERROR( + detail::generate_stride_order_preserving_format(input_stride, output_dim.size(), stride_order)); + out_0_tensor->set_stride(detail::generate_stride(output_dim, stride_order)); + } + RETURN_CUDNN_FRONTEND_ERROR_IF(out_0_tensor->get_stride().empty(), + error_code_t::SHAPE_DEDUCTION_FAILED, + "Pointwise output strides could not be computed"); } return {error_code_t::OK, ""}; @@ -57,7 +74,7 @@ class PointwiseNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index bfaa020f..882134af 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -48,7 +48,7 @@ class ReductionNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/resample.h b/include/cudnn_frontend/node/resample.h index 6ab93fa5..37be58e8 100644 --- a/include/cudnn_frontend/node/resample.h +++ b/include/cudnn_frontend/node/resample.h @@ -96,7 +96,7 @@ class ResampleNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 664cb3f3..370771b7 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -55,7 +55,7 @@ class ReshapeNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/rmsnorm.h b/include/cudnn_frontend/node/rmsnorm.h index c8fe0450..77f13583 100644 --- a/include/cudnn_frontend/node/rmsnorm.h +++ b/include/cudnn_frontend/node/rmsnorm.h @@ -76,7 +76,7 @@ class RMSNormNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { @@ -233,7 +233,7 @@ class DRMSNormNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index 549c08e7..20bd5544 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -56,7 +56,7 @@ class RngNode : public NodeCRTP { error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>& operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& tensors) const override final { diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index e2ef388d..ded68f9e 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -12,6 +12,7 @@ #include "pointwise.h" #include "rng.h" #include "softmax.h" +#include "paged_cache_load.h" namespace cudnn_frontend::graph { @@ -99,6 +100,16 @@ class SDPANode : public NodeCRTP { bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); bool const is_dropout = attributes.dropout_probability.has_value() || is_dropout_custom; + auto page_table_v_it = attributes.inputs.find(input_names::Page_table_V); + auto page_table_k_it = attributes.inputs.find(input_names::Page_table_K); + bool const is_paged = ((page_table_k_it) != attributes.inputs.end() && page_table_k_it->second != nullptr) || + ((page_table_v_it) != attributes.inputs.end() && page_table_v_it->second != nullptr); + + auto const& rng_tensor = attributes.outputs.find(output_names::RNG_DUMP); + bool const is_rng = (rng_tensor != attributes.outputs.end() && rng_tensor->second != nullptr); + + bool const max_seq_kv_explicit = attributes.max_seq_len_kv.has_value(); + // validation TODO: // - validate stats has valid dims @@ -174,6 +185,29 @@ class SDPANode : public NodeCRTP { error_code_t::ATTRIBUTE_NOT_SET, "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + // validate options for paged attention + RETURN_CUDNN_FRONTEND_ERROR_IF(is_paged && is_ragged, + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged caches are not supported in combination with ragged offsets."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_paged && (!has_seq_len_q || !has_seq_len_kv || !attributes.padding_mask), + error_code_t::GRAPH_NOT_SUPPORTED, + "Paged caches can only be used in combination with padding mask and variable sequence lengths for both Q and KV."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(!is_paged && max_seq_kv_explicit, + error_code_t::GRAPH_NOT_SUPPORTED, "When not using paged attention, there is no need to explicitly set max kv sequence length."); + + if (max_seq_kv_explicit){ + auto max_seq_kv = attributes.max_seq_len_kv.value(); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_bias && (bias_mask->second->get_dim()[3] != max_seq_kv), + error_code_t::GRAPH_NOT_SUPPORTED, "Value set through set_paged_attention_max_seq_len_kv is incompatible with the sequence length of the bias"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(is_rng && + rng_tensor->second->get_dim()[3] != max_seq_kv, + error_code_t::GRAPH_NOT_SUPPORTED, "Value set through set_paged_attention_max_seq_len_kv is incompatible with the sequence length of the RNG_DUMP"); + } + // version specific validation RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 8906 && ((s_kv % 64 != 0) || (d_qk % 64 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, @@ -194,6 +228,10 @@ class SDPANode : public NodeCRTP { RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90200 && attributes.sliding_window_length.has_value(), error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.2.0, sliding window attention is not supported"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_paged, + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.5.0, paged caches are not supported"); // validate that datatype is set for the graph @@ -236,29 +274,102 @@ class SDPANode : public NodeCRTP { // Gather dim to fill properties of virtual tensors auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); auto b = q_dim[0]; - auto h = q_dim[1]; + auto h_q = q_dim[1]; auto s_q = q_dim[2]; + auto d_qk = q_dim[3]; auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); - auto s_kv = k_dim[2]; + auto h_k = k_dim[1]; + auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); + auto h_v = v_dim[1]; + auto d_v = v_dim[3]; - // cuDNN frontend API attention requires Q, K, V where - // Q = {b, h_q, s_q, d_qk} - // K = {b, h_k, s_kv, d_qk} - // V = {b, h_v, s_kv, d_v} - // but cuDNN backend API attention requires Q, KT, V - // Q = {b, h_q, s_q, d_qk} - // KT = {b, h_k, d_qk, s_kv} - // V = {b, h_v, s_kv, d_v} - // So the code below maps the K->KT - std::vector temp_vec; + bool is_paged_k = attributes.inputs[input_names::Page_table_K] != nullptr; + bool is_paged_v = attributes.inputs[input_names::Page_table_V] != nullptr; - temp_vec = attributes.inputs[input_names::K]->get_dim(); - std::swap(temp_vec[2], temp_vec[3]); - attributes.inputs[input_names::K]->set_dim(temp_vec); + // Infer s_kv + int64_t s_kv = -1; - temp_vec = attributes.inputs[input_names::K]->get_stride(); - std::swap(temp_vec[2], temp_vec[3]); - attributes.inputs[input_names::K]->set_stride(temp_vec); + // If s_kv was set explicitly, use that + if (attributes.max_seq_len_kv.has_value()) { + s_kv = attributes.max_seq_len_kv.value(); + } + // When one of K or V cache are paged, s_kv can be extracted directly + else if (!is_paged_k) { + s_kv = k_dim[2]; + + } else if (!is_paged_v) { + s_kv = v_dim[2]; + } else { + CUDNN_FE_LOG_LABEL_ENDL( + "WARNING: maximum kv sequence length is being inferred. To set it explicitly, please use " + "\"set_paged_attention_max_seq_len_kv\""); + + // If there is a bias, extract it from there + if (attributes.inputs[input_names::Bias] != nullptr) { + s_kv = attributes.inputs[input_names::Bias]->get_dim()[3]; + // If there is an rng_dump output, extract it from there + } else if (attributes.outputs.find(output_names::RNG_DUMP) != attributes.outputs.end() && + attributes.outputs[output_names::RNG_DUMP] != nullptr) { + s_kv = attributes.outputs[output_names::RNG_DUMP]->get_dim()[3]; + // When both caches are paged, and the above failed, we need to infer s_kv from the page table and + // container + } else { + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_k = attributes.inputs[input_names::Page_table_K]->get_dim(); + // [b, h_k, block_size, d_k] + auto container_dim_k = attributes.inputs[input_names::K]->get_dim(); + int64_t s_k = page_table_dim_k[2] * container_dim_k[2]; + + // [b, 1, ceil(s_kv/block_size), 1] + auto page_table_dim_v = attributes.inputs[input_names::Page_table_V]->get_dim(); + // [b, h_v, block_size, d_v] + auto container_dim_v = attributes.inputs[input_names::V]->get_dim(); + int64_t s_v = page_table_dim_v[2] * container_dim_v[2]; + + s_kv = std::min(s_k, s_v); + } + } + + std::shared_ptr k_cache; + if (!is_paged_k) { + // 1. map K->KT + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, V + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // V = {b, h_v, s_kv, d_v} + // So the code below maps the K->KT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + // 2. Set k_cache + k_cache = attributes.inputs[input_names::K]; + } else { + // Create a paged cache load operation + auto paged_cache_load_attributes_k = PagedCacheLoad_attributes(); + // Need to create virtual tensor descriptor for yOut here as it cannot be inferred + // K-cache has BHDS layout + k_cache = std::make_shared(); + k_cache->set_dim({b, h_k, d_qk, s_kv}) + .set_stride({d_qk * s_kv * h_k, d_qk * s_kv, 1, d_qk}) + .set_data_type(attributes.inputs[input_names::K]->get_data_type()); + k_cache->set_is_virtual(true); + paged_cache_load(attributes.inputs[input_names::K], + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::Page_table_K], + paged_cache_load_attributes_k, + k_cache); + } std::shared_ptr last_output; @@ -271,10 +382,9 @@ class SDPANode : public NodeCRTP { bmm1_attributes.set_padding(0.0); } - auto const& bmm1_output = - matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm1_attributes); + auto const& bmm1_output = matmul(attributes.inputs[input_names::Q], k_cache, bmm1_attributes); // Setting dim and strides as pointwise op wont have knowledge of how to do it for mha. - bmm1_output->set_dim({b, h, s_q, s_kv}).set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}); + bmm1_output->set_dim({b, h_q, s_q, s_kv}).set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); last_output = bmm1_output; // Optional scale @@ -323,11 +433,11 @@ class SDPANode : public NodeCRTP { // Multiply by alibi slope alibi_slopes = std::make_shared(); - alibi_slopes->set_dim({1, h, 1, 1}) - .set_stride({h, 1, 1, 1}) + alibi_slopes->set_dim({1, h_q, 1, 1}) + .set_stride({h_q, 1, 1, 1}) // Hard code data type float as FE itself will compute and place in variant pack later .set_data_type(DataType_t::FLOAT); - alibi_slopes_size = h * sizeof(float); + alibi_slopes_size = h_q * sizeof(float); auto mul_attributes = Pointwise_attributes().set_name("mul").set_mode(PointwiseMode_t::MUL); auto const& alibi_mask = pointwise(sub_output, alibi_slopes, mul_attributes); @@ -589,8 +699,8 @@ class SDPANode : public NodeCRTP { .set_bernoulli_probability(1.0 - attributes.dropout_probability.value())); rng_output // Hard coding dim and strides as rng output can no inputs to infer it from. - ->set_dim({b, h, s_q, s_kv}) - .set_stride({h * s_q * s_kv, s_q * s_kv, s_kv, 1}); + ->set_dim({b, h_q, s_q, s_kv}) + .set_stride({h_q * s_q * s_kv, s_q * s_kv, s_kv, 1}); } auto mask_attributes = @@ -621,12 +731,31 @@ class SDPANode : public NodeCRTP { auto const& seq_len_q = attributes.inputs[input_names::SEQ_LEN_Q]; auto const& seq_len_kv = attributes.inputs[input_names::SEQ_LEN_KV]; - auto const& V = attributes.inputs[input_names::V]; - auto const& O = attributes.outputs[output_names::O]; + // auto const& V = attributes.inputs[input_names::V]; + auto const& O = attributes.outputs[output_names::O]; + + std::shared_ptr v_cache; + + if (!is_paged_v) { + v_cache = attributes.inputs[input_names::V]; + } else { + auto paged_cache_load_attributes_v = PagedCacheLoad_attributes(); + v_cache = std::make_shared(); + v_cache->set_dim({b, h_v, s_kv, d_v}) + .set_stride({d_v * s_kv * h_v, d_v * s_kv, d_v, 1}) + .set_data_type(attributes.inputs[input_names::V]->get_data_type()); + v_cache->set_is_virtual(true); + paged_cache_load(attributes.inputs[input_names::V], + attributes.inputs[input_names::SEQ_LEN_KV], + attributes.inputs[input_names::Page_table_V], + paged_cache_load_attributes_v, + v_cache); + } + auto bmm2_attributes = Matmul_attributes().set_name("bmm2").set_m_override(seq_len_q).set_k_override(seq_len_kv); // Special non-functional-style call. Needed because output already created and provided to user. - matmul(last_output, V, bmm2_attributes, O); + matmul(last_output, v_cache, bmm2_attributes, O); return {error_code_t::OK, ""}; } @@ -662,7 +791,8 @@ class SDPANode : public NodeCRTP { virtual error_t collect_tensors_in_workspace_node( - std::unordered_map>>& workspace_modifications, + std::unordered_map>>& + workspace_modifications, int64_t& offset) const override final { if (attributes.alibi_mask) { CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Q, input_names::Q); @@ -764,6 +894,8 @@ class SDPABackwardNode : public NodeCRTP { auto const& bias_mask = attributes.inputs.find(input_names::Bias); bool const is_bias = (bias_mask != attributes.inputs.end() && bias_mask->second != nullptr); + auto const& dbias_mask = attributes.outputs.find(output_names::dBias); + bool const is_dbias = (dbias_mask != attributes.outputs.end() && dbias_mask->second != nullptr); auto const& dropout_mask = attributes.inputs.find(input_names::Dropout_mask); bool const is_dropout_custom = (dropout_mask != attributes.inputs.end()) && (dropout_mask->second != nullptr); @@ -773,11 +905,22 @@ class SDPABackwardNode : public NodeCRTP { // - validate stats has valid dims // - validate Q and dQ have the same dims - // validate basic dimension requirements - RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), - error_code_t::GRAPH_NOT_SUPPORTED, - "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); + cudaDeviceProp prop; + int device; + CHECK_CUDA_ERROR(detail::cuda_get_device(&device)); + CHECK_CUDA_ERROR(detail::cuda_get_device_properties(&prop, device)); + if (prop.major >= 9) { + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 256) || (d_qk % 8 != 0) || (d_v > 256) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim shoud be less than 256 and hidden_dim should be multiple of 8"); + } else { + // validate basic dimension requirements + RETURN_CUDNN_FRONTEND_ERROR_IF((d_qk > 128) || (d_qk % 8 != 0) || (d_v > 128) || (d_v % 8 != 0), + error_code_t::GRAPH_NOT_SUPPORTED, + "Num hidden_dim shoud be less than 128 and hidden_dim should be multiple of 8"); + } RETURN_CUDNN_FRONTEND_ERROR_IF((h_q % h_k != 0) || (h_q % h_v != 0), error_code_t::GRAPH_NOT_SUPPORTED, "For group-query attention, number of heads for key and query must be a factor of number of heads for query"); @@ -806,6 +949,11 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::ATTRIBUTE_NOT_SET, "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); + // validate options for max_total_seq_len + RETURN_CUDNN_FRONTEND_ERROR_IF((attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value()) && !is_ragged, + error_code_t::GRAPH_NOT_SUPPORTED, + "max_total_seq_len_q is only supported with packed layout"); + // validate options for bottom right causal mask RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.causal_mask && attributes.causal_mask_bottom_right, error_code_t::GRAPH_NOT_SUPPORTED, @@ -866,14 +1014,25 @@ class SDPABackwardNode : public NodeCRTP { error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.2.0, sliding window attention is not supported"); - RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_bias && attributes.padding_mask, + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_dbias && attributes.padding_mask, error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.5.0, dBias with variable sequence lengths is not supported"); - RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_bias && ((s_q % 64 != 0) || (s_kv % 64 != 0)), + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90500 && is_dbias && ((s_q % 64 != 0) || (s_kv % 64 != 0)), error_code_t::GRAPH_NOT_SUPPORTED, "For cuDNN version below 9.5.0, dBias not support s_q/s_kv which aren't multiple of 64"); + RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 90600 && is_ragged && ((h_q != h_k) || (h_q != h_v)), + error_code_t::GRAPH_NOT_SUPPORTED, + "For cuDNN version below 9.6.0, group-query attention with raggged offset is not supported"); + + if (detail::get_backend_version() < 90600 && (attributes.max_total_seq_len_q.has_value() || attributes.max_total_seq_len_kv.has_value())) { + CUDNN_FE_LOG_LABEL_ENDL( + "WARNING: sdpa_backward.attributes.max_total_seq_len has been set, but cuDNN version is below 9.6.0 " + "which does not support max_total_seq_len_q. The workspace memory size required to execute this graph " + "may be unexpectedly large"); + } + // validate that datatype is set for the graph RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, @@ -970,7 +1129,7 @@ class SDPABackwardNode : public NodeCRTP { // ---------------------input tensor workarounds--------------------------- - bool use_workspace_opt = false; + bool use_dp_workspace = false; if (detail::get_backend_version() >= 8905 && detail::get_backend_version() < 90000) { // workspace optimization is enabled by default when: @@ -1011,11 +1170,11 @@ class SDPABackwardNode : public NodeCRTP { int64_t required_dp_workspace_bytes = b * h_q * workspace_s_q * workspace_s_kv * 2; if (max_dp_workspace_bytes == -1) { - use_workspace_opt = true; + use_dp_workspace = true; } else if (max_dp_workspace_bytes == 0) { - use_workspace_opt = false; + use_dp_workspace = false; } else { - use_workspace_opt = (required_dp_workspace_bytes <= max_dp_workspace_bytes); + use_dp_workspace = (required_dp_workspace_bytes <= max_dp_workspace_bytes); } } } @@ -1024,16 +1183,7 @@ class SDPABackwardNode : public NodeCRTP { // - dBias is enabled (dBias is only supported on workspace implementation) // - the user force requests deterministic algorithm if (attributes.outputs[output_names::dBias] || attributes.is_deterministic_algorithm) { - use_workspace_opt = true; - } - - // non-virtual dQ_accum is how the backend API signals workspace optimization - if (!use_workspace_opt) { - dQ_accum = std::make_shared(); - dQ_accum->set_is_virtual(false); - dQ_accum->set_dim({b, h_q, s_q, d_qk}).set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}); - dQ_accum->set_data_type(DataType_t::FLOAT).set_reordering_type(TensorReordering_t::F16x16); - dQ_accum_size = b * h_q * s_q * d_qk * sizeof(float); + use_dp_workspace = true; } // --------------RNG node-------------------- @@ -1083,6 +1233,21 @@ class SDPABackwardNode : public NodeCRTP { last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); softmax_sum = last_output; + softmax_sum->set_is_virtual(false); + softmax_sum->set_dim({b, h_q, s_q, 1}); + softmax_sum->set_data_type(DataType_t::FLOAT); + + if (attributes.inputs[input_names::Stats]->get_ragged_offset() && attributes.max_total_seq_len_q.has_value() && + detail::get_backend_version() >= 90600) { + // sized TH1 softmax_sum + softmax_sum->set_stride(attributes.inputs[input_names::Stats]->get_stride()); + softmax_sum->set_ragged_offset(attributes.inputs[input_names::Stats]->get_ragged_offset()); + softmax_sum_size = attributes.max_total_seq_len_q.value() * h_q * 1 * sizeof(float); + } else { + // sized BHS1 softmax_sum + softmax_sum->set_stride({h_q * s_q, s_q, 1, 1}); + softmax_sum_size = b * h_q * s_q * 1 * sizeof(float); + } // --------------"Q @ KT => exp_softmax => dV" chain-------------------- @@ -1510,29 +1675,44 @@ class SDPABackwardNode : public NodeCRTP { last_output->set_ragged_offset(attributes.inputs[input_names::K]->get_ragged_offset()); } - matmul(dS_output, - last_output, - Matmul_attributes() - .set_name("matmul_dS_K") - .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) - .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), - (dQ_accum) ? dQ_accum : attributes.outputs[output_names::dQ]); + if (!use_dp_workspace) { + dQ_accum = std::make_shared(); + dQ_accum->set_is_virtual(false); + dQ_accum->set_dim({b, h_q, s_q, d_qk}); + dQ_accum->set_data_type(DataType_t::FLOAT); + + if (attributes.outputs[output_names::dQ]->get_ragged_offset() && + attributes.max_total_seq_len_q.has_value() && detail::get_backend_version() >= 90600) { + // sized THD dQ_accum + dQ_accum->set_stride(attributes.outputs[output_names::dQ]->get_stride()); + dQ_accum->set_ragged_offset(attributes.outputs[output_names::dQ]->get_ragged_offset()); + dQ_accum_size = attributes.max_total_seq_len_q.value() * + (attributes.outputs[output_names::dQ]->get_stride())[2] * sizeof(float); + } else { + // sized BHSD dQ_accum + dQ_accum->set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1}); + dQ_accum_size = b * h_q * s_q * d_qk * sizeof(float); + } + + matmul(dS_output, + last_output, + Matmul_attributes() + .set_name("matmul_dS_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), + (dQ_accum) ? dQ_accum : attributes.outputs[output_names::dQ]); - if (dQ_accum) { pointwise(dQ_accum, Pointwise_attributes().set_name("identity_dQ").set_mode(PointwiseMode_t::IDENTITY), attributes.outputs[output_names::dQ]); - } - - // ---------------------output tensor workarounds--------------------------- - - // non-virtual softmax_sum is required for below cuDNN 8.9.5 - // non-virtual softmax_sum is passed by the node - if (detail::get_backend_version() < 8905) { - softmax_sum->set_is_virtual(false); - softmax_sum->set_dim({b, h_q, s_q, 1}); - softmax_sum->set_data_type(DataType_t::FLOAT); - softmax_sum_size = b * h_q * s_q * sizeof(float); + } else { + matmul(dS_output, + last_output, + Matmul_attributes() + .set_name("matmul_dS_K") + .set_m_override(attributes.inputs[input_names::SEQ_LEN_Q]) + .set_k_override(attributes.inputs[input_names::SEQ_LEN_KV]), + attributes.outputs[output_names::dQ]); } return {error_code_t::OK, ""}; @@ -1551,7 +1731,8 @@ class SDPABackwardNode : public NodeCRTP { virtual error_t collect_tensors_in_workspace_node( - std::unordered_map>>& workspace_modifications, + std::unordered_map>>& + workspace_modifications, int64_t& offset) const override final { if (attributes.alibi_mask) { CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(Q, input_names::Q); @@ -1563,8 +1744,10 @@ class SDPABackwardNode : public NodeCRTP { } if (dQ_accum && !dQ_accum->get_is_virtual()) { - std::vector f_vec = {(float)dQ_accum_size}; - workspace_modifications.emplace(dQ_accum->get_uid(), std::make_tuple(1, offset, f_vec)); + std::vector f_vec = {(float)dQ_accum_size}; + int64_t dQ_accum_workspace_type = detail::get_backend_version() < 90600 ? 1 : 2; + workspace_modifications.emplace(dQ_accum->get_uid(), + std::make_tuple(dQ_accum_workspace_type, offset, f_vec)); offset = offset + dQ_accum_size; } diff --git a/include/cudnn_frontend/node/slice.h b/include/cudnn_frontend/node/slice.h index 7fd8b2f6..d84ce3f5 100644 --- a/include/cudnn_frontend/node/slice.h +++ b/include/cudnn_frontend/node/slice.h @@ -61,19 +61,19 @@ class SliceNode : public NodeCRTP { // But assign it a uid auto const input = attributes.inputs.at(Slice_attributes::input_names::X); if (input->has_uid() == false) { - assign_uid(input.get(), potential_uid, used_uids); + detail::assign_uid(input.get(), potential_uid, used_uids); } auto const output = attributes.outputs.at(Slice_attributes::output_names::Y); output->set_is_virtual(false); - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(output, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(output, tensors, potential_uid, used_uids)); return {error_code_t::OK, ""}; } error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operations, + std::unordered_set& uids_involved_in_operations, std::vector>&, managed_backend_descriptor_t& raw_operations, std::unordered_map>&) const override final { diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index 7e39cbb7..07acf447 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -36,7 +36,7 @@ class RngNode; class SoftmaxNode; // Interface for all nodes to follow. -class INode : public ICudnn { +class INode { public: // A closed set of types that are allowed to be passed by value today using pass_by_values_t = Tensor_attributes::pass_by_values_t; @@ -79,7 +79,7 @@ class INode : public ICudnn { } virtual error_t - collect_pass_by_value_tensors_node(std::unordered_map&) const { + collect_pass_by_value_tensors_node(std::unordered_map&) const { return {error_code_t::OK, ""}; }; @@ -96,8 +96,9 @@ class INode : public ICudnn { std::unordered_set const& used_uids) const = 0; virtual error_t - collect_tensors_in_workspace_node(std::unordered_map>>&, - int64_t&) const { + collect_tensors_in_workspace_node( + std::unordered_map>>&, + int64_t&) const { return {error_code_t::OK, ""}; } @@ -129,7 +130,8 @@ class INode : public ICudnn { RNG, SCALED_DOT_PRODUCT_ATTENTION, SLICE, - WGRAD + WGRAD, + PAGED_CACHE_LOAD }; Type tag; @@ -184,6 +186,13 @@ class INode : public ICudnn { Rng_attributes attributes, std::shared_ptr y); + void + paged_cache_load(std::shared_ptr container, + std::shared_ptr seqLen, + std::shared_ptr pageTable, + PagedCacheLoad_attributes attributes, + std::shared_ptr yOut); + error_t validate_subtree() { // pre validate to catch errors early @@ -225,7 +234,8 @@ class INode : public ICudnn { } error_t - collect_pass_by_value_tensors_subtree(std::unordered_map& tensor_to_pass_by_value) const { + collect_pass_by_value_tensors_subtree( + std::unordered_map& tensor_to_pass_by_value) const { CHECK_CUDNN_FRONTEND_ERROR(collect_pass_by_value_tensors_node(tensor_to_pass_by_value)); for (auto const& sub_node : sub_nodes) { CHECK_CUDNN_FRONTEND_ERROR(sub_node->collect_pass_by_value_tensors_subtree(tensor_to_pass_by_value)); @@ -235,7 +245,8 @@ class INode : public ICudnn { error_t collect_tensors_in_workspace_subtree( - std::unordered_map>>& worskspace_modifications, + std::unordered_map>>& + worskspace_modifications, int64_t& offset) const { CHECK_CUDNN_FRONTEND_ERROR(collect_tensors_in_workspace_node(worskspace_modifications, offset)); offset = get_fe_workspace_size_node(); @@ -271,7 +282,7 @@ class INode : public ICudnn { // Only INode that map to a primitive cudnn operation need to specialize. virtual error_t create_cudnn_operations( - std::unordered_set& uids_involved_in_operation, + std::unordered_set& uids_involved_in_operation, std::vector>& backend_operations, managed_backend_descriptor_t& raw_operations, std::unordered_map>& uid_to_backend_tensors) const { @@ -362,18 +373,18 @@ class NodeCRTP : public INode { error_t create_cudnn_tensors_node(std::unordered_map>& tensors, int64_t& potential_uid, - std::unordered_set const& used_uids) const { + std::unordered_set const& used_uids) const override { CUDNN_FE_LOG_LABEL_ENDL("INFO: Creating cudnn tensors for node named '" << self().attributes.name << "':"); for (auto const& [name, tensor] : self().attributes.inputs) { (void)name; if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); } } for (auto const& [name, tensor] : self().attributes.outputs) { (void)name; if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); } } @@ -382,7 +393,7 @@ class NodeCRTP : public INode { // Special case in BN where peer stats is also an input but is not present in inputs map for (auto const& tensor : self().attributes.peer_stats) { if (tensor) { - CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); + CHECK_CUDNN_FRONTEND_ERROR(detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids)); } } } diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index b89a0986..ac484a22 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -304,6 +304,18 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Rng_attributes::input_names, NLOHMANN_JSON_SERIALIZE_ENUM(Rng_attributes::output_names, {{Rng_attributes::output_names::Y, "Y"}}) +NLOHMANN_JSON_SERIALIZE_ENUM(PagedCacheLoad_attributes::input_names, + { + {PagedCacheLoad_attributes::input_names::container, "container"}, + {PagedCacheLoad_attributes::input_names::seqLen, "seqLen"}, + {PagedCacheLoad_attributes::input_names::pageTable, "pageTable"}, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(PagedCacheLoad_attributes::output_names, + { + {PagedCacheLoad_attributes::output_names::yOut, "yOut"}, + }) + NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_attributes::input_names, { {SDPA_attributes::input_names::Q, "Q"}, @@ -317,6 +329,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_attributes::input_names, {SDPA_attributes::input_names::Offset, "Offset"}, {SDPA_attributes::input_names::Dropout_mask, "Dropout_mask"}, {SDPA_attributes::input_names::Dropout_scale, "Dropout_scale"}, + {SDPA_attributes::input_names::Page_table_K, "Page_table_K"}, + {SDPA_attributes::input_names::Page_table_V, "Page_table_V"}, }) NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_attributes::output_names, diff --git a/include/cudnn_frontend_Operation.h b/include/cudnn_frontend_Operation.h index 52c17825..c5039b9f 100644 --- a/include/cudnn_frontend_Operation.h +++ b/include/cudnn_frontend_Operation.h @@ -174,6 +174,9 @@ class Operation_v8 : public BackendDescriptor { ManagedOpaqueDescriptor idxdesc = nullptr; ManagedOpaqueDescriptor offsetdesc = nullptr; ManagedOpaqueDescriptor seeddesc = nullptr; + ManagedOpaqueDescriptor containerdesc = nullptr; + ManagedOpaqueDescriptor pageTabledesc = nullptr; + ManagedOpaqueDescriptor sequencedesc = nullptr; std::vector peerStatdescs; cudnnBackendAttributeType_t alphabetaType = CUDNN_TYPE_FLOAT; @@ -204,19 +207,20 @@ class Operation_v8 : public BackendDescriptor { class OperationBuilder_v8 { private: Operation_v8 m_operation; - bool is_convolution_op = false; - bool is_pointwise_op = false; - bool is_matmul_op = false; - bool is_reduction_op = false; - bool is_genstats_op = false; - bool is_bn_finalize_op = false; - bool is_resample_fwd_op = false; - bool is_resample_bwd_op = false; - bool is_norm_forward_op = false; - bool is_norm_backward_op = false; - bool is_bn_bwd_weight = false; - bool is_rng_op = false; - bool is_reshape_op = false; + bool is_convolution_op = false; + bool is_pointwise_op = false; + bool is_matmul_op = false; + bool is_reduction_op = false; + bool is_genstats_op = false; + bool is_bn_finalize_op = false; + bool is_resample_fwd_op = false; + bool is_resample_bwd_op = false; + bool is_norm_forward_op = false; + bool is_norm_backward_op = false; + bool is_bn_bwd_weight = false; + bool is_rng_op = false; + bool is_reshape_op = false; + bool is_paged_cache_load_op = false; using Message_t = const char *; @@ -1641,6 +1645,70 @@ class OperationBuilder_v8 { return std::move(m_operation); } + Operation_v8 && + build_paged_cache_load_op() { +#if (CUDNN_VERSION < 90500) + set_error_and_throw_exception( + &m_operation, + CUDNN_STATUS_NOT_SUPPORTED, + "CUDNN_BACKEND_OPERATION: paged_cache_load_op operation Not supported in this version"); +#else + NV_CUDNN_FE_DYNAMIC_CHECK_BACKEND_DESCRIPTOR( + 90500, m_operation, "CUDNN_BACKEND_OPERATION: build_paged_cache_load_op requires cudnn 9.5.0"); + + // Quick helper lambda to ensure code being DRY + auto set_tensor_descriptor = [&](auto attr, const std::string &descriptor_name, auto &descriptor) { + std::string error_msg = "CUDNN_BACKEND_OPERATION: Check and Set " + descriptor_name; + auto status = CUDNN_STATUS_SUCCESS; + if (descriptor != nullptr) { + status = detail::set_attribute(m_operation.pointer->get_backend_descriptor(), + attr, + CUDNN_TYPE_BACKEND_DESCRIPTOR, + 1, + &(descriptor->get_backend_descriptor())); + } else { + status = CUDNN_STATUS_BAD_PARAM; + } + + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception(&m_operation, status, error_msg.c_str()); + } + return status; + }; + + if (CUDNN_STATUS_SUCCESS != set_tensor_descriptor(CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_CONTAINER_DESC, + "CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_CONTAINER_DESC", + m_operation.containerdesc)) { + return std::move(m_operation); + } + + if (CUDNN_STATUS_SUCCESS != set_tensor_descriptor(CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_PAGE_TABLE_DESC, + "CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_PAGE_TABLE_DESC", + m_operation.pageTabledesc)) { + return std::move(m_operation); + } + + if (CUDNN_STATUS_SUCCESS != set_tensor_descriptor(CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_SEQUENCE_DESC, + "CUDNN_ATTR_OPERATION_PAGED_CACHE_SEQUENCE_DESC", + m_operation.sequencedesc)) { + return std::move(m_operation); + } + + if (CUDNN_STATUS_SUCCESS != set_tensor_descriptor(CUDNN_ATTR_OPERATION_PAGED_CACHE_LOAD_YDESC, + "CUDNN_ATTR_OPERATION_PAGED_CACHE_YDESC", + m_operation.ydesc)) { + return std::move(m_operation); + } + + auto status = detail::finalize(m_operation.pointer->get_backend_descriptor()); + if (status != CUDNN_STATUS_SUCCESS) { + set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed"); + return std::move(m_operation); + } +#endif + return std::move(m_operation); + } + Operation_v8 && build_reshape_operation() { #if (CUDNN_VERSION >= 8700) @@ -2419,6 +2487,24 @@ class OperationBuilder_v8 { return *this; } + auto + setcontainerDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & { + m_operation.containerdesc = tensor.get_desc(); + return *this; + } + + auto + setpageTableDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & { + m_operation.pageTabledesc = tensor.get_desc(); + return *this; + } + + auto + setsequenceDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 & { + m_operation.sequencedesc = tensor.get_desc(); + return *this; + } + auto setNormFwdPhase(NormFwdPhase_t mode) -> OperationBuilder_v8 & { m_operation.norm_fwd_phase = mode; @@ -2808,18 +2894,19 @@ class OperationBuilder_v8 { (m_operation.op_mode == DescriptorType_t::OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) || (m_operation.op_mode == DescriptorType_t::OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)); - is_pointwise_op = (m_operation.op_mode == DescriptorType_t::OPERATION_POINTWISE_DESCRIPTOR); - is_matmul_op = (m_operation.op_mode == DescriptorType_t::OPERATION_MATMUL_DESCRIPTOR); - is_reduction_op = (m_operation.op_mode == DescriptorType_t::OPERATION_REDUCTION_DESCRIPTOR); - is_genstats_op = (m_operation.op_mode == DescriptorType_t::OPERATION_GEN_STATS_DESCRIPTOR); - is_bn_finalize_op = (m_operation.op_mode == DescriptorType_t::OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR); - is_bn_bwd_weight = (m_operation.op_mode == DescriptorType_t::OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR); - is_resample_fwd_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESAMPLE_FWD_DESCRIPTOR); - is_norm_forward_op = (m_operation.op_mode == DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); - is_norm_backward_op = (m_operation.op_mode == DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); - is_resample_bwd_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESAMPLE_BWD_DESCRIPTOR); - is_rng_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RNG_DESCRIPTOR); - is_reshape_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESHAPE_DESCRIPTOR); + is_pointwise_op = (m_operation.op_mode == DescriptorType_t::OPERATION_POINTWISE_DESCRIPTOR); + is_matmul_op = (m_operation.op_mode == DescriptorType_t::OPERATION_MATMUL_DESCRIPTOR); + is_reduction_op = (m_operation.op_mode == DescriptorType_t::OPERATION_REDUCTION_DESCRIPTOR); + is_genstats_op = (m_operation.op_mode == DescriptorType_t::OPERATION_GEN_STATS_DESCRIPTOR); + is_bn_finalize_op = (m_operation.op_mode == DescriptorType_t::OPERATION_BN_FINALIZE_STATISTICS_DESCRIPTOR); + is_bn_bwd_weight = (m_operation.op_mode == DescriptorType_t::OPERATION_BN_BWD_WEIGHTS_DESCRIPTOR); + is_resample_fwd_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESAMPLE_FWD_DESCRIPTOR); + is_norm_forward_op = (m_operation.op_mode == DescriptorType_t::OPERATION_NORM_FORWARD_DESCRIPTOR); + is_norm_backward_op = (m_operation.op_mode == DescriptorType_t::OPERATION_NORM_BACKWARD_DESCRIPTOR); + is_resample_bwd_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESAMPLE_BWD_DESCRIPTOR); + is_rng_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RNG_DESCRIPTOR); + is_reshape_op = (m_operation.op_mode == DescriptorType_t::OPERATION_RESHAPE_DESCRIPTOR); + is_paged_cache_load_op = (m_operation.op_mode == DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR); } // This constructor which takes in cudnn C backend enum for cudnnBackendDescriptorType_t will be deprecated, @@ -2865,6 +2952,8 @@ class OperationBuilder_v8 { status_ = validate_norm_op(msg); } else if (is_reshape_op) { status_ = validate_reshape_op(msg); + } else if (is_paged_cache_load_op) { + status_ = CUDNN_STATUS_SUCCESS; } else { status_ = CUDNN_STATUS_BAD_PARAM; msg = @@ -2877,8 +2966,8 @@ class OperationBuilder_v8 { } // Create the descriptor. - cudnnBackendDescriptorType_t cudnn_backend_descritpor_type; - auto status = detail::convert_to_cudnn_type(m_operation.op_mode, cudnn_backend_descritpor_type); + cudnnBackendDescriptorType_t cudnn_backend_descriptor_type; + auto status = detail::convert_to_cudnn_type(m_operation.op_mode, cudnn_backend_descriptor_type); if (status != CUDNN_STATUS_SUCCESS) { set_error_and_throw_exception( &m_operation, @@ -2886,7 +2975,7 @@ class OperationBuilder_v8 { "CUDNN_BACKEND_OPERATION: cudnnCreate Failed with Invalid backend descriptor type."); return std::move(m_operation); } - status = m_operation.initialize_managed_backend_pointer(cudnn_backend_descritpor_type); + status = m_operation.initialize_managed_backend_pointer(cudnn_backend_descriptor_type); if (status != CUDNN_STATUS_SUCCESS) { set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnCreate Failed"); return std::move(m_operation); @@ -2920,6 +3009,8 @@ class OperationBuilder_v8 { return build_resample_bwd_operation(); } else if (m_operation.op_mode == DescriptorType_t::OPERATION_RNG_DESCRIPTOR) { return build_rng_operation(); + } else if (m_operation.op_mode == DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR) { + return build_paged_cache_load_op(); } else if (m_operation.op_mode == DescriptorType_t::OPERATION_RESHAPE_DESCRIPTOR) { return build_reshape_operation(); } else { diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index 9c12a9ef..32d3864f 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -22,11 +22,12 @@ #pragma once -#include -#include +#include + #if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING #include #include +#include #endif namespace cudnn_frontend { @@ -44,23 +45,54 @@ get_symbol(const char *function_name) { return ret; } +enum class CudaLibrary { CUDART, CUDA }; + inline void * -get_cuda_symbol(const char *function_name) { - static std::mutex cuda_fe_lib_mutex; - std::lock_guard lock(cuda_fe_lib_mutex); - char *c = NULL; - c = dlerror(); - static void *dl_handle = dlopen("libcudart.so", RTLD_NOW); - c = dlerror(); - (void)c; - if (dl_handle == nullptr) { - std::string error_msg = std::string("Unable to dlopen libcudart.so") + std::string(c); - throw std::runtime_error(error_msg.c_str()); +get_cuda_symbol(CudaLibrary library, const char *function_name) { + // Static mutex to ensure thread-safety + static std::mutex cuda_lib_mutex; + // Static map to store handles for different libraries + static std::unordered_map dl_handles; + + // Determine the library name based on the provided library parameter + const char *library_name = (library == CudaLibrary::CUDART) ? "libcudart.so" : "libcuda.so"; + + // Lock the mutex to ensure thread-safe access + std::lock_guard lock(cuda_lib_mutex); + + // If the library hasn't been opened yet, open it + if (dl_handles.find(library) == dl_handles.end()) { + // Clear any existing error + dlerror(); + + // Attempt to open the specified CUDA library + void *handle = dlopen(library_name, RTLD_NOW); + const char *error = dlerror(); + if (!handle || error) { + // If opening the library fails, throw an exception with the error message + throw std::runtime_error("Unable to dlopen " + std::string(library_name) + ": " + + std::string(error ? error : "Unknown error")); + } + // Store the handle for future use + dl_handles[library] = handle; } - void *ret = dlsym(dl_handle, function_name); - return ret; + // Clear any existing error before calling dlsym + dlerror(); + + // Try to find the symbol (function) in the library + void *symbol = dlsym(dl_handles[library], function_name); + const char *error = dlerror(); + if (!symbol || error) { + // If the symbol is not found, throw an exception with details + throw std::runtime_error("Unable to find symbol " + std::string(function_name) + ": " + + std::string(error ? error : "Unknown error")); + } + + // Return the pointer to the function + return symbol; } + #endif #if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING @@ -94,16 +126,118 @@ get_cuda_symbol(const char *function_name) { #endif #if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING -#define NV_FE_CALL_TO_CUDA(function_name, cuda_symbol, ...) \ - static void *fptr = get_cuda_symbol(#cuda_symbol); \ - if (fptr == nullptr) { \ - throw std::runtime_error("Unable to find symbol " #cuda_symbol); \ - } \ - return reinterpret_cast(fptr)(__VA_ARGS__); + +#define NV_FE_CALL_TO_CUDA(function_name, cuda_symbol, ...) \ + return reinterpret_cast(get_cuda_symbol(CudaLibrary::CUDART, #cuda_symbol))(__VA_ARGS__); +#define NV_FE_CALL_TO_CU(function_name, cuda_symbol, ...) \ + return reinterpret_cast(get_cuda_symbol(CudaLibrary::CUDA, #cuda_symbol))(__VA_ARGS__); + #else + #define NV_FE_CALL_TO_CUDA(function_name, cuda_symbol, ...) return cuda_symbol(__VA_ARGS__); +#define NV_FE_CALL_TO_CU(function_name, cuda_symbol, ...) return cuda_symbol(__VA_ARGS__); + #endif +inline CUresult +cu_graph_create(CUgraph *pGraph, unsigned int flags) { + NV_FE_CALL_TO_CU(cu_graph_create, cuGraphCreate, pGraph, flags); +} + +inline CUresult +cu_graph_get_nodes(CUgraph hGraph, CUgraphNode *nodes, size_t *numNodes) { + NV_FE_CALL_TO_CU(cu_graph_get_nodes, cuGraphGetNodes, hGraph, nodes, numNodes); +} + +inline cudaError_t +cuda_graph_add_child_graph_node(cudaGraphNode_t *pGraphNode, + cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, + size_t numDependencies, + cudaGraph_t childGraph) { + NV_FE_CALL_TO_CUDA(cuda_graph_add_child_graph_node, + cudaGraphAddChildGraphNode, + pGraphNode, + graph, + pDependencies, + numDependencies, + childGraph); +} + +inline cudaError_t +cuda_graph_add_memcpy_node_1D(cudaGraphNode_t *pGraphNode, + cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, + size_t numDependencies, + void *dst, + const void *src, + size_t count, + cudaMemcpyKind kind) { + NV_FE_CALL_TO_CUDA(cuda_graph_add_memcpy_node_1D, + cudaGraphAddMemcpyNode1D, + pGraphNode, + graph, + pDependencies, + numDependencies, + dst, + src, + count, + kind); +} + +inline cudaError_t +cuda_graph_add_memset_node(cudaGraphNode_t *pGraphNode, + cudaGraph_t graph, + const cudaGraphNode_t *pDependencies, + size_t numDependencies, + const cudaMemsetParams *pMemsetParams) { + NV_FE_CALL_TO_CUDA(cuda_graph_add_memset_node, + cudaGraphAddMemsetNode, + pGraphNode, + graph, + pDependencies, + numDependencies, + pMemsetParams); +} + +inline cudaError_t +cuda_graph_get_root_nodes(cudaGraph_t hGraph, cudaGraphNode_t *phNodes, size_t *pNumNodes) { + NV_FE_CALL_TO_CUDA(cuda_graph_get_root_nodes, cudaGraphGetRootNodes, hGraph, phNodes, pNumNodes); +} + +inline cudaError_t +cuda_graph_child_graph_node_get_graph(cudaGraphNode_t hNode, cudaGraph_t *phGraph) { + NV_FE_CALL_TO_CUDA(cuda_graph_child_graph_node_get_graph, cudaGraphChildGraphNodeGetGraph, hNode, phGraph); +} + +inline cudaError_t +cuda_graph_node_get_dependent_nodes(cudaGraphNode_t node, + cudaGraphNode_t *pDependentNodes, + size_t *pNumDependentNodes) { + NV_FE_CALL_TO_CUDA( + cuda_graph_node_get_dependent_nodes, cudaGraphNodeGetDependentNodes, node, pDependentNodes, pNumDependentNodes); +} + +inline cudaError_t +cuda_graph_add_memcpy_node_set_params_1D(cudaGraphNode_t node, + void *dst, + const void *src, + size_t count, + cudaMemcpyKind kind) { + NV_FE_CALL_TO_CUDA( + cuda_graph_add_memcpy_node_set_params_1D, cudaGraphMemcpyNodeSetParams1D, node, dst, src, count, kind); +} + +inline cudaError_t +cuda_graph_add_memset_node_set_params(cudaGraphNode_t node, const cudaMemsetParams *pMemsetParams) { + NV_FE_CALL_TO_CUDA(cuda_graph_add_memset_node_set_params, cudaGraphMemsetNodeSetParams, node, pMemsetParams); +} + +inline cudaError_t +cuda_graph_destroy(cudaGraph_t graph) { + NV_FE_CALL_TO_CUDA(cuda_graph_destroy, cudaGraphDestroy, graph); +} + inline cudaError_t cuda_event_create(cudaEvent_t *event) { NV_FE_CALL_TO_CUDA(cuda_event_create, cudaEventCreate, event); @@ -154,17 +288,14 @@ cuda_get_error_string(cudaError_t error) { NV_FE_CALL_TO_CUDA(cuda_get_error_string, cudaGetErrorString, error); } +inline CUresult +cu_get_error_string(CUresult error, const char **pStr) { + NV_FE_CALL_TO_CU(cu_get_error_string, cuGetErrorString, error, pStr); +} + inline cudaError_t cuda_device_synchronize() { -#if defined NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING - static void *fptr = get_cuda_symbol("cudaDeviceSynchronize"); - if (fptr == nullptr) { - throw std::runtime_error("Unable to find symbol cudaDeviceSynchronize"); - } - return reinterpret_cast(fptr)(); -#else - return cudaDeviceSynchronize(); -#endif + NV_FE_CALL_TO_CUDA(cuda_device_synchronize, cudaDeviceSynchronize); } inline cudnnStatus_t @@ -266,6 +397,40 @@ execute(cudnnHandle_t handle, cudnnBackendDescriptor_t executionPlan, cudnnBacke NV_FE_CALL_TO_BACKEND(execute, cudnnBackendExecute, handle, executionPlan, variantPack); } +inline cudnnStatus_t +populate_cuda_graph(cudnnHandle_t handle, + cudnnBackendDescriptor_t executionPlan, + cudnnBackendDescriptor_t variantPack, + cudaGraph_t cuda_graph) { +#if CUDNN_VERSION >= 90500 + NV_FE_CALL_TO_BACKEND( + populate_cuda_graph, cudnnBackendPopulateCudaGraph, handle, executionPlan, variantPack, cuda_graph); +#else + (void)handle; + (void)executionPlan; + (void)variantPack; + (void)cuda_graph; + return CUDNN_STATUS_VERSION_MISMATCH; +#endif +} + +inline cudnnStatus_t +update_cuda_graph(cudnnHandle_t handle, + cudnnBackendDescriptor_t executionPlan, + cudnnBackendDescriptor_t variantPack, + cudaGraph_t cuda_graph) { +#if CUDNN_VERSION >= 90500 + NV_FE_CALL_TO_BACKEND( + update_cuda_graph, cudnnBackendUpdateCudaGraph, handle, executionPlan, variantPack, cuda_graph); +#else + (void)handle; + (void)executionPlan; + (void)variantPack; + (void)cuda_graph; + return CUDNN_STATUS_VERSION_MISMATCH; +#endif +} + inline const char * get_error_string(cudnnStatus_t status) { NV_FE_CALL_TO_BACKEND(get_error_string, cudnnGetErrorString, status); @@ -290,7 +455,7 @@ get_last_error_string_() { std::string message; - message.reserve(size); + message.resize(size); get_last_error_string(message.data(), size); diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index 6ead3a19..172f775a 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -95,14 +95,14 @@ struct nlohmann::adl_serializer { }; template <> -struct nlohmann::adl_serializer> { +struct nlohmann::adl_serializer> { static void - to_json(nlohmann::json& j, const std::variant& data) { + to_json(nlohmann::json& j, const std::variant& data) { std::visit([&](const auto& v) { j = {{"index", data.index()}, {"value", v}}; }, data); } static void - from_json(const nlohmann::json& j, std::variant& data) { + from_json(const nlohmann::json& j, std::variant& data) { if (!j.is_object() || !j.contains("index") || !j.contains("value")) { return; } @@ -111,10 +111,12 @@ struct nlohmann::adl_serializer> if (type_index == 0) { data = j.at("value").get(); } else if (type_index == 1) { - data = j.at("value").get(); + data = j.at("value").get(); } else if (type_index == 2) { - data = j.at("value").get(); + data = j.at("value").get(); } else if (type_index == 3) { + data = j.at("value").get(); + } else if (type_index == 4) { data = j.at("value").get(); } else { return; @@ -253,6 +255,11 @@ to_string(cudnnBackendBehaviorNote_t note) { return std::string("CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER"); case CUDNN_BEHAVIOR_NOTE_TYPE_COUNT: return std::string("CUDNN_BEHAVIOR_NOTE_TYPE_COUNT"); + // If none of the above cases hit, its definitely strict nan prop and should raise an error. +#if (CUDNN_VERSION >= 90500) + case CUDNN_BEHAVIOR_NOTE_SUPPORTS_CUDA_GRAPH_NATIVE_API: + return std::string("CUDNN_BEHAVIOR_NOTE_SUPPORTS_CUDA_GRAPH_NATIVE_API"); +#endif #ifndef NO_DEFAULT_IN_SWITCH default: return std::string("UNKNOWN_BEHAVIOR_NOTE"); @@ -443,7 +450,8 @@ enum class DescriptorType_t { OPERATION_NORM_BACKWARD_DESCRIPTOR, OPERATION_RESHAPE_DESCRIPTOR, RNG_DESCRIPTOR, - OPERATION_RNG_DESCRIPTOR + OPERATION_RNG_DESCRIPTOR, + OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR }; enum class NormMode_t { @@ -593,6 +601,7 @@ enum class BehaviorNote_t { RUNTIME_COMPILATION, REQUIRES_FILTER_INT8x32_REORDER, REQUIRES_BIAS_INT8x32_REORDER, + SUPPORTS_CUDA_GRAPH_NATIVE_API, }; NLOHMANN_JSON_SERIALIZE_ENUM(BehaviorNote_t, @@ -600,6 +609,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(BehaviorNote_t, {BehaviorNote_t::RUNTIME_COMPILATION, "RUNTIME_COMPILATION"}, {BehaviorNote_t::REQUIRES_FILTER_INT8x32_REORDER, "REQUIRES_FILTER_INT8x32_REORDER"}, {BehaviorNote_t::REQUIRES_BIAS_INT8x32_REORDER, "REQUIRES_BIAS_INT8x32_REORDER"}, + {BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API, "SUPPORTS_CUDA_GRAPH_NATIVE_API"}, }) enum class NumericalNote_t { @@ -883,6 +893,9 @@ operator<<(std::ostream& os, const DescriptorType_t& mode) { case DescriptorType_t::OPERATION_RNG_DESCRIPTOR: os << "OPERATION_RNG_DESCRIPTOR"; break; + case DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR: + os << "OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR"; + break; case DescriptorType_t::NOT_SET: os << "NOT_SET"; break; @@ -1285,6 +1298,14 @@ convert_to_cudnn_type(cudnn_frontend::BehaviorNote_t const mode, cudnnBackendBeh case BehaviorNote_t::REQUIRES_BIAS_INT8x32_REORDER: cudnn_mode = CUDNN_BEHAVIOR_NOTE_REQUIRES_BIAS_INT8x32_REORDER; return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API: +#if (CUDNN_VERSION >= 90500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90300, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE); + cudnn_mode = CUDNN_BEHAVIOR_NOTE_SUPPORTS_CUDA_GRAPH_NATIVE_API; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#else + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif } return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; } @@ -1417,6 +1438,13 @@ convert_to_cudnn_type(cudnn_frontend::DescriptorType_t const mode, cudnnBackendD return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; #endif + case DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR: +#if (CUDNN_VERSION >= 90500) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90500, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE); + cudnn_mode = CUDNN_BACKEND_OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#endif + #ifndef NO_DEFAULT_IN_SWITCH default: return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; @@ -1792,6 +1820,11 @@ convert_from_cudnn_type(cudnnBackendDescriptorType_t const cudnn_mode) { return DescriptorType_t::OPERATION_RNG_DESCRIPTOR; #endif +#if (CUDNN_VERSION >= 90500) + case CUDNN_BACKEND_OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR: + return DescriptorType_t::OPERATION_PAGED_CACHE_LOAD_DESCRIPTOR; +#endif + #ifndef NO_DEFAULT_IN_SWITCH default: return DescriptorType_t::NOT_SET; diff --git a/include/cudnn_frontend_version.h b/include/cudnn_frontend_version.h index 002d7482..24468286 100644 --- a/include/cudnn_frontend_version.h +++ b/include/cudnn_frontend_version.h @@ -23,7 +23,7 @@ #pragma once #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 7 +#define CUDNN_FRONTEND_MINOR_VERSION 8 #define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 39c01f18..6e874d3e 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -51,7 +51,7 @@ python_add_library( ) target_link_libraries(_compiled_module PRIVATE pybind11::headers) -target_compile_features(_compiled_module PRIVATE cxx_std_17) +target_compile_features(_compiled_module PRIVATE cxx_std_20) target_include_directories( _compiled_module diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index fa823359..a3eba17c 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -12,6 +12,7 @@ reduction_mode, behavior_note, create_handle, + create_kernel_cache, get_stream, numerical_note, set_stream, @@ -25,7 +26,7 @@ from .datatypes import _library_type, _is_torch_tensor -__version__ = "1.7.0" +__version__ = "1.8.0" def _tensor( diff --git a/python/properties.cpp b/python/properties.cpp index 443ee404..e16024a9 100644 --- a/python/properties.cpp +++ b/python/properties.cpp @@ -3,6 +3,8 @@ #include "pybind11/pybind11.h" #include "pybind11/cast.h" #include "pybind11/stl.h" +#include "pybind11/complex.h" +#include "pybind11/functional.h" #include "cudnn_frontend.h" @@ -50,6 +52,13 @@ class HandleManagement { } }; +std::shared_ptr +create_kernel_cache_helper() { + auto kernel_cache = std::make_shared(); + throw_if(kernel_cache == nullptr, cudnn_frontend::error_code_t::INVALID_VALUE, "kernel cache creation failed"); + return kernel_cache; +} + static std::string get_last_error_string() { return detail::get_last_error_string_(); @@ -109,6 +118,9 @@ init_properties(py::module_& m) { m.def("get_last_error_string", &get_last_error_string); + py::class_>(m, "kernel_cache"); + m.def("create_kernel_cache", &create_kernel_cache_helper); + m.def("create_handle", &HandleManagement::create_handle); m.def("destroy_handle", &HandleManagement::destroy_handle); m.def("get_stream", &HandleManagement::get_stream); @@ -159,8 +171,28 @@ init_properties(py::module_& m) { py::enum_(m, "behavior_note") .value("RUNTIME_COMPILATION", cudnn_frontend::BehaviorNote_t::RUNTIME_COMPILATION) .value("REQUIRES_FILTER_INT8x32_REORDER", cudnn_frontend::BehaviorNote_t::REQUIRES_FILTER_INT8x32_REORDER) - .value("REQUIRES_BIAS_INT8x32_REORDER", cudnn_frontend::BehaviorNote_t::REQUIRES_BIAS_INT8x32_REORDER); + .value("REQUIRES_BIAS_INT8x32_REORDER", cudnn_frontend::BehaviorNote_t::REQUIRES_BIAS_INT8x32_REORDER) + .value("SUPPORTS_CUDA_GRAPH_NATIVE_API", cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API); } } // namespace python_bindings } // namespace cudnn_frontend + +// namespace pybind11 { +// namespace detail { +// template <> struct type_caster> { +// public: +// PYBIND11_TYPE_CASTER(std::shared_ptr, _("KernelCachePtr")); + +// bool load(handle , bool) { +// return false; // Prevent Python -> C++ conversion +// } + +// static handle cast(std::shared_ptr src, return_value_policy, handle) { +// if (!src) return none().release(); +// return capsule(new std::shared_ptr(std::move(src)), +// [](void *ptr) { delete static_cast*>(ptr); +// }).release(); +// } +// }; +// }} // namespace pybind11::detail \ No newline at end of file diff --git a/python/pycudnn.cpp b/python/pycudnn.cpp index eb14267e..22034f51 100644 --- a/python/pycudnn.cpp +++ b/python/pycudnn.cpp @@ -59,6 +59,10 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st void init_pygraph_submodule(py::module_ &); +// pybinds for kernel_cache class +void +create_kernel_cache_submodule(py::module_ &); + // pybinds for all properties and helpers void init_properties(py::module_ &); diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index 4575531f..a5909a6c 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -406,11 +406,53 @@ PyGraph::deserialize(py::object const& pyobj) { } } +void +PyGraph::update_cuda_graph(std::intptr_t handle, + std::unordered_map var_pack, + std::intptr_t workspace, + std::intptr_t cuda_graph) { + std::unordered_map var_pack_; + var_pack_.reserve(var_pack.size()); + for (auto const& [uid, device_pointer] : var_pack) { + var_pack_.emplace(uid, (void*)device_pointer); + } + + auto status = graph.update_cuda_graph(reinterpret_cast(handle), + var_pack_, + reinterpret_cast(workspace), + reinterpret_cast(cuda_graph)); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + return; +} + +void +PyGraph::populate_cuda_graph( + std::intptr_t handle, + std::unordered_map var_pack, + std::intptr_t workspace, + std::intptr_t cuda_graph) { + std::unordered_map var_pack_; + var_pack_.reserve(var_pack.size()); + for (auto const& [uid, device_pointer] : var_pack) { + var_pack_.emplace(uid, (void*)device_pointer); + } + + auto status = graph.populate_cuda_graph(reinterpret_cast(handle), + var_pack_, + reinterpret_cast(workspace), + reinterpret_cast(cuda_graph)); + throw_if(status.is_bad(), status.get_code(), status.get_message()); + + return; +} + void PyGraph::execute(std::unordered_map var_pack, std::intptr_t workspace, std::optional exec_handle) { std::unordered_map var_pack_; + var_pack_.reserve(var_pack.size()); for (auto const& [uid, device_pointer] : var_pack) { var_pack_.emplace(uid, (void*)device_pointer); } @@ -467,13 +509,15 @@ init_pygraph_submodule(py::module_& m) { cudnn_frontend::DataType_t, cudnn_frontend::DataType_t, std::optional, - py::object>(), + py::object, + std::shared_ptr>(), py::arg_v("name", "test_graph"), py::arg_v("io_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("intermediate_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("handle", std::nullopt), - py::arg_v("sm_count", py::none())) + py::arg_v("sm_count", py::none()), + py::arg_v("kernel_cache", nullptr)) .def("tensor_like", py::overload_cast const&, std::string const&>( &PyGraph::tensor_like), @@ -715,6 +759,7 @@ init_pygraph_submodule(py::module_& m) { Returns: cudnn_tensor: The result of reshape operation. Please set the dims for the output tensor. )pbdoc") + .def("deselect_engines", &PyGraph::deselect_engines) .def("deselect_numeric_notes", &PyGraph::deselect_numeric_notes) .def("deselect_behavior_notes", &PyGraph::deselect_behavior_notes) .def("select_numeric_notes", &PyGraph::select_numeric_notes) @@ -762,6 +807,8 @@ init_pygraph_submodule(py::module_& m) { If the graph does not have the UID, this will raise an error )pbdoc") .def("_execute", &PyGraph::execute) + .def("populate_cuda_graph", &PyGraph::populate_cuda_graph) + .def("update_cuda_graph", &PyGraph::update_cuda_graph) .def("serialize", &PyGraph::serialize) .def("deserialize", &PyGraph::deserialize) .def("_execute_plan_at_index", &PyGraph::execute_plan_at_index) diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index 55667f03..93af81ed 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -48,7 +48,8 @@ class PyGraph { cudnn_frontend::DataType_t intermediate_data_type, cudnn_frontend::DataType_t compute_data_type, std::optional handle_, - py::object sm_count) { + py::object sm_count, + std::shared_ptr kernel_cache) { graph.set_compute_data_type(compute_data_type) .set_intermediate_data_type(intermediate_data_type) .set_io_data_type(io_data_type); @@ -63,6 +64,11 @@ class PyGraph { if (sm_count.is(py::none()) == false) { graph.set_sm_count(sm_count.cast()); } + + if (kernel_cache) { + graph.set_kernel_cache(kernel_cache); + graph.set_dynamic_shape_enabled(true); + } } ~PyGraph() { @@ -285,6 +291,9 @@ class PyGraph { py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, + std::shared_ptr& paged_attention_k_table, + std::shared_ptr& paged_attention_v_table, + py::object const& paged_attention_max_seq_len_kv, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -303,6 +312,8 @@ class PyGraph { bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, + py::object const& max_total_seq_len_q, + py::object const& max_total_seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& sliding_window_length, @@ -325,7 +336,11 @@ class PyGraph { std::shared_ptr& scale_o, bool const is_inference, py::object const& attn_scale, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, bool const use_causal_mask, + py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -350,7 +365,11 @@ class PyGraph { std::shared_ptr& scale_dV, std::shared_ptr& scale_dP, py::object const& attn_scale, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, bool const use_causal_mask, + py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); @@ -381,6 +400,18 @@ class PyGraph { int64_t get_workspace_size(); + void + populate_cuda_graph(std::intptr_t handle, + std::unordered_map var_pack, + std::intptr_t workspace, + std::intptr_t cuda_graph); + + void + update_cuda_graph(std::intptr_t handle, + std::unordered_map var_pack, + std::intptr_t workspace, + std::intptr_t cuda_graph); + void execute(std::unordered_map var_pack, int64_t workspace, std::optional); @@ -402,6 +433,12 @@ class PyGraph { return; } + void + deselect_engines(std::vector const& engine_names) { + graph.deselect_engines(engine_names); + return; + } + void deselect_numeric_notes(std::vector const& notes) { graph.deselect_numeric_notes(notes); diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index 5a28e26c..51a30ec9 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -28,6 +28,9 @@ PyGraph::sdpa(std::shared_ptr& q, py::object const& sliding_window_length, py::object const& dropout, std::shared_ptr& rng_dump, + std::shared_ptr& paged_attention_k_table, + std::shared_ptr& paged_attention_v_table, + py::object const& paged_attention_max_seq_len_kv, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::SDPA_attributes() @@ -42,6 +45,22 @@ PyGraph::sdpa(std::shared_ptr& q, .set_compute_data_type(compute_data_type) .set_name(name); + if (paged_attention_k_table) { + attributes.set_paged_attention_k_table(paged_attention_k_table); + } + + if (paged_attention_v_table) { + attributes.set_paged_attention_v_table(paged_attention_v_table); + } + + if (!paged_attention_max_seq_len_kv.is_none()) { + if (py::isinstance(paged_attention_max_seq_len_kv)) { + attributes.set_paged_attention_max_seq_len_kv(paged_attention_max_seq_len_kv.cast()); + } else { + throw std::runtime_error("paged_attention_max_seq_len_kv must be an integer."); + } + } + if (!attn_scale.is_none()) { if (py::isinstance(attn_scale)) { auto const attn_scale_value = attn_scale.cast(); @@ -98,6 +117,8 @@ PyGraph::sdpa(std::shared_ptr& q, } } + // Add page table attributes + auto [O, Stats] = graph.sdpa(q, k, v, attributes); return {O, Stats}; } @@ -116,6 +137,8 @@ PyGraph::sdpa_backward(std::shared_ptr bool const use_padding_mask, std::shared_ptr& seq_len_q, std::shared_ptr& seq_len_kv, + py::object const& max_total_seq_len_q, + py::object const& max_total_seq_len_kv, bool const use_causal_mask, bool const use_causal_mask_bottom_right, py::object const& sliding_window_length, @@ -152,6 +175,16 @@ PyGraph::sdpa_backward(std::shared_ptr } } + if (!max_total_seq_len_q.is_none()) { + int const max_total_seq_len_q_value = max_total_seq_len_q.cast(); + attributes.set_max_total_seq_len_q(max_total_seq_len_q_value); + } + + if (!max_total_seq_len_kv.is_none()) { + int const max_total_seq_len_kv_value = max_total_seq_len_kv.cast(); + attributes.set_max_total_seq_len_kv(max_total_seq_len_kv_value); + } + if (!sliding_window_length.is_none()) { int const sliding_window_value = sliding_window_length.cast(); attributes.set_sliding_window_length(sliding_window_value); @@ -209,11 +242,18 @@ PyGraph::sdpa_fp8(std::shared_ptr& q, std::shared_ptr& scale_o, bool const is_inference, py::object const& attn_scale, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, bool const use_causal_mask, + py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::SDPA_fp8_attributes() .set_is_inference(is_inference) + .set_padding_mask(use_padding_mask) + .set_seq_len_q(seq_len_q) + .set_seq_len_kv(seq_len_kv) .set_causal_mask(use_causal_mask) .set_compute_data_type(compute_data_type) .set_name(name); @@ -231,6 +271,41 @@ PyGraph::sdpa_fp8(std::shared_ptr& q, } } + if (!dropout.is_none()) { + py::tuple dropout_tuple = dropout.cast(); + if ((!dropout_tuple) || (dropout_tuple.size() != 3 && dropout_tuple.size() != 2)) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor, and an offset tensor) or (mask " + "tensor, scale tensor)"); + } + if (py::isinstance(dropout_tuple[0])) { + auto const probability = dropout_tuple[0].cast(); + auto const seed = dropout_tuple[1].cast>(); + if (!seed) { + throw std::runtime_error("dropout seed must be a cudnn_tensor."); + } + + auto const offset = dropout_tuple[2].cast>(); + if (!offset) { + throw std::runtime_error("dropout offset must be a cudnn_tensor."); + } + + attributes.set_dropout(probability, seed, offset); + } else { + auto const mask = dropout_tuple[0].cast>(); + if (!mask) { + throw std::runtime_error("dropout mask must be a cudnn_tensor."); + } + + auto const scale = dropout_tuple[1].cast>(); + if (!scale) { + throw std::runtime_error("dropout scale must be a cudnn_tensor."); + } + + attributes.set_dropout(mask, scale); + } + } + auto [o, stats, amax_s, amax_o] = graph.sdpa_fp8(q, k, v, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attributes); return {o, stats, amax_s, amax_o}; @@ -256,10 +331,17 @@ PyGraph::sdpa_fp8_backward(std::shared_ptr& scale_dV, std::shared_ptr& scale_dP, py::object const& attn_scale, + bool const use_padding_mask, + std::shared_ptr& seq_len_q, + std::shared_ptr& seq_len_kv, bool const use_causal_mask, + py::object const& dropout, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { auto attributes = cudnn_frontend::graph::SDPA_fp8_backward_attributes() + .set_padding_mask(use_padding_mask) + .set_seq_len_q(seq_len_q) + .set_seq_len_kv(seq_len_kv) .set_causal_mask(use_causal_mask) .set_compute_data_type(compute_data_type) .set_name(name); @@ -277,6 +359,41 @@ PyGraph::sdpa_fp8_backward(std::shared_ptr(dropout)) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + py::tuple dropout_tuple = dropout.cast(); + if (dropout_tuple.size() != 3) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + + if (py::isinstance(dropout_tuple[0]) && py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const probability = dropout_tuple[0].cast(); + auto const seed = dropout_tuple[1].cast>(); + auto const offset = dropout_tuple[2].cast>(); + attributes.set_dropout(probability, seed, offset); + } else if (py::isinstance(dropout_tuple[0], cudnn_tensor_type) && + py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const mask = dropout_tuple[0].cast>(); + auto const scale = dropout_tuple[1].cast>(); + auto const scale_inv = dropout_tuple[2].cast>(); + attributes.set_dropout(mask, scale, scale_inv); + } else { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + } + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = graph.sdpa_fp8_backward(q, k, v, @@ -318,6 +435,9 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg_v("sliding_window_length", py::none()), py::arg_v("dropout", py::none()), py::arg_v("rng_dump", nullptr), + py::arg_v("paged_attention_k_table", py::none()), + py::arg_v("paged_attention_v_table", py::none()), + py::arg_v("paged_attention_max_seq_len_kv", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -325,8 +445,8 @@ init_pygraph_sdpa_submodule(py::class_& m) { Args: q (cudnn_tensor): The query data. - k (cudnn_tensor): The key data. - v (cudnn_tensor): The value data. + k (cudnn_tensor): The key data. When page_table_k is provided, 'k' is a container of non-contiguous key data. + v (cudnn_tensor): The value data. When page_table_v is provided, 'v' is a container of non-contiguous value data. is_inference (bool): Whether it is an inference step or training step. attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. @@ -339,6 +459,9 @@ init_pygraph_sdpa_submodule(py::class_& m) { sliding_window_length (Optional[int]): The length of sliding window. Default is None. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. rng_dump (Optional[cudnn_tensor]): Debug tensor to dump the Philox RNG dropout mask. Default is None. + paged_attention_k_table (Optional[cudnn_tensor]): The page table to look up offsets into 'k' + paged_attention_v_table (Optional[cudnn_tensor]): The page table to look up offsets into 'v' + paged_attention_max_seq_len_kv (Optional[integer]): The maximum sequence length for k/v caches when paged attention is active. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -361,6 +484,8 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg_v("use_padding_mask", false), py::arg_v("seq_len_q", nullptr), py::arg_v("seq_len_kv", nullptr), + py::arg_v("max_total_seq_len_q", py::none()), + py::arg_v("max_total_seq_len_kv", py::none()), py::arg_v("use_causal_mask", false), py::arg_v("use_causal_mask_bottom_right", false), py::arg_v("sliding_window_length", py::none()), @@ -413,7 +538,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg("scale_o"), py::arg("is_inference"), py::arg_v("attn_scale", py::none()), + py::arg("use_padding_mask"), + py::arg_v("seq_len_q", nullptr), + py::arg_v("seq_len_kv", nullptr), py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -431,7 +560,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { scale_o (cudnn_tensor): Scale factor for output. is_inference (bool): Whether it is an inference step or training step. attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_padding_mask (bool): Whether it is an inference step or training step. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. @@ -462,7 +595,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { py::arg("scale_dV"), py::arg("scale_dP"), py::arg_v("attn_scale", py::none()), + py::arg_v("use_padding_mask", false), + py::arg_v("seq_len_q", nullptr), + py::arg_v("seq_len_kv", nullptr), py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("name", ""), R"pbdoc( @@ -488,7 +625,11 @@ init_pygraph_sdpa_submodule(py::class_& m) { scale_dV (cudnn_tensor): Scale factor for value gradient. scale_dP (cudnn_tensor): Scale factor for dP gradient. attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_padding_mask (bool): Whether it is an inference step or training step. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index b3f37b0f..74c268aa 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -1,113 +1,27 @@ cmake_minimum_required(VERSION 3.18) -find_package(Catch2 QUIET) - find_package(Threads) +find_package(Catch2 QUIET) if(NOT Catch2_FOUND) - Include(FetchContent) + include(FetchContent) # Fetch and build catch2 FetchContent_Declare( - Catch2 - GIT_REPOSITORY https://github.com/catchorg/Catch2.git - GIT_TAG v3.3.2 + Catch2 + GIT_REPOSITORY https://github.com/catchorg/Catch2.git + GIT_TAG v3.3.2 ) FetchContent_MakeAvailable(Catch2) endif() -# Find cudnn include(${PROJECT_SOURCE_DIR}/cmake/cuDNN.cmake) -add_executable( - samples - - cpp/sdpa/fp16_fwd.cpp - cpp/sdpa/fp16_bwd.cpp - cpp/sdpa/fp16_cached.cpp - cpp/sdpa/fp16_benchmark.cpp - cpp/sdpa/fp16_fwd_with_custom_dropout.cpp - cpp/sdpa/fp8_fwd.cpp - cpp/sdpa/fp8_bwd.cpp - - cpp/convolution/fprop.cpp - cpp/convolution/fp8_fprop.cpp - cpp/convolution/int8_fprop.cpp - cpp/convolution/dgrads.cpp - cpp/convolution/wgrads.cpp - - cpp/matmul/matmuls.cpp - cpp/matmul/fp8_matmul.cpp - cpp/matmul/int8_matmul.cpp - cpp/matmul/mixed_matmul.cpp - - cpp/norm/batchnorm.cpp - cpp/norm/layernorm.cpp - cpp/norm/rmsnorm.cpp - - cpp/misc/serialization.cpp - cpp/misc/autotuning.cpp - cpp/misc/parallel_compilation.cpp - cpp/misc/pointwise.cpp - cpp/misc/resample.cpp - cpp/misc/slice.cpp - cpp/misc/sm_carveout.cpp - - legacy_samples/conv_sample.cpp - legacy_samples/test_list.cpp - legacy_samples/fp16_emu.cpp - legacy_samples/helpers.cpp - legacy_samples/fusion_sample.cpp - legacy_samples/fp8_sample.cpp - legacy_samples/norm_samples.cpp - legacy_samples/fused_mha_sample.cpp - legacy_samples/f16_flash_mha_sample.cpp - legacy_samples/fp8_flash_mha_sample.cpp -) - if(DEFINED ENV{NO_DEFAULT_IN_SWITCH}) message("Default case in the switch is disabled") add_compile_definitions(NO_DEFAULT_IN_SWITCH) endif() -if (MSVC) - target_compile_options( - samples PRIVATE - /W4 /WX # warning level 3 and all warnings as errors - /wd4100 # allow unused parameters - /wd4458 # local hides class member (currently a problem for all inline setters) - /wd4505 # unreferenced function with internal linkage has been removed - /wd4101 /wd4189 # unreferenced local - /bigobj # increase number of sections in .Obj file - ) -else() - target_compile_options( - samples PRIVATE - -Wall - -Wextra - -Werror - -Wno-unused-function - ) -endif() - -target_link_libraries( - samples - - PRIVATE Threads::Threads - - cudnn_frontend - _cudnn_frontend_pch - Catch2::Catch2WithMain - - - CUDNN::cudnn -) - -# cuDNN dlopen's its libraries -# Add all libraries in link line as NEEDED -set_target_properties( - samples - PROPERTIES - LINK_WHAT_YOU_USE TRUE - RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin -) \ No newline at end of file +# Add subdirectories for samples and legacy_samples +add_subdirectory(cpp) +add_subdirectory(legacy_samples) diff --git a/samples/README.md b/samples/README.md index 3c60a159..c430bbe3 100644 --- a/samples/README.md +++ b/samples/README.md @@ -17,6 +17,9 @@ Samples leveraging FE's Python interface are located in [samples/python](python/ * [51_sdpa](python/51_scaled_dot_product_attention_backward.ipynb) Shows how to run causal self attention in bprop. +* [52_sdpa](python/52_scaled_dot_product_attention_with_paged_caches.ipynb) + Shows how to run scaled dot product attention where the K and V caches are stored in non contiguous memory. + ## C++ Interface Samples Samples leveraging FE's C++ interface are located in [samples/cpp](cpp/). @@ -48,6 +51,10 @@ Users are expected to build a graph once and then execute it multiple times. Thi cudnn's sdpa operation enables various customizations on itself. These examples show how to build a graph with sdpa operation for your own custom sdpa needs. +- [Fwd SDPA with paged caches](cpp/sdpa/fp16_fwd_with_paged_caches.cpp) + +Similar to [Fwd SDPA](cpp/sdpa/fp16_fwd.cpp), but here with the ability to use non contiguous K and V caches in combination with page tables, as described in the [PagedAttention paper](https://arxiv.org/abs/2309.06180). + - [Fwd FP8 SDPA](cpp/sdpa/fp8_fwd.cpp) and [Bwd SDPA](cpp/sdpa/fp8_bwd.cpp) Extends the sdpa sample to fp8 precision. @@ -132,6 +139,10 @@ How to serialize a graph into a file and read it back on another thread/process. How to choose the best performing plan among multiple plans suggested by the heuristics. +- [Cuda Graphs](cpp/misc/cudagraphs.cpp) + +Shows how to use the native cuda graph API. The samples show how to create cudnn's cuda graph, and how to repeatedly update it with new device buffers for multiple execution. + - [SM Carveout](cpp/misc/sm_carveout.cpp) Showcases a Batch norm example, where only a partial number of SMs participate in executing the kernel. diff --git a/samples/cpp/CMakeLists.txt b/samples/cpp/CMakeLists.txt new file mode 100644 index 00000000..137ea93e --- /dev/null +++ b/samples/cpp/CMakeLists.txt @@ -0,0 +1,77 @@ +# target sources +add_executable( + samples + + sdpa/fp16_fwd.cpp + sdpa/fp16_bwd.cpp + sdpa/fp16_cached.cpp + sdpa/fp16_benchmark.cpp + sdpa/fp16_fwd_with_custom_dropout.cpp + sdpa/fp16_fwd_with_paged_caches.cpp + sdpa/fp8_fwd.cpp + sdpa/fp8_bwd.cpp + + convolution/fprop.cpp + convolution/fp8_fprop.cpp + convolution/int8_fprop.cpp + convolution/dgrads.cpp + convolution/wgrads.cpp + + matmul/matmuls.cpp + matmul/fp8_matmul.cpp + matmul/int8_matmul.cpp + matmul/mixed_matmul.cpp + + norm/batchnorm.cpp + norm/layernorm.cpp + norm/rmsnorm.cpp + + misc/serialization.cpp + misc/autotuning.cpp + misc/parallel_compilation.cpp + misc/pointwise.cpp + misc/resample.cpp + misc/slice.cpp + misc/sm_carveout.cpp + misc/cudagraphs.cpp +) + +# target flags +if(MSVC) + target_compile_options( + samples PRIVATE + /W4 /WX # warning level 3 and all warnings as errors + /wd4100 # allow unused parameters + /wd4458 # local hides class member (currently a problem for all inline setters) + /wd4505 # unreferenced function with internal linkage has been removed + /wd4101 /wd4189 # unreferenced local + /bigobj # increase number of sections in .Obj file + ) +else() + target_compile_options( + samples PRIVATE + -Wall + -Wextra + -Werror + -Wno-unused-function + ) +endif() + +# target links +target_link_libraries( + samples PRIVATE + Threads::Threads + Catch2::Catch2WithMain + cudnn_frontend + _cudnn_frontend_pch + CUDNN::cudnn + + CUDA::cudart + CUDA::cuda_driver # Needed as calls all CUDA calls will eventually move to driver +) + +# target cmake properties +set_target_properties( + samples PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin +) diff --git a/samples/cpp/convolution/dgrads.cpp b/samples/cpp/convolution/dgrads.cpp index 3265e007..589cb5fc 100644 --- a/samples/cpp/convolution/dgrads.cpp +++ b/samples/cpp/convolution/dgrads.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -49,7 +49,7 @@ TEST_CASE("Convolution Dgrad", "[dgrad][graph]") { DX->set_dim({4, 32, 16, 16}).set_output(true); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -105,7 +105,7 @@ TEST_CASE("Dgrad Drelu Graph", "[dgrad][graph]") { DX->set_output(true); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -209,7 +209,7 @@ TEST_CASE("Dgrad Drelu DBNweight Graph", "[dgrad][graph]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); REQUIRE(graph.build_operation_graph(handle).is_good()); diff --git a/samples/cpp/convolution/fp8_fprop.cpp b/samples/cpp/convolution/fp8_fprop.cpp index 4d0b6efe..0aa2444a 100644 --- a/samples/cpp/convolution/fp8_fprop.cpp +++ b/samples/cpp/convolution/fp8_fprop.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -96,7 +96,7 @@ TEST_CASE("Convolution fp8 precision", "[conv][graph]") { REQUIRE(graph->validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph->build_operation_graph(handle).is_good()); REQUIRE(graph->create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -130,5 +130,5 @@ TEST_CASE("Convolution fp8 precision", "[conv][graph]") { std::cout << graph->print() << std::endl; REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } diff --git a/samples/cpp/convolution/fprop.cpp b/samples/cpp/convolution/fprop.cpp index 73493cc0..c863f73a 100644 --- a/samples/cpp/convolution/fprop.cpp +++ b/samples/cpp/convolution/fprop.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -69,7 +69,7 @@ TEST_CASE("Convolution fprop", "[conv][graph][caching]") { cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, Y] = build_new_graph(handle); @@ -203,7 +203,7 @@ TEST_CASE("Convolution fprop dynamic shape", "[conv][graph][dynamic_shape]") { }; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); for (int idx_shape = 0; idx_shape < conv_shapes_count; ++idx_shape) { auto [graph, X, W, Y] = build_new_graph(handle, idx_shape); @@ -291,7 +291,7 @@ TEST_CASE("CSBR Graph", "[conv][graph][caching]") { }; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, B, S, Y] = lookup_cache_or_build_graph(handle); @@ -433,7 +433,7 @@ TEST_CASE("CSBR Graph dynamic shape", "[conv][graph][dynamic_shape]") { }; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); for (int idx_shape = 0; idx_shape < conv_shapes_count; idx_shape++) { auto [graph, X, W, B, S, Y] = lookup_cache_or_build_graph(handle, idx_shape); @@ -525,7 +525,7 @@ TEST_CASE("SBRCS", "[conv][genstats][graph]") { SKIP("SBRCS requires Ampere or Hopper"); } - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, B, S, Y, SUM, SQ_SUM] = build_new_graph(handle); @@ -635,7 +635,7 @@ TEST_CASE("CBR Graph NCHW", "[conv][graph][caching]") { }; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, Z, B, Y] = lookup_cache_or_build_graph(handle); @@ -719,7 +719,7 @@ TEST_CASE("Convolution fprop large", "[conv][graph][caching]") { cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, Y] = build_new_graph(handle); diff --git a/samples/cpp/convolution/int8_fprop.cpp b/samples/cpp/convolution/int8_fprop.cpp index 233e569f..3d5ac2fd 100644 --- a/samples/cpp/convolution/int8_fprop.cpp +++ b/samples/cpp/convolution/int8_fprop.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -83,7 +83,7 @@ TEST_CASE("Conv with Int8 datatypes", "[conv][graph][caching]") { SKIP("Int8 datatype convolutions require Ampere and later architectures"); } - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto [graph, X, W, Y] = build_new_graph(handle); diff --git a/samples/cpp/convolution/wgrads.cpp b/samples/cpp/convolution/wgrads.cpp index 12cb72ed..15970697 100644 --- a/samples/cpp/convolution/wgrads.cpp +++ b/samples/cpp/convolution/wgrads.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -48,7 +48,7 @@ TEST_CASE("Convolution Wgrad", "[wgrad][graph][wgrad][Conv_wgrad]") { DW->set_output(true).set_dim({64, 64, 3, 3}); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -115,7 +115,7 @@ TEST_CASE("Wgrad Graph", "[wgrad][graph][scale-bias-relu-wgrad][ConvBNwgrad]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); diff --git a/samples/cpp/matmul/fp8_matmul.cpp b/samples/cpp/matmul/fp8_matmul.cpp index 62f63d79..c6470cdd 100644 --- a/samples/cpp/matmul/fp8_matmul.cpp +++ b/samples/cpp/matmul/fp8_matmul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -24,7 +24,7 @@ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -105,7 +105,7 @@ TEST_CASE("Matmul fp8 precision", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -127,5 +127,5 @@ TEST_CASE("Matmul fp8 precision", "[matmul][graph]") { {B_descale, B_descale_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } diff --git a/samples/cpp/matmul/int8_matmul.cpp b/samples/cpp/matmul/int8_matmul.cpp index 788b49f4..cf4353a2 100644 --- a/samples/cpp/matmul/int8_matmul.cpp +++ b/samples/cpp/matmul/int8_matmul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -24,7 +24,7 @@ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -87,7 +87,7 @@ TEST_CASE("Int8 Matmul", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -113,5 +113,5 @@ TEST_CASE("Int8 Matmul", "[matmul][graph]") { std::cout << graph.print() << std::endl; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/matmul/matmuls.cpp b/samples/cpp/matmul/matmuls.cpp index ef79c429..ed0f10b1 100644 --- a/samples/cpp/matmul/matmuls.cpp +++ b/samples/cpp/matmul/matmuls.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -24,7 +24,7 @@ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -152,7 +152,7 @@ matmul_dynamic_shapes(bool use_abs = false, bool use_bias = false) { // Run cudnn graph cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); for (int idx_shape = 0; idx_shape < matmul_shapes_count; idx_shape++) { auto [graph, A, B, C, Bias] = build_new_graph(handle, idx_shape); @@ -172,7 +172,7 @@ matmul_dynamic_shapes(bool use_abs = false, bool use_bias = false) { REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); } - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Matmul dynamic shape", "[matmul][graph][dynamic_shape]") { @@ -238,7 +238,7 @@ TEST_CASE("Matmul", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -257,7 +257,7 @@ TEST_CASE("Matmul", "[matmul][graph]") { std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Abs + Matmul", "[matmul][graph]") { @@ -307,7 +307,7 @@ TEST_CASE("Abs + Matmul", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); @@ -326,7 +326,7 @@ TEST_CASE("Abs + Matmul", "[matmul][graph]") { std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Bias + Matmul", "[matmul][graph]") { @@ -392,7 +392,7 @@ TEST_CASE("Bias + Matmul", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -435,7 +435,7 @@ TEST_CASE("Bias + Matmul", "[matmul][graph]") { std::mt19937{std::random_device{}()}); Surface workspace(graph.get_workspace_size_plan_at_index(random_successful.front()), false); REQUIRE(graph.execute_plan_at_index(handle, variant_pack, workspace.devPtr, random_successful.front()).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { @@ -528,7 +528,7 @@ TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { }; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); Surface x_tensor(4 * 16 * 64, false); Surface w_tensor(4 * 64 * 32, false); @@ -594,7 +594,7 @@ TEST_CASE("Matmul with restricted shared memory", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -613,5 +613,5 @@ TEST_CASE("Matmul with restricted shared memory", "[matmul][graph]") { std::unordered_map, void*> variant_pack = { {A, A_gpu.devPtr}, {B, B_gpu.devPtr}, {C, C_gpu.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/matmul/mixed_matmul.cpp b/samples/cpp/matmul/mixed_matmul.cpp index 956f88f5..ab3e195f 100644 --- a/samples/cpp/matmul/mixed_matmul.cpp +++ b/samples/cpp/matmul/mixed_matmul.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -24,7 +24,7 @@ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -80,7 +80,7 @@ TEST_CASE("Mixed Precision Matmul", "[matmul][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); @@ -105,5 +105,5 @@ TEST_CASE("Mixed Precision Matmul", "[matmul][graph]") { std::cout << graph.print() << std::endl; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } diff --git a/samples/cpp/misc/autotuning.cpp b/samples/cpp/misc/autotuning.cpp index bd61ac1c..646fc7a5 100644 --- a/samples/cpp/misc/autotuning.cpp +++ b/samples/cpp/misc/autotuning.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -45,7 +45,7 @@ TEST_CASE("Matmul autotuning", "[matmul][graph][autotuning]") { int64_t a_uid = 0, b_uid = 1, c_uid = 2; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto create_graph = [&]() -> fe::graph::Graph { // Make cudnn graph @@ -158,5 +158,5 @@ TEST_CASE("Matmul autotuning", "[matmul][graph][autotuning]") { Surface workspace(graph.get_workspace_size_plan_at_index(candidate_index), false); REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/cudagraphs.cpp b/samples/cpp/misc/cudagraphs.cpp new file mode 100644 index 00000000..2bb475e9 --- /dev/null +++ b/samples/cpp/misc/cudagraphs.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +#define A_UID 0 +#define B_UID 1 +#define C_UID 2 +#define D_UID 3 + +std::shared_ptr +create_graph(int64_t b, int64_t m, int64_t n, int64_t k, float scale_value) { + //// Create the cudnn graph + auto graph = std::make_shared(); + graph->set_io_data_type(cudnn_frontend::DataType_t::HALF) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto A = graph->tensor( + cudnn_frontend::graph::Tensor_attributes().set_dim({b, m, k}).set_stride({m * k, k, 1}).set_uid(A_UID)); + + auto scale_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::MUL); + auto S = graph->pointwise(A, graph->tensor(scale_value), scale_options); + S->set_data_type(cudnn_frontend::DataType_t::HALF); + + auto B = graph->tensor( + cudnn_frontend::graph::Tensor_attributes().set_dim({b, k, n}).set_stride({n * k, n, 1}).set_uid(B_UID)); + auto T = graph->matmul(S, B, cudnn_frontend::graph::Matmul_attributes()); + + auto C = graph->tensor(cudnn_frontend::graph::Tensor_attributes() + .set_dim({1, 1, 1}) + .set_stride({1, 1, 1}) + .set_is_pass_by_value(true) + .set_uid(C_UID)); + auto add_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + auto D = graph->pointwise(T, C, add_options); + D->set_output(true).set_uid(D_UID); + return graph; +} + +TEST_CASE("Cuda graphs with matmul add", "[cudagraph][graph]") { + //// Main graph + // This example shows how to add a cudnn cuda graph to an already existing cuda graph. + cudaGraph_t main_cuda_graph; + cudaGraphCreate(&main_cuda_graph, 0); + + // Create any FE graph that you want to create a cuda graph for + int64_t b = 8, m = 32, n = 16, k = 8; + float scale_value = .5f; + auto graph = create_graph(b, m, n, k, scale_value); + + // Create the execution plan, as that is needed to populate cuda graph with cudnn kernels + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + // Validare the graph and lower the FE graph to BE graph + REQUIRE(graph->validate().is_good()); + REQUIRE(graph->build_operation_graph(handle).is_good()); + REQUIRE(graph->create_execution_plans({cudnn_frontend::HeurMode_t::A}).is_good()); + + // Make sure the selected executino plan supports cuda graph + graph->select_behavior_notes({cudnn_frontend::BehaviorNote_t::SUPPORTS_CUDA_GRAPH_NATIVE_API}); + auto status = graph->check_support(handle); + if (cudnn_frontend::detail::get_backend_version() >= 90500) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.is_bad()); + SKIP("cudnn versions 9.5 and earlier don't support behavior note of SUPPORTS_CUDA_GRAPH_NATIVE_API."); + } + REQUIRE(graph->build_plans(handle).is_good()); + + //// Populate an exisiting cuda graph with cudnn's cuda graph + cudaGraph_t cudnn_cuda_graph; + + // Initialize the cudnn cuda graph. + // The responsibility to destroy is on the user. + cudaGraphCreate(&cudnn_cuda_graph, 0); // 0 is just what the API says to pass + + Surface workspace(graph->get_workspace_size(), false); + + half starter_value = __float2half(1.f); + half bias_value = __float2half(2.f); + Surface a_gpu(b * m * k, false, starter_value); + Surface b_gpu(b * k * n, false, starter_value); + Surface d_gpu(b * m * n, false); + std::unordered_map variant_pack = { + {A_UID, a_gpu.devPtr}, {B_UID, b_gpu.devPtr}, {C_UID, &bias_value}, {D_UID, d_gpu.devPtr}}; + + REQUIRE(graph->populate_cuda_graph(handle, variant_pack, workspace.devPtr, cudnn_cuda_graph).is_good()); + + // Put cudnn's cuda graph into main graph + cudaGraphNode_t cudnn_node_in_main_graph; + cudaGraphAddChildGraphNode(&cudnn_node_in_main_graph, + main_cuda_graph, + NULL, + 0, + cudnn_cuda_graph); // Note that this clones cudnn_cuda_graph. + + // It is safe to destroy cudnn_cuda_graph here. + cudaGraphDestroy(cudnn_cuda_graph); + + //// Instantiate the main graph. + cudaGraphExec_t cuda_graph_exec; + cudaGraphInstantiate(&cuda_graph_exec, main_cuda_graph, NULL, NULL, 0); + + cudaGraphLaunch(cuda_graph_exec, 0); + + //// Functional correctness + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK( + cudaMemcpy(d_gpu.hostPtr, d_gpu.devPtr, sizeof(d_gpu.hostPtr[0]) * d_gpu.n_elems, cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaDeviceSynchronize()); + + for (int i = 0; i < d_gpu.n_elems; i++) { + REQUIRE(__half2float(d_gpu.hostPtr[i]) == + scale_value * k * __half2float(starter_value) + __half2float(bias_value)); + } + + //// Update the instantiated cuda graph with new device pointers + Surface workspace_new(graph->get_workspace_size(), false); + + half starter_value_new = __float2half(1.f); + half bias_value_new = __float2half(1.f); + Surface a_gpu_new(b * m * k, false, starter_value_new); + Surface b_gpu_new(b * k * n, false, starter_value_new); + Surface d_gpu_new(b * m * n, false); + std::unordered_map variant_pack_new = { + {A_UID, a_gpu_new.devPtr}, {B_UID, b_gpu_new.devPtr}, {C_UID, &bias_value_new}, {D_UID, d_gpu_new.devPtr}}; + + // This needs a cudnn cuda graph, which we can query from the cudnn_node in the main graph + cudaGraph_t cudnn_cuda_graph_new; + cudaGraphChildGraphNodeGetGraph(cudnn_node_in_main_graph, &cudnn_cuda_graph_new); + + REQUIRE(graph->update_cuda_graph(handle, variant_pack_new, workspace_new.devPtr, cudnn_cuda_graph_new).is_good()); + + cudaGraphExecChildGraphNodeSetParams(cuda_graph_exec, cudnn_node_in_main_graph, cudnn_cuda_graph_new); + + cudaGraphLaunch(cuda_graph_exec, 0); + + //// Functional correctness + cudaDeviceSynchronize(); + CUDA_CHECK(cudaMemcpy( + d_gpu_new.hostPtr, d_gpu_new.devPtr, sizeof(d_gpu_new.hostPtr[0]) * d_gpu_new.n_elems, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + + for (int i = 0; i < d_gpu_new.n_elems; i++) { + REQUIRE(__half2float(d_gpu_new.hostPtr[i]) == + (scale_value * k * __half2float(starter_value_new) + __half2float(bias_value_new))); + } + + //// Cleanup + cudaGraphExecDestroy(cuda_graph_exec); + cudaGraphDestroy(main_cuda_graph); + cudaGraphDestroy(cudnn_cuda_graph_new); + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/cpp/misc/parallel_compilation.cpp b/samples/cpp/misc/parallel_compilation.cpp index 99dc4100..61c3387f 100644 --- a/samples/cpp/misc/parallel_compilation.cpp +++ b/samples/cpp/misc/parallel_compilation.cpp @@ -25,7 +25,7 @@ #include #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -57,7 +57,7 @@ TEST_CASE("Parallel build", "[matmul][graph][parallel]") { int64_t a_uid = 0, b_uid = 1, c_uid = 2; cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto create_graph = [&]() -> fe::graph::Graph { // Make cudnn graph @@ -148,5 +148,5 @@ TEST_CASE("Parallel build", "[matmul][graph][parallel]") { }; } - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/pointwise.cpp b/samples/cpp/misc/pointwise.cpp index e3801b18..8f8d699d 100644 --- a/samples/cpp/misc/pointwise.cpp +++ b/samples/cpp/misc/pointwise.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -44,7 +44,7 @@ TEST_CASE("Reduction", "[reduction]") { C->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({1, 1, 1, 1}); REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); @@ -56,7 +56,7 @@ TEST_CASE("Reduction", "[reduction]") { Surface workspace(workspace_size, false); REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Fused scalar", "[scalar][graph]") { @@ -78,7 +78,7 @@ TEST_CASE("Fused scalar", "[scalar][graph]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); @@ -94,7 +94,7 @@ TEST_CASE("Fused scalar", "[scalar][graph]") { REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Fused Amax Reduction and type conversion", "[reduction]") { @@ -136,7 +136,7 @@ TEST_CASE("Fused Amax Reduction and type conversion", "[reduction]") { REQUIRE(graph.validate().is_good()); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); @@ -153,5 +153,5 @@ TEST_CASE("Fused Amax Reduction and type conversion", "[reduction]") { Surface workspace(workspace_size, false); REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/resample.cpp b/samples/cpp/misc/resample.cpp index ac3acb9b..3f782e77 100644 --- a/samples/cpp/misc/resample.cpp +++ b/samples/cpp/misc/resample.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -57,7 +57,7 @@ TEST_CASE("Resample Max Pooling NHWC Inference", "[resample][pooling][max][graph assert(Index == nullptr); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); REQUIRE(graph.build_operation_graph(handle).is_good()); @@ -75,7 +75,7 @@ TEST_CASE("Resample Max Pooling NHWC Inference", "[resample][pooling][max][graph REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Resample Max Pooling NHWC Training", "[resample][pooling][max][graph]") { @@ -112,7 +112,7 @@ TEST_CASE("Resample Max Pooling NHWC Training", "[resample][pooling][max][graph] Index->set_output(true).set_data_type(fe::DataType_t::INT8); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -138,7 +138,7 @@ TEST_CASE("Resample Max Pooling NHWC Training", "[resample][pooling][max][graph] REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } TEST_CASE("Resample Avg Pooling", "[resample][pooling][average][graph]") { @@ -174,7 +174,7 @@ TEST_CASE("Resample Avg Pooling", "[resample][pooling][average][graph]") { assert(Index == nullptr); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); REQUIRE(graph.build_operation_graph(handle).is_good()); @@ -192,5 +192,5 @@ TEST_CASE("Resample Avg Pooling", "[resample][pooling][average][graph]") { REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/serialization.cpp b/samples/cpp/misc/serialization.cpp index 97c885ee..a1304064 100644 --- a/samples/cpp/misc/serialization.cpp +++ b/samples/cpp/misc/serialization.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -42,7 +42,7 @@ TEST_CASE("CSBR Graph with serialization", "[conv][graph][serialization]") { cudnnHandle_t handle; // Handle to use during deserialize and execute - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto build_and_validate_graph_helper = [](int64_t n, int64_t c, int64_t h, int64_t w, int64_t k, int64_t r, int64_t s) @@ -105,7 +105,7 @@ TEST_CASE("CSBR Graph with serialization", "[conv][graph][serialization]") { int64_t n, int64_t c, int64_t h, int64_t w, int64_t k, int64_t r, int64_t s) -> bool { cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = build_and_validate_graph_helper(n, c, h, w, k, r, s); @@ -129,7 +129,7 @@ TEST_CASE("CSBR Graph with serialization", "[conv][graph][serialization]") { std::vector serialized_data; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = build_and_validate_graph_helper(n, c, h, w, k, r, s); @@ -316,7 +316,7 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { float dropout_probability) -> bool { cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = build_and_validate_graph_helper( b, h, s_q, s_kv, d, is_attn_scale, is_inference, use_dropout_with_rng, dropout_probability); @@ -345,7 +345,7 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { std::vector serialized_data; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = build_and_validate_graph_helper( b, h, s_q, s_kv, d, is_attn_scale, is_inference, use_dropout_with_rng, dropout_probability); @@ -384,7 +384,7 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { serialize(b, h, s_q, s_kv, d, is_attn_scale, is_inference, use_dropout_with_rng, dropout_probability); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = deserialize(handle, serialize_data); @@ -417,5 +417,5 @@ TEST_CASE("SDPA Graph with serialization", "[sdpa][graph][serialization]") { REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } \ No newline at end of file diff --git a/samples/cpp/misc/slice.cpp b/samples/cpp/misc/slice.cpp index e35ff458..087ba363 100644 --- a/samples/cpp/misc/slice.cpp +++ b/samples/cpp/misc/slice.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -68,7 +68,7 @@ TEST_CASE("Slice gemm", "[slice][gemm][graph][fusion]") { C->set_output(true).set_uid(c_uid); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.build(handle, {fe::HeurMode_t::A}).is_good()); @@ -92,5 +92,5 @@ TEST_CASE("Slice gemm", "[slice][gemm][graph][fusion]") { REQUIRE(false); } - checkCudnnErr(cudnnDestroy(handle)); + CUDNN_CHECK(cudnnDestroy(handle)); } diff --git a/samples/cpp/misc/sm_carveout.cpp b/samples/cpp/misc/sm_carveout.cpp index e6464170..d6818c0a 100644 --- a/samples/cpp/misc/sm_carveout.cpp +++ b/samples/cpp/misc/sm_carveout.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -69,21 +69,16 @@ TEST_CASE("SGBN with SM carveout", "[batchnorm][graph][sm_carveout]") { .set_stride({4 * c, 1, 4 * c, 4 * c}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto momentum = graph.tensor(fe::graph::Tensor_attributes() - .set_name("momentum") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + float epsilon_cpu = 1e-05f; + float momentum_cpu = 1e-01f; + auto epsilon = graph.tensor(epsilon_cpu); + auto momentum = graph.tensor(momentum_cpu); auto batchnorm_options = fe::graph::Batchnorm_attributes() .set_epsilon(epsilon) .set_previous_running_stats(prev_running_mean, prev_running_var, momentum) .set_peer_stats({peer_stats_0, peer_stats_1}); + auto [Y, mean, inv_variance, next_running_mean, next_running_var] = graph.batchnorm(X, scale, bias, batchnorm_options); mean->set_output(true).set_data_type(fe::DataType_t::FLOAT); @@ -100,7 +95,7 @@ TEST_CASE("SGBN with SM carveout", "[batchnorm][graph][sm_carveout]") { SKIP("ConvBNFprop requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -121,8 +116,7 @@ TEST_CASE("SGBN with SM carveout", "[batchnorm][graph][sm_carveout]") { Surface Next_running_var_tensor(c, false); Surface Scale_tensor(c, false); Surface Bias_tensor(c, false); - float epsilon_cpu = 1e-05f; - float momentum_cpu = 1e-01f; + Surface Y_tensor(n * c * h * w, false); Surface Peer_stats_0_tensor(2 * 4 * c, false, true); Surface Peer_stats_1_tensor(2 * 4 * c, false); @@ -141,8 +135,6 @@ TEST_CASE("SGBN with SM carveout", "[batchnorm][graph][sm_carveout]") { {next_running_var, Next_running_var_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, - {epsilon, &epsilon_cpu}, - {momentum, &momentum_cpu}, {Y, Y_tensor.devPtr}, {peer_stats_0, Peer_stats_0_tensor.devPtr}, {peer_stats_1, Peer_stats_1_tensor.devPtr}}; diff --git a/samples/cpp/norm/batchnorm.cpp b/samples/cpp/norm/batchnorm.cpp index 2c4f7a4c..5949365a 100644 --- a/samples/cpp/norm/batchnorm.cpp +++ b/samples/cpp/norm/batchnorm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -46,22 +46,14 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { fe::graph::Tensor_attributes().set_name("scale").set_dim({1, 32, 1, 1}).set_stride({32, 1, 32, 32})); auto bias = graph.tensor( fe::graph::Tensor_attributes().set_name("bias").set_dim({1, 32, 1, 1}).set_stride({32, 1, 32, 32})); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true)); - auto momentum = graph.tensor(fe::graph::Tensor_attributes() - .set_name("momentum") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true)); - auto accum_count = graph.tensor(fe::graph::Tensor_attributes() - .set_name("accum_count") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::INT64)); + + float EPS_scalar = 0.001f; + float MOMENTUM_scalar = 0.001f; + int64_t nhw = 64; + + auto epsilon = graph.tensor(EPS_scalar); + auto momentum = graph.tensor(MOMENTUM_scalar); + auto accum_count = graph.tensor(nhw); auto bn_finalize_options = fe::graph::BN_finalize_attributes().set_previous_running_stats(prev_running_mean, prev_running_var, momentum); @@ -79,7 +71,7 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { #endif cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -103,9 +95,6 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { Surface Bias_tensor(32, false); Surface eq_scale_tensor(32, false); Surface eq_bias_tensor(32, false); - float EPS_scalar = 0.001f; - float MOMENTUM_scalar = 0.001f; - int64_t nhw = 64; int64_t workspace_size; REQUIRE(graph.get_workspace_size(workspace_size).is_good()); @@ -116,11 +105,8 @@ TEST_CASE("BN Finalize Graph", "[batchnorm][graph]") { {sq_sum, Sq_sum_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, - {epsilon, &EPS_scalar}, - {accum_count, &nhw}, {prev_running_mean, Previous_running_mean_tensor.devPtr}, {prev_running_var, Previous_running_var_tensor.devPtr}, - {momentum, &MOMENTUM_scalar}, {eq_scale, eq_scale_tensor.devPtr}, {eq_bias, eq_bias_tensor.devPtr}, {saved_mean, Mean_tensor.devPtr}, @@ -174,21 +160,11 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { .set_stride({4 * 32, 1, 4 * 32, 4 * 32}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - auto momentum = graph.tensor(fe::graph::Tensor_attributes() - .set_name("momentum") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); - - auto batchnorm_options = fe::graph::Batchnorm_attributes() - .set_epsilon(epsilon) - - .set_peer_stats({peer_stats_0, peer_stats_1}); + auto epsilon = graph.tensor(1e-05f); + auto momentum = graph.tensor(1e-01f); + + auto batchnorm_options = + fe::graph::Batchnorm_attributes().set_epsilon(epsilon).set_peer_stats({peer_stats_0, peer_stats_1}); if (has_running_stats) { batchnorm_options.set_previous_running_stats(prev_running_mean, prev_running_var, momentum); } @@ -224,7 +200,7 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { SKIP("ConvBNFprop requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -245,8 +221,6 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { Surface Next_running_var_tensor(32, false); Surface Scale_tensor(32, false); Surface Bias_tensor(32, false); - float epsilon_cpu = 1e-05f; - float momentum_cpu = 1e-01f; Surface A_tensor(4 * 32 * 16 * 16, false); Surface Y_tensor(4 * 32 * 16 * 16, false); Surface Peer_stats_0_tensor(2 * 4 * 32, false, true); @@ -262,8 +236,6 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { {inv_variance, Var_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, - {epsilon, &epsilon_cpu}, - {momentum, &momentum_cpu}, {A, A_tensor.devPtr}, {Y, Y_tensor.devPtr}, {peer_stats_0, Peer_stats_0_tensor.devPtr}, @@ -274,7 +246,6 @@ TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") { variant_pack[prev_running_var] = Previous_running_var_tensor.devPtr; variant_pack[next_running_mean] = Next_running_mean_tensor.devPtr; variant_pack[next_running_var] = Next_running_var_tensor.devPtr; - variant_pack[momentum] = &momentum_cpu; } REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -351,7 +322,7 @@ TEST_CASE("DBN Add Relu Graph", "[BN][graph][backward]") { SKIP("BatchNorm Backward requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -461,7 +432,7 @@ TEST_CASE("BN_inference DRelu DBN Graph", "[Batchnorm][graph][backward]") { #endif cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); diff --git a/samples/cpp/norm/layernorm.cpp b/samples/cpp/norm/layernorm.cpp index 3446f537..ba66c269 100644 --- a/samples/cpp/norm/layernorm.cpp +++ b/samples/cpp/norm/layernorm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -51,11 +51,8 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { .set_stride({hidden_size, 1, hidden_size, hidden_size}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + float epsilon_cpu = 1e-05f; + auto epsilon = graph.tensor(epsilon_cpu); auto layernorm_options = fe::graph::Layernorm_attributes().set_forward_phase(fe::NormFwdPhase_t::TRAINING).set_epsilon(epsilon); @@ -72,7 +69,7 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { SKIP("ConvBNFprop requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -89,7 +86,6 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { Surface Var_tensor(batch_size * seq_length, false); Surface Scale_tensor(hidden_size, false); Surface Bias_tensor(hidden_size, false); - float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); int64_t workspace_size; @@ -102,7 +98,6 @@ TEST_CASE("LayerNorm Training", "[layernorm][graph]") { {inv_variance, Var_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, - {epsilon, &epsilon_cpu}, {Y, Y_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -136,11 +131,8 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { .set_stride({hidden_size, 1, hidden_size, hidden_size}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + float epsilon_cpu = 1e-05f; + auto epsilon = graph.tensor(epsilon_cpu); auto layernorm_options = fe::graph::Layernorm_attributes().set_forward_phase(fe::NormFwdPhase_t::INFERENCE).set_epsilon(epsilon); @@ -157,7 +149,7 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { SKIP("ConvBNFprop requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -172,7 +164,6 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { Surface X_tensor(batch_size * seq_length * hidden_size, false); Surface Scale_tensor(hidden_size, false); Surface Bias_tensor(hidden_size, false); - float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); int64_t workspace_size; @@ -180,11 +171,7 @@ TEST_CASE("LayerNorm Inference", "[layernorm][graph]") { Surface workspace(workspace_size, false); std::unordered_map, void*> variant_pack = { - {X, X_tensor.devPtr}, - {scale, Scale_tensor.devPtr}, - {bias, Bias_tensor.devPtr}, - {epsilon, &epsilon_cpu}, - {Y, Y_tensor.devPtr}}; + {X, X_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, {Y, Y_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -240,7 +227,7 @@ TEST_CASE("LayerNorm Backward", "[layernorm][graph]") { SKIP("LayerNorm Backward requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); diff --git a/samples/cpp/norm/rmsnorm.cpp b/samples/cpp/norm/rmsnorm.cpp index 55871ddd..878086c1 100644 --- a/samples/cpp/norm/rmsnorm.cpp +++ b/samples/cpp/norm/rmsnorm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -45,12 +45,8 @@ TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { .set_stride({hidden_size, 1, hidden_size, hidden_size}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + float epsilon_cpu = 1e-05f; + auto epsilon = graph.tensor(epsilon_cpu); auto rmsnorm_options = fe::graph::Rmsnorm_attributes().set_forward_phase(fe::NormFwdPhase_t::TRAINING).set_epsilon(epsilon); @@ -65,7 +61,7 @@ TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { SKIP("RMSNorm requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -80,7 +76,6 @@ TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { Surface X_tensor(batch_size * seq_length * hidden_size, false); Surface Var_tensor(batch_size * seq_length, false); Surface Scale_tensor(hidden_size, false); - float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); int64_t workspace_size; @@ -88,11 +83,7 @@ TEST_CASE("RmsNorm Training", "[rmsnorm][graph]") { Surface workspace(workspace_size, false); std::unordered_map, void*> variant_pack = { - {X, X_tensor.devPtr}, - {inv_variance, Var_tensor.devPtr}, - {scale, Scale_tensor.devPtr}, - {epsilon, &epsilon_cpu}, - {Y, Y_tensor.devPtr}}; + {X, X_tensor.devPtr}, {inv_variance, Var_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {Y, Y_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -124,12 +115,8 @@ TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { .set_stride({hidden_size, 1, hidden_size, hidden_size}) .set_data_type(fe::DataType_t::FLOAT)); - auto epsilon = graph.tensor(fe::graph::Tensor_attributes() - .set_name("epsilon") - .set_dim({1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_is_pass_by_value(true) - .set_data_type(fe::DataType_t::FLOAT)); + float epsilon_cpu = 1e-05f; + auto epsilon = graph.tensor(epsilon_cpu); auto rmsnorm_options = fe::graph::Rmsnorm_attributes() .set_forward_phase(fe::NormFwdPhase_t::INFERENCE) @@ -146,7 +133,7 @@ TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { SKIP("RmsNorm requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); @@ -161,7 +148,6 @@ TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { Surface X_tensor(batch_size * seq_length * hidden_size, false); Surface Scale_tensor(hidden_size, false); Surface Bias_tensor(hidden_size, false); - float epsilon_cpu = 1e-05f; Surface Y_tensor(batch_size * seq_length * hidden_size, false); int64_t workspace_size; @@ -169,11 +155,7 @@ TEST_CASE("RmsNorm Inference", "[rmsnorm][graph]") { Surface workspace(workspace_size, false); std::unordered_map, void*> variant_pack = { - {X, X_tensor.devPtr}, - {scale, Scale_tensor.devPtr}, - {bias, Bias_tensor.devPtr}, - {epsilon, &epsilon_cpu}, - {Y, Y_tensor.devPtr}}; + {X, X_tensor.devPtr}, {scale, Scale_tensor.devPtr}, {bias, Bias_tensor.devPtr}, {Y, Y_tensor.devPtr}}; REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); @@ -224,7 +206,7 @@ TEST_CASE("RmsNorm Backward", "[rmsnorm][graph]") { SKIP("RmsNorm Backward requires Ampere and up"); } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); REQUIRE(graph.validate().is_good()); diff --git a/samples/cpp/sdpa/fp16_benchmark.cpp b/samples/cpp/sdpa/fp16_benchmark.cpp index 0bc51cfa..f6681f3d 100644 --- a/samples/cpp/sdpa/fp16_benchmark.cpp +++ b/samples/cpp/sdpa/fp16_benchmark.cpp @@ -22,7 +22,7 @@ #include #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -102,7 +102,7 @@ TEST_CASE("Benchmark sdpa graph API runtimes", "[graph][sdpa][flash]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); BENCHMARK_ADVANCED("Create")(Catch::Benchmark::Chronometer meter) { meter.measure([&] { auto g = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); }); diff --git a/samples/cpp/sdpa/fp16_bwd.cpp b/samples/cpp/sdpa/fp16_bwd.cpp index 6168cf64..749cbedb 100644 --- a/samples/cpp/sdpa/fp16_bwd.cpp +++ b/samples/cpp/sdpa/fp16_bwd.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -198,7 +198,7 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); // Create the SDPA backward graph auto graph = create_sdpa_backward_graph(b, @@ -260,15 +260,15 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { std::vector hostActualSeqlenQ(b, 20); std::vector hostActualSeqlenKV(b, 20); - checkCudaErr(cudaMemcpy(devActualSeqlenQ.devPtr, - hostActualSeqlenQ.data(), - sizeof(hostActualSeqlenQ[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaMemcpy(devActualSeqlenKV.devPtr, - hostActualSeqlenKV.data(), - sizeof(hostActualSeqlenKV[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(devActualSeqlenQ.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; @@ -281,7 +281,7 @@ TEST_CASE("Toy sdpa backward", "[graph][sdpa][flash][backward]") { REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } diff --git a/samples/cpp/sdpa/fp16_cached.cpp b/samples/cpp/sdpa/fp16_cached.cpp index 10711180..d0462713 100644 --- a/samples/cpp/sdpa/fp16_cached.cpp +++ b/samples/cpp/sdpa/fp16_cached.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -115,7 +115,7 @@ TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto fwd_graph = create_sdpa_forward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); auto bwd_graph = create_sdpa_backward_graph(b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v); @@ -151,7 +151,7 @@ TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { Surface fwd_workspace(workspace_size, false); REQUIRE(fwd_graph2->execute(handle, variant_pack, fwd_workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); Surface dO_tensor(b * h_q * s_q * d_qk, false); Surface dQ_tensor(b * h_q * s_q * d_qk, false); @@ -175,7 +175,7 @@ TEST_CASE("Cached sdpa", "[graph][sdpa][flash]") { REQUIRE(bwd_graph2->execute(handle, variant_pack, bwd_workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } diff --git a/samples/cpp/sdpa/fp16_fwd.cpp b/samples/cpp/sdpa/fp16_fwd.cpp index 66344025..b3acf5e5 100644 --- a/samples/cpp/sdpa/fp16_fwd.cpp +++ b/samples/cpp/sdpa/fp16_fwd.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -151,7 +151,7 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = create_sdpa_forward_graph(b, h_q, @@ -191,15 +191,15 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { std::vector hostActualSeqlenQ(b, 20); std::vector hostActualSeqlenKV(b, 20); - checkCudaErr(cudaMemcpy(devActualSeqlenQ.devPtr, - hostActualSeqlenQ.data(), - sizeof(hostActualSeqlenQ[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaMemcpy(devActualSeqlenKV.devPtr, - hostActualSeqlenKV.data(), - sizeof(hostActualSeqlenKV[0]) * b, - cudaMemcpyHostToDevice)); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(devActualSeqlenQ.devPtr, + hostActualSeqlenQ.data(), + sizeof(hostActualSeqlenQ[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; @@ -216,7 +216,7 @@ TEST_CASE("Toy sdpa forward", "[graph][sdpa][flash][forward]") { REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } diff --git a/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp b/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp index 5c70151a..36cfba40 100644 --- a/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp +++ b/samples/cpp/sdpa/fp16_fwd_with_custom_dropout.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include @@ -30,7 +30,7 @@ namespace fe = cudnn_frontend; /* Run this example by using command: -bin/samples "Toy sdpa forward" +bin/samples "Toy sdpa forward with dropout" This example shows how to construct a sdpa forward graph. */ @@ -146,7 +146,7 @@ TEST_CASE("Toy sdpa forward with dropout", "[graph][sdpa][flash][forward]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto graph = create_sdpa_forward_graph_with_custom_dropout( b, h_q, h_k, h_v, s_q, s_kv, d_qk, d_v, attn_scale, is_inference, causal_mask, has_attn_bias); @@ -184,7 +184,7 @@ TEST_CASE("Toy sdpa forward with dropout", "[graph][sdpa][flash][forward]") { REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } diff --git a/samples/cpp/sdpa/fp16_fwd_with_paged_caches.cpp b/samples/cpp/sdpa/fp16_fwd_with_paged_caches.cpp new file mode 100644 index 00000000..18dd9378 --- /dev/null +++ b/samples/cpp/sdpa/fp16_fwd_with_paged_caches.cpp @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +#include +namespace fe = cudnn_frontend; + +#include + +/* +Run this example by using command: +bin/samples "Toy sdpa forward with paged caches" + +This example shows how to construct a sdpa forward graph with paged caches. +*/ + +// Tensors in forward pass +#define Q_UID 1 +#define K_UID 2 +#define V_UID 3 +#define O_UID 4 +#define STATS_UID 5 +#define BIAS_UID 6 +#define SEQ_LEN_Q_UID 7 +#define SEQ_LEN_KV_UID 8 +#define PAGE_TABLE_K_UID 9 +#define PAGE_TABLE_V_UID 10 + +std::shared_ptr +create_sdpa_forward_graph_with_paged_caches(int64_t const b, + int64_t const h_q, + int64_t const h_k, + int64_t const h_v, + int64_t const s_q, + int64_t const s_kv, + int64_t const d_qk, + int64_t const d_v, + int64_t const block_size, + int64_t const num_blocks_k, + int64_t const num_blocks_v, + int64_t const table_size, + float const attn_scale = 1.0f, + bool const is_inference = false, + bool const causal_mask = false, + bool const alibi_mask = false, + bool has_attn_bias = false) { + // Create a graph and set common global properties. + auto graph = std::make_shared(); + graph->set_io_data_type(fe::DataType_t::BFLOAT16) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_uid(Q_UID) + .set_dim({b, h_q, s_q, d_qk}) + .set_stride({h_q * s_q * d_qk, s_q * d_qk, d_qk, 1})); + + auto K = graph->tensor(fe::graph::Tensor_attributes() + .set_name("container_K") + .set_uid(K_UID) + .set_dim({num_blocks_k, h_k, block_size, d_qk}) + .set_stride({h_k * block_size * d_qk, block_size * d_qk, d_qk, 1})); + + auto V = graph->tensor(fe::graph::Tensor_attributes() + .set_name("container_V") + .set_uid(V_UID) + .set_dim({num_blocks_v, h_v, block_size, d_v}) + .set_stride({h_v * block_size * d_v, block_size * d_v, d_v, 1})); + + auto sdpa_options = fe::graph::SDPA_attributes() + .set_name("flash_attention") + .set_is_inference(is_inference) + .set_alibi_mask(alibi_mask) + .set_causal_mask(causal_mask) + .set_attn_scale(attn_scale); + + if (has_attn_bias) { + auto bias = graph->tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_uid(BIAS_UID) + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + sdpa_options.set_bias(bias); + } + + // Setup padding mask + auto seq_q = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_uid(SEQ_LEN_Q_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + auto seq_kv = graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_uid(SEQ_LEN_KV_UID) + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(true).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + + auto page_table_k = graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_k") + .set_uid(PAGE_TABLE_K_UID) + .set_dim({b, 1, table_size, 1}) + .set_stride({{table_size, table_size, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + auto page_table_v = graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_v") + .set_uid(PAGE_TABLE_V_UID) + .set_dim({b, 1, table_size, 1}) + .set_stride({{table_size, table_size, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + + sdpa_options.set_paged_attention_k_table(page_table_k); + sdpa_options.set_paged_attention_v_table(page_table_v); + sdpa_options.set_paged_attention_max_seq_len_kv(static_cast(s_kv)); + + auto [O, Stats] = graph->sdpa(Q, K, V, sdpa_options); + + O->set_output(true).set_dim({b, h_q, s_q, d_v}).set_stride({h_q * d_v, d_v, b * h_q * d_v, 1}).set_uid(O_UID); + + if (is_inference) { + assert(Stats == nullptr); + } else { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_uid(STATS_UID); + } + + return graph; +} + +TEST_CASE("Toy sdpa forward with paged caches", "[graph][sdpa][flash][paged][forward]") { + int64_t b = 3; // batch size + int64_t h_q = 4; // head dim + int64_t h_k = 4; // head dim + int64_t h_v = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d_qk = 128; // hidden dim + int64_t d_v = 128; // hidden dim + int64_t block_size = 64; // block size for paged attention + int64_t num_blocks_k = ((s_kv + block_size - 1) / block_size) * b; // Number of blocks in container_k + int64_t num_blocks_v = ((s_kv + block_size - 1) / block_size) * b; // Number of blocks in container_v + int64_t page_table_size = (s_kv + block_size - 1) / block_size; // per-batch size of the page tables + bool is_inference = false; + float attn_scale = 0.123f; + bool causal_mask = true; + bool alibi_mask = false; + bool has_attn_bias = false; + + if (cudnnGetVersion() < 90500) { + SKIP("Test requires cudnn 9.5.0 or above"); + return; + } + + cudnnHandle_t handle; + CUDNN_CHECK(cudnnCreate(&handle)); + + auto graph = create_sdpa_forward_graph_with_paged_caches(b, + h_q, + h_k, + h_v, + s_q, + s_kv, + d_qk, + d_v, + block_size, + num_blocks_k, + num_blocks_v, + page_table_size, + attn_scale, + is_inference, + causal_mask, + alibi_mask, + has_attn_bias); + + REQUIRE(graph->build(handle, {fe::HeurMode_t::A}).is_good()); + + //// Build variant pack + Surface q_tensor(b * h_q * s_q * d_qk, false); + Surface k_container_tensor(num_blocks_k * h_k * d_qk * block_size, false); + Surface v_container_tensor(num_blocks_v * h_v * d_v * block_size, false); + + Surface o_tensor(b * s_q * h_q * d_qk, false); + + Surface page_table_k_tensor(b * page_table_size, false); + Surface page_table_v_tensor(b * page_table_size, false); + + std::vector host_page_table_k(b * page_table_size); + std::vector host_page_table_v(b * page_table_size); + + // Initialize the page tables + std::mt19937 rng; + std::uniform_int_distribution distribution(0, int32_t(std::min(num_blocks_k, num_blocks_v)) - 1); + + for (auto& elem : host_page_table_k) { + elem = distribution(rng); + } + for (auto& elem : host_page_table_v) { + elem = distribution(rng); + } + + CUDA_CHECK(cudaMemcpy(page_table_k_tensor.devPtr, + host_page_table_k.data(), + sizeof(host_page_table_k[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(page_table_v_tensor.devPtr, + host_page_table_v.data(), + sizeof(host_page_table_v[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + std::unordered_map variant_pack = { + {Q_UID, q_tensor.devPtr}, + {K_UID, k_container_tensor.devPtr}, + {V_UID, v_container_tensor.devPtr}, + {O_UID, o_tensor.devPtr}, + {PAGE_TABLE_K_UID, page_table_k_tensor.devPtr}, + {PAGE_TABLE_V_UID, page_table_v_tensor.devPtr}}; + + Surface bias_tensor(b * 1 * s_q * s_kv, false); + if (has_attn_bias) { + variant_pack[BIAS_UID] = bias_tensor.devPtr; + } + + // Create variable sequence lengths + Surface devActualSeqlenQ(b, false); + Surface devActualSeqlenKV(b, false); + std::vector hostActualSeqlenQ(b, 20); + std::vector hostActualSeqlenKV(b, 20); + + CUDA_CHECK(cudaMemcpy( + devActualSeqlenQ.devPtr, hostActualSeqlenQ.data(), sizeof(hostActualSeqlenQ[0]) * b, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(devActualSeqlenKV.devPtr, + hostActualSeqlenKV.data(), + sizeof(hostActualSeqlenKV[0]) * b, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + + variant_pack[SEQ_LEN_Q_UID] = devActualSeqlenQ.devPtr; + variant_pack[SEQ_LEN_KV_UID] = devActualSeqlenKV.devPtr; + + Surface statsTensor(b * h_q * s_q * 1, false); + if (is_inference == false) { + variant_pack[STATS_UID] = statsTensor.devPtr; + } + + int64_t workspace_size; + REQUIRE(graph->get_workspace_size(workspace_size).is_good()); + Surface workspace(workspace_size, false); + + REQUIRE(graph->execute(handle, variant_pack, workspace.devPtr).is_good()); + + CUDA_CHECK(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} diff --git a/samples/cpp/sdpa/fp8_bwd.cpp b/samples/cpp/sdpa/fp8_bwd.cpp index 487aed2d..82e542b6 100644 --- a/samples/cpp/sdpa/fp8_bwd.cpp +++ b/samples/cpp/sdpa/fp8_bwd.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include #include @@ -133,7 +133,7 @@ TEST_CASE("sdpa_fp8_bprop", "[graph][sdpa][fp8][backward]") { Amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_stride({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto status = mha_graph.validate(); if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { @@ -220,7 +220,7 @@ TEST_CASE("sdpa_fp8_bprop", "[graph][sdpa][fp8][backward]") { REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } @@ -304,7 +304,7 @@ TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][sdpa][fp8][backward]") { amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_stride({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto status = mha_graph.validate(); if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { @@ -391,7 +391,7 @@ TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][sdpa][fp8][backward]") { REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } \ No newline at end of file diff --git a/samples/cpp/sdpa/fp8_fwd.cpp b/samples/cpp/sdpa/fp8_fwd.cpp index 0426d0a3..6ede98d1 100644 --- a/samples/cpp/sdpa/fp8_fwd.cpp +++ b/samples/cpp/sdpa/fp8_fwd.cpp @@ -21,7 +21,7 @@ */ #include -#include "../../utils/helpers.h" +#include "../utils/helpers.h" #include #include @@ -94,7 +94,7 @@ TEST_CASE("sdpa_fp8_fprop", "[graph][sdpa][fp8][forward]") { } cudnnHandle_t handle; - checkCudnnErr(cudnnCreate(&handle)); + CUDNN_CHECK(cudnnCreate(&handle)); auto status = mha_graph.validate(); if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { @@ -152,7 +152,7 @@ TEST_CASE("sdpa_fp8_fprop", "[graph][sdpa][fp8][forward]") { REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); - checkCudaErr(cudaDeviceSynchronize()); + CUDA_CHECK(cudaDeviceSynchronize()); cudnnDestroy(handle); } \ No newline at end of file diff --git a/samples/cpp/utils/helpers.h b/samples/cpp/utils/helpers.h new file mode 100644 index 00000000..76bb7fa9 --- /dev/null +++ b/samples/cpp/utils/helpers.h @@ -0,0 +1,348 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +#define CUDA_CHECK(status) \ + { \ + cudaError_t err = status; \ + if (err != cudaSuccess) { \ + std::stringstream err_msg; \ + err_msg << "CUDA Error: " << cudaGetErrorString(err) << " (" << err << ") at " << __FILE__ << ":" \ + << __LINE__; \ + FAIL(err_msg.str()); \ + } \ + } + +#define CUDNN_CHECK(status) \ + { \ + cudnnStatus_t err = status; \ + if (err != CUDNN_STATUS_SUCCESS) { \ + std::stringstream err_msg; \ + err_msg << "cuDNN Error: " << cudnnGetErrorString(err) << " (" << err << ") at " << __FILE__ << ":" \ + << __LINE__; \ + FAIL(err_msg.str()); \ + } \ + } + +inline size_t +get_compute_capability() { + struct cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, 0)); + return prop.major * 10 + prop.minor; +} + +inline bool +is_ampere_arch() { + auto cc = get_compute_capability(); + return (80 <= cc) && (cc < 89); +} + +inline bool +is_ada_arch() { + auto cc = get_compute_capability(); + return (cc == 89); +} + +inline bool +is_hopper_arch() { + auto cc = get_compute_capability(); + return (90 <= cc); +} + +inline bool +is_arch_supported_by_cudnn() { + if (cudnnGetVersion() < 8600 && (is_hopper_arch() || is_ada_arch())) { + return false; + } + return true; +} + +inline bool +check_device_arch_newer_than(std::string const& arch) { + size_t arch_major = 6; + size_t arch_minor = 0; + if (arch == "hopper") { + arch_major = 9; + } + if (arch == "ampere") { + arch_major = 8; + } + if (arch == "turing") { + arch_major = 7; + arch_minor = 5; + } + if (arch == "volta") { + arch_major = 7; + } + if (arch == "pascal") { + arch_major = 6; + } + + auto queried_version = arch_major * 10 + arch_minor; + if (get_compute_capability() >= queried_version) { + return true; + } + return false; +} + +static half +cpu_float2half_rn(float f) { + void* f_ptr = &f; + unsigned x = *((int*)f_ptr); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + __half_raw hr; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + hr.x = 0x7fffU; + // Add an indirection to get around type aliasing check + void* hr_ptr = &hr; + return *reinterpret_cast(hr_ptr); + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + hr.x = static_cast(sign | 0x7c00U); + // Add an indirection to get around type aliasing check + void* hr_ptr = &hr; + return *reinterpret_cast(hr_ptr); + } + if (u < 0x33000001) { + hr.x = static_cast(sign | 0x0000U); + // Add an indirection to get around type aliasing check + void* hr_ptr = &hr; + return *reinterpret_cast(hr_ptr); + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + hr.x = static_cast((sign | (exponent << 10) | mantissa)); + + // Add an indirection to get around type aliasing check + void* hr_ptr = &hr; + return *reinterpret_cast(hr_ptr); +} + +static float +cpu_half2float(half h) { + // Add an indirection to get around type aliasing check + void* h_ptr = &h; + __half_raw hr = *reinterpret_cast<__half_raw*>(h_ptr); + + unsigned sign = ((hr.x >> 15) & 1); + unsigned exponent = ((hr.x >> 10) & 0x1f); + unsigned mantissa = ((hr.x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + int temp = ((sign << 31) | (exponent << 23) | mantissa); + + // Add an indirection to get around type aliasing check + void* temp_ptr = &temp; + float* res_ptr = reinterpret_cast(temp_ptr); + return *res_ptr; +} + +// Generate uniform numbers [0,1) +static void +initImage(float* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + image[index] = float(seed) * 2.3283064e-10f; // 2^-32 + } +} + +static void +initImage(half* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + image[index] = cpu_float2half_rn(float(seed) * 2.3283064e-10f); // 2^-32 + } +} + +// Currently set to generate uniform integers [-2, 2] to avoid int8 overflow +static void +initImage(int8_t* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + // Takes floats from [0, 1), scales and casts to ints from [0, 4], then subtracts from 2 + image[index] = 2 - (int8_t)(5 * float(seed) * 2.3283064e-10f); // 2^-32 + } +} + +// Currently set to generate random integers [0, 50] to avoid uint8 overflow +static void +initImage(uint8_t* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + // Takes floats from [0, 1), scales and casts to ints from [0, 50] + image[index] = (uint8_t)(50 * float(seed) * 2.3283064e-10f); // 2^-32 + } +} + +// Currently set to generate uniform integers [0,1] +static void +initImage(int32_t* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + // Takes floats from [0, 1), scales and casts to ints from [0, 4], then divides by 4 + image[index] = ((int32_t)(5.f * float(seed) * 2.3283064e-10f)) / 4; // 2^-32 + } +} + +// Currently set to generate uniform integers [0,1] +static void +initImage(int64_t* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + // Takes floats from [0, 1), scales and casts to ints from [0, 4], then divides by 4 + image[index] = ((int64_t)(5.f * float(seed) * 2.3283064e-10f)) / 4; // 2^-32 + } +} + +// Currently set to generate booleans +static void +initImage(bool* image, int64_t imageSize) { + static unsigned seed = 123456789; + for (int64_t index = 0; index < imageSize; index++) { + seed = (1103515245 * seed + 12345) & 0xffffffff; + // Takes floats from [0, 1), scales and casts to ints from [0, 4], then divides by 4 + int64_t val = ((int32_t)(5.f * float(seed) * 2.3283064e-10f)) / 4; // 2^-32 + + // val is 0 or 1 + image[index] = (val == 1); + } +} + +template +struct Surface { + T_ELEM* devPtr = NULL; + T_ELEM* hostPtr = NULL; + int64_t n_elems = 0; + + protected: + explicit Surface() {} + + public: + explicit Surface(int64_t n_elems, [[maybe_unused]] bool hasRef) : n_elems(n_elems) { + CUDA_CHECK(cudaMalloc((void**)&(devPtr), (size_t)((n_elems) * sizeof(devPtr[0])))); + hostPtr = (T_ELEM*)calloc((size_t)n_elems, sizeof(hostPtr[0])); + initImage(hostPtr, n_elems); + CUDA_CHECK(cudaMemcpy(devPtr, hostPtr, size_t(sizeof(hostPtr[0]) * n_elems), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + explicit Surface(int64_t n_elems, [[maybe_unused]] bool hasRef, bool isInterleaved) { + (void)isInterleaved; + CUDA_CHECK(cudaMalloc((void**)&(devPtr), (n_elems) * sizeof(devPtr[0]))); + hostPtr = (T_ELEM*)calloc(n_elems, sizeof(hostPtr[0])); + initImage(hostPtr, n_elems); + uint32_t* temp = (uint32_t*)hostPtr; + for (auto i = 0; i < n_elems; i = i + 2) { + temp[i + 1] = 1u; + } + + CUDA_CHECK(cudaMemcpy(devPtr, hostPtr, size_t(sizeof(hostPtr[0]) * n_elems), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + explicit Surface(int64_t size, [[maybe_unused]] bool hasRef, T_ELEM fillValue) : n_elems(size) { + CUDA_CHECK(cudaMalloc((void**)&(devPtr), (size) * sizeof(devPtr[0]))); + hostPtr = (T_ELEM*)calloc(size, sizeof(hostPtr[0])); + for (int i = 0; i < size; i++) { + hostPtr[i] = fillValue; + } + CUDA_CHECK(cudaMemcpy(devPtr, hostPtr, sizeof(hostPtr[0]) * n_elems, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + Surface(const Surface& other) : n_elems(n_elems) { + CUDA_CHECK(cudaMalloc((void**)&(devPtr), (size_t)((n_elems) * sizeof(devPtr[0])))); + hostPtr = (T_ELEM*)calloc((size_t)n_elems, sizeof(hostPtr[0])); + std::copy(other.hostPtr, other.hostPtr + n_elems, hostPtr); + CUDA_CHECK(cudaMemcpy(devPtr, hostPtr, size_t(sizeof(hostPtr[0]) * n_elems), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + Surface(Surface&& other) noexcept : Surface() { swap(*this, other); } + + Surface& + operator=(Surface other) { + swap(*this, other); + return *this; + } + + friend void + swap(Surface& first, Surface& second) { + std::swap(first.n_elems, second.n_elems); + std::swap(first.hostPtr, second.hostPtr); + std::swap(first.devPtr, second.devPtr); + } + + ~Surface() { + if (devPtr) { + cudaFree(devPtr); + devPtr = nullptr; + } + if (hostPtr) { + free(hostPtr); + hostPtr = nullptr; + } + } +}; diff --git a/samples/legacy_samples/CMakeLists.txt b/samples/legacy_samples/CMakeLists.txt new file mode 100644 index 00000000..019f17c0 --- /dev/null +++ b/samples/legacy_samples/CMakeLists.txt @@ -0,0 +1,54 @@ +# target sources +add_executable( + legacy_samples + + conv_sample.cpp + test_list.cpp + fp16_emu.cpp + helpers.cpp + fusion_sample.cpp + fp8_sample.cpp + norm_samples.cpp + fused_mha_sample.cpp + f16_flash_mha_sample.cpp + fp8_flash_mha_sample.cpp +) + +# target flags +if(MSVC) + target_compile_options( + legacy_samples PRIVATE + /W4 /WX # warning level 3 and all warnings as errors + /wd4100 # allow unused parameters + /wd4458 # local hides class member (currently a problem for all inline setters) + /wd4505 # unreferenced function with internal linkage has been removed + /wd4101 /wd4189 # unreferenced local + /bigobj # increase number of sections in .Obj file + ) +else() + target_compile_options( + legacy_samples PRIVATE + -Wall + -Wextra + -Werror + -Wno-unused-function + ) +endif() + +# target links +target_link_libraries( + legacy_samples PRIVATE + Threads::Threads + Catch2::Catch2WithMain + cudnn_frontend + _cudnn_frontend_pch + CUDNN::cudnn + + CUDA::cudart +) + +# target cmake properties +set_target_properties( + legacy_samples PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin +) diff --git a/samples/legacy_samples/conv_sample.h b/samples/legacy_samples/conv_sample.h index 7563bca5..c51c3db1 100644 --- a/samples/legacy_samples/conv_sample.h +++ b/samples/legacy_samples/conv_sample.h @@ -34,9 +34,9 @@ #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" void run_from_global_index(int64_t* dimA_padded, diff --git a/samples/legacy_samples/cpu_references.h b/samples/legacy_samples/cpu_references.h index 461187b2..2c9bafcb 100644 --- a/samples/legacy_samples/cpu_references.h +++ b/samples/legacy_samples/cpu_references.h @@ -22,7 +22,7 @@ #pragma once -#include "../utils/helpers.h" +#include "./utils/helpers.h" #include template diff --git a/samples/legacy_samples/f16_flash_mha_sample.cpp b/samples/legacy_samples/f16_flash_mha_sample.cpp index 6a1dffdb..5e2f77b0 100644 --- a/samples/legacy_samples/f16_flash_mha_sample.cpp +++ b/samples/legacy_samples/f16_flash_mha_sample.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -22,7 +22,7 @@ #include "f16_flash_mha_sample.h" #include -#include "../utils/error_util.h" +#include "./utils/error_util.h" #define Q_ID 1 #define K_ID 2 diff --git a/samples/legacy_samples/f16_flash_mha_sample.h b/samples/legacy_samples/f16_flash_mha_sample.h index 9c91ed45..f7403a54 100644 --- a/samples/legacy_samples/f16_flash_mha_sample.h +++ b/samples/legacy_samples/f16_flash_mha_sample.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -33,9 +33,9 @@ #include #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" #if (CUDNN_VERSION >= 8900) diff --git a/samples/legacy_samples/fp16_dev.cu b/samples/legacy_samples/fp16_dev.cu index 3955feec..d2daa2a0 100644 --- a/samples/legacy_samples/fp16_dev.cu +++ b/samples/legacy_samples/fp16_dev.cu @@ -20,8 +20,8 @@ * DEALINGS IN THE SOFTWARE. */ -#include "utils/error_util.h" -#include "utils/fp16_dev.h" +#include "./utils/error_util.h" +#include "./utils/fp16_dev.h" #define BLOCK_SIZE 128 template diff --git a/samples/legacy_samples/fp16_emu.cpp b/samples/legacy_samples/fp16_emu.cpp index a345ad3b..195c5c99 100644 --- a/samples/legacy_samples/fp16_emu.cpp +++ b/samples/legacy_samples/fp16_emu.cpp @@ -20,7 +20,7 @@ * DEALINGS IN THE SOFTWARE. */ -#include "../utils/fp16_emu.h" +#include "./utils/fp16_emu.h" #define STATIC_ASSERT(cond) \ { static_assert(cond, "static_assert failed."); } diff --git a/samples/legacy_samples/fp8_flash_mha_sample.cpp b/samples/legacy_samples/fp8_flash_mha_sample.cpp index b84d55c9..d99bb48d 100644 --- a/samples/legacy_samples/fp8_flash_mha_sample.cpp +++ b/samples/legacy_samples/fp8_flash_mha_sample.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -22,7 +22,7 @@ #include "fp8_flash_mha_sample.h" #include -#include "../utils/error_util.h" +#include "./utils/error_util.h" #if (CUDNN_VERSION >= 8900) std::unordered_map tensor_name_to_uid = {{"Q", 1}, diff --git a/samples/legacy_samples/fp8_flash_mha_sample.h b/samples/legacy_samples/fp8_flash_mha_sample.h index 78978d8c..00a80bb0 100644 --- a/samples/legacy_samples/fp8_flash_mha_sample.h +++ b/samples/legacy_samples/fp8_flash_mha_sample.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -33,9 +33,9 @@ #include #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" #if (CUDNN_VERSION >= 8900) void diff --git a/samples/legacy_samples/fp8_sample.cpp b/samples/legacy_samples/fp8_sample.cpp index 6352ed50..891185ad 100644 --- a/samples/legacy_samples/fp8_sample.cpp +++ b/samples/legacy_samples/fp8_sample.cpp @@ -1,6 +1,6 @@ #include "fp8_sample.h" #include -#include "../utils/error_util.h" +#include "./utils/error_util.h" using namespace cudnn_frontend; diff --git a/samples/legacy_samples/fp8_sample.h b/samples/legacy_samples/fp8_sample.h index ef1d162e..6c459cd4 100644 --- a/samples/legacy_samples/fp8_sample.h +++ b/samples/legacy_samples/fp8_sample.h @@ -33,9 +33,9 @@ #include #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" void run_fp8_conv_scale(int64_t* x_dim, diff --git a/samples/legacy_samples/fused_mha_sample.cpp b/samples/legacy_samples/fused_mha_sample.cpp index d9ee7788..3fc6e77e 100644 --- a/samples/legacy_samples/fused_mha_sample.cpp +++ b/samples/legacy_samples/fused_mha_sample.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -22,7 +22,7 @@ #include "fused_mha_sample.h" #include -#include "../utils/error_util.h" +#include "./utils/error_util.h" #define Q_ID 1 #define K_ID 2 diff --git a/samples/legacy_samples/fused_mha_sample.h b/samples/legacy_samples/fused_mha_sample.h index 39c7fc35..e8fc5695 100644 --- a/samples/legacy_samples/fused_mha_sample.h +++ b/samples/legacy_samples/fused_mha_sample.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -33,9 +33,9 @@ #include #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" #if (CUDNN_VERSION >= 8700) void diff --git a/samples/legacy_samples/fusion_sample.cpp b/samples/legacy_samples/fusion_sample.cpp index 1e5ddcbe..d285d90a 100644 --- a/samples/legacy_samples/fusion_sample.cpp +++ b/samples/legacy_samples/fusion_sample.cpp @@ -22,7 +22,7 @@ #include "fusion_sample.h" #include -#include "../utils/error_util.h" +#include "./utils/error_util.h" bool allowAll(cudnnBackendDescriptor_t engine_config) { diff --git a/samples/legacy_samples/fusion_sample.h b/samples/legacy_samples/fusion_sample.h index 42a47343..83309a9f 100644 --- a/samples/legacy_samples/fusion_sample.h +++ b/samples/legacy_samples/fusion_sample.h @@ -33,9 +33,9 @@ #include #include -#include "../utils/fp16_dev.h" -#include "../utils/fp16_emu.h" -#include "../utils/helpers.h" +#include "./utils/fp16_dev.h" +#include "./utils/fp16_emu.h" +#include "./utils/helpers.h" #include diff --git a/samples/legacy_samples/helpers.cpp b/samples/legacy_samples/helpers.cpp index ac0abc4f..3bfb9ef3 100644 --- a/samples/legacy_samples/helpers.cpp +++ b/samples/legacy_samples/helpers.cpp @@ -20,7 +20,7 @@ * DEALINGS IN THE SOFTWARE. */ -#include "../utils/helpers.h" +#include "./utils/helpers.h" size_t get_compute_capability() { diff --git a/samples/legacy_samples/norm_samples.cpp b/samples/legacy_samples/norm_samples.cpp index 5b07e89a..cbe1e783 100644 --- a/samples/legacy_samples/norm_samples.cpp +++ b/samples/legacy_samples/norm_samples.cpp @@ -23,8 +23,8 @@ #include "norm_samples.h" #include -#include "../utils/error_util.h" -#include "../utils/helpers.h" +#include "./utils/error_util.h" +#include "./utils/helpers.h" bool AllowAll(cudnnBackendDescriptor_t engine_config) { diff --git a/samples/legacy_samples/test_list.cpp b/samples/legacy_samples/test_list.cpp index 0e42f3f9..93747d0f 100644 --- a/samples/legacy_samples/test_list.cpp +++ b/samples/legacy_samples/test_list.cpp @@ -112,19 +112,12 @@ TEST_CASE("Use global(index) for execution", "[frontend][global_index][wgrad]") } cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -193,18 +186,12 @@ TEST_CASE("Use heuristics for execution", "[frontend][heuristics][conv]") { cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -275,18 +262,12 @@ TEST_CASE("Use DNN based heuristics for execution", "[frontend][dnn_heuristics][ cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -357,18 +338,12 @@ TEST_CASE("Use fallback for execution", "[frontend][global_index][dgrad]") { cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -435,22 +410,15 @@ TEST_CASE("ConvBiasAct sample", "[frontend][convAddBiasAct]") { getFwdConvOutputDim(xTensorDim[dim + 2], padding[dim], wTensorDim[dim + 2], convstride[dim], dilation[dim]); } - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Xsize = xTensorDim[0] * xTensorDim[1] * xTensorDim[2] * xTensorDim[3]; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -500,18 +468,12 @@ TEST_CASE("Use cudnnFindPlan for execution", "[frontend][cudnnFindPlan][conv]") cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -577,22 +539,15 @@ TEST_CASE("ConvBiasAct sample with cudnnFindPlan", "[frontend][cudnnFindPlan][co getFwdConvOutputDim(xTensorDim[dim + 2], padding[dim], wTensorDim[dim + 2], convstride[dim], dilation[dim]); } - printf("====PADDING DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Xsize = xTensorDim[0] * xTensorDim[1] * xTensorDim[2] * xTensorDim[3]; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -642,18 +597,12 @@ TEST_CASE("Use cudnnGetPlan for execution", "[frontend][cudnnGetPlan][conv]") { cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -717,22 +666,15 @@ TEST_CASE("ConvScaleBiasAddAct sample", "[frontend][fusion][ConvScaleBiasAddAct] int64_t bTensorDim[] = {1, 32, 1, 1}; // bias int64_t aTensorDim[] = {4, 32, 31, 31}; // add - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -784,22 +726,15 @@ TEST_CASE("ConvScaleBiasAddAct sample_float", "[frontend][fusion][ConvScaleBiasA int64_t bTensorDim[] = {1, 32, 1, 1}; // bias int64_t aTensorDim[] = {4, 32, 512, 512}; // add - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -849,23 +784,15 @@ TEST_CASE("ConvBiasScaleAct sample", "[frontend][fusion][ConvBiasScaleAct]") { int64_t bTensorDim[] = {1, 64, 1, 1}; // bias int64_t sTensorDim[] = {1, 64, 1, 1}; // scale + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -912,23 +839,15 @@ TEST_CASE("ConvBiasScaleActSerialization sample", "[frontend][fusion][serializat int64_t bTensorDim[] = {1, 64, 1, 1}; // bias int64_t sTensorDim[] = {1, 64, 1, 1}; // scale + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -977,23 +896,15 @@ TEST_CASE("ConvScaleBiasActGenIndexSelection sample", "[frontend][fusion][ConvSc int64_t sTensorDim[] = {1, 64, 1, 1}; // scale int64_t thresholdTensorDim[] = {1, 1, 1, 1}; // scalar number + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -1057,22 +968,15 @@ TEST_CASE("ConvScaleBiasAct_int8 sample", "[frontend][fusion][ConvScaleBiasAct_i int64_t bTensorDim[] = {1, 256, 1, 1}; // bias int64_t sTensorDim[] = {1, 256, 1, 1}; // scale - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -1141,18 +1045,12 @@ TEST_CASE("PoolScaleBiasAct_int8 sample", "[pooling][forward][avgerage_pooling]" int64_t postPaddingA[CUDNN_DIM_MAX] = {0, 0}; int64_t strideA[CUDNN_DIM_MAX] = {2, 2}; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -1197,11 +1095,12 @@ TEST_CASE("MatmulBiasAct sample", "[frontend][fusion][MatmulBiasAct]") { int64_t zTensorDim[] = {1, 1, 64}; // bias - printf("====DIMENSIONS====\n"); - printf("a matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", aTensorDim[0], aTensorDim[1], aTensorDim[2]); - printf("b matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", bTensorDim[0], bTensorDim[1], bTensorDim[2]); - printf("c matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", cTensorDim[0], cTensorDim[1], cTensorDim[2]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "a matrix dims are " << aTensorDim[0] << ", " << aTensorDim[1] << ", " << aTensorDim[2] << std::endl; + + std::cout << "b matrix dims are " << bTensorDim[0] << ", " << bTensorDim[1] << ", " << bTensorDim[2] << std::endl; + std::cout << "c matrix dims are " << cTensorDim[0] << ", " << cTensorDim[1] << ", " << cTensorDim[2] << std::endl; int64_t Csize = cTensorDim[0] * cTensorDim[1] * cTensorDim[2]; Surface A(aTensorDim[0] * aTensorDim[1] * aTensorDim[2], false); @@ -1240,11 +1139,12 @@ TEST_CASE("MatmulBiasAct sample_float", "[frontend][fusion][MatmulBiasAct]") { int64_t zTensorDim[] = {1, 1, 64}; // bias - printf("====DIMENSIONS====\n"); - printf("a matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", aTensorDim[0], aTensorDim[1], aTensorDim[2]); - printf("b matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", bTensorDim[0], bTensorDim[1], bTensorDim[2]); - printf("c matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", cTensorDim[0], cTensorDim[1], cTensorDim[2]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "a matrix dims are " << aTensorDim[0] << ", " << aTensorDim[1] << ", " << aTensorDim[2] << std::endl; + std::cout << "b matrix dims are " << bTensorDim[0] << ", " << bTensorDim[1] << ", " << bTensorDim[2] << std::endl; + + std::cout << "c matrix dims are " << cTensorDim[0] << ", " << cTensorDim[1] << ", " << cTensorDim[2] << std::endl; int64_t Csize = cTensorDim[0] * cTensorDim[1] * cTensorDim[2]; Surface A(aTensorDim[0] * aTensorDim[1] * aTensorDim[2], false); @@ -1283,13 +1183,14 @@ TEST_CASE("MatmulDGeluDBias sample", "[frontend][fusion][MatmulDGeluDBias]") { int64_t cTensorDim[] = {1, 2048, 4096}; // batch M N int64_t zTensorDim[] = {1, 1, 4096}; // bias + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "a matrix dims are " << aTensorDim[0] << ", " << aTensorDim[1] << ", " << aTensorDim[2] << std::endl; + + std::cout << "b matrix dims are " << bTensorDim[0] << ", " << bTensorDim[1] << ", " << bTensorDim[2] << std::endl; - printf("====DIMENSIONS====\n"); - printf("a matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", aTensorDim[0], aTensorDim[1], aTensorDim[2]); - printf("b matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", bTensorDim[0], bTensorDim[1], bTensorDim[2]); - printf("c matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", cTensorDim[0], cTensorDim[1], cTensorDim[2]); - printf("z matrix dims are %" PRId64 ", %" PRId64 ", %" PRId64 "\n", zTensorDim[0], zTensorDim[1], zTensorDim[2]); + std::cout << "c matrix dims are " << cTensorDim[0] << ", " << cTensorDim[1] << ", " << cTensorDim[2] << std::endl; + std::cout << "z matrix dims are " << zTensorDim[0] << ", " << zTensorDim[1] << ", " << zTensorDim[2] << std::endl; int64_t Csize = cTensorDim[0] * cTensorDim[1] * cTensorDim[2]; int64_t Zsize = zTensorDim[0] * zTensorDim[1] * zTensorDim[2]; @@ -1345,22 +1246,15 @@ TEST_CASE("ConvDrelu sample", "[frontend][convDrelu][drelu]") { wTensorDim_padded[i] = wTensorDim[i]; } - printf("====PADDING DIMENSIONS====\n"); - printf("padded input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim_padded[0], - xTensorDim_padded[1], - xTensorDim_padded[2], - xTensorDim_padded[3]); - printf("padded filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim_padded[0], - wTensorDim_padded[1], - wTensorDim_padded[2], - wTensorDim_padded[3]); - printf("padded output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim_padded[0], - yTensorDim_padded[1], - yTensorDim_padded[2], - yTensorDim_padded[3]); + std::cout << "====PADDING DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim_padded[0] << ", " << xTensorDim_padded[1] << ", " + << xTensorDim_padded[2] << ", " << xTensorDim_padded[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim_padded[0] << ", " << wTensorDim_padded[1] << ", " + << wTensorDim_padded[2] << ", " << wTensorDim_padded[3] << std::endl; + + std::cout << "output dims are " << yTensorDim_padded[0] << ", " << yTensorDim_padded[1] << ", " + << yTensorDim_padded[2] << ", " << yTensorDim_padded[3] << std::endl; int64_t Xsize = xTensorDim_padded[0] * xTensorDim_padded[1] * xTensorDim_padded[2] * xTensorDim_padded[3]; int64_t Ysize = yTensorDim_padded[0] * yTensorDim_padded[1] * yTensorDim_padded[2] * yTensorDim_padded[3]; @@ -1418,22 +1312,15 @@ TEST_CASE("DgradDrelu sample", "[frontend][dgradDrelu][drelu]") { wTensorDim_padded[i] = wTensorDim[i]; } - printf("====PADDING DIMENSIONS====\n"); - printf("padded input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim_padded[0], - xTensorDim_padded[1], - xTensorDim_padded[2], - xTensorDim_padded[3]); - printf("padded filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim_padded[0], - wTensorDim_padded[1], - wTensorDim_padded[2], - wTensorDim_padded[3]); - printf("padded output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim_padded[0], - yTensorDim_padded[1], - yTensorDim_padded[2], - yTensorDim_padded[3]); + std::cout << "====PADDING DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim_padded[0] << ", " << xTensorDim_padded[1] << ", " + << xTensorDim_padded[2] << ", " << xTensorDim_padded[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim_padded[0] << ", " << wTensorDim_padded[1] << ", " + << wTensorDim_padded[2] << ", " << wTensorDim_padded[3] << std::endl; + + std::cout << "output dims are " << yTensorDim_padded[0] << ", " << yTensorDim_padded[1] << ", " + << yTensorDim_padded[2] << ", " << yTensorDim_padded[3] << std::endl; int64_t Xsize = xTensorDim_padded[0] * xTensorDim_padded[1] * xTensorDim_padded[2] * xTensorDim_padded[3]; int64_t Ysize = yTensorDim_padded[0] * yTensorDim_padded[1] * yTensorDim_padded[2] * yTensorDim_padded[3]; @@ -1477,22 +1364,16 @@ TEST_CASE("ConvColReduction sample", "[frontend][fusion][ConvColReduction]") { int64_t reducedTensorDim[] = {1, 256, 1, 1}; // output is NPQ * C reduced to C column - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - reducedTensorDim[0], - reducedTensorDim[1], - reducedTensorDim[2], - reducedTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << reducedTensorDim[0] << ", " << reducedTensorDim[1] << ", " << reducedTensorDim[2] + << ", " << reducedTensorDim[3] << std::endl; int64_t outputSize = reducedTensorDim[0] * reducedTensorDim[1] * reducedTensorDim[2] * reducedTensorDim[3]; @@ -1544,18 +1425,12 @@ TEST_CASE("Use errata to block global(index) for execution", "[frontend][errata] cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -1603,18 +1478,12 @@ TEST_CASE("DP4A execution with cudnnFindPlan", "[frontend][cudnnFindPlan][conv]" cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = vectorCount * dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = vectorCount * filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -1667,18 +1536,12 @@ TEST_CASE("IMMA execution with manual autotuning", "[frontend][cudnnGetPlan][con cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = vectorCount * dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = vectorCount * filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -1729,18 +1592,12 @@ TEST_CASE("Use Plan cache for rerunning the same convolution", "[frontend][dnn_h cudnnConvolutionMode_t mode = CUDNN_CONVOLUTION; - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", dimA[0], dimA[1], dimA[2], dimA[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - filterdimA[0], - filterdimA[1], - filterdimA[2], - filterdimA[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - outdimA[0], - outdimA[1], - outdimA[2], - outdimA[3]); + std::cout << "====DIMENSIONS====\n"; + std::cout << "input dims are " << dimA[0] << ", " << dimA[1] << ", " << dimA[2] << ", " << dimA[3] << "\n"; + std::cout << "filter dims are " << filterdimA[0] << ", " << filterdimA[1] << ", " << filterdimA[2] << ", " + << filterdimA[3] << "\n"; + std::cout << "output dims are " << outdimA[0] << ", " << outdimA[1] << ", " << outdimA[2] << ", " << outdimA[3] + << "\n"; int64_t Xsize = dimA[0] * dimA[1] * dimA[2] * dimA[3]; int64_t Wsize = filterdimA[0] * filterdimA[1] * filterdimA[2] * filterdimA[3]; @@ -2318,18 +2175,13 @@ TEST_CASE("Max pooling idx tensor dump", "[pooling][forward][max_pooling]") { int64_t postPaddingA[] = {1, 1}; int64_t strideA[] = {2, 2}; - printf("====DIMENSIONS====\n"); - printf("x dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; - printf("y dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "y dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t Xsize = xTensorDim[0] * xTensorDim[1] * xTensorDim[2] * xTensorDim[3]; int64_t Ysize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -2403,18 +2255,13 @@ TEST_CASE("Backward pooling", "[pooling][backward][max_pooling]") { int64_t postPaddingA[] = {0, 0}; int64_t strideA[] = {2, 2}; - printf("====DIMENSIONS====\n"); - printf("dx dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - dxTensorDim[0], - dxTensorDim[1], - dxTensorDim[2], - dxTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; - printf("dy dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - dyTensorDim[0], - dyTensorDim[1], - dyTensorDim[2], - dyTensorDim[3]); + std::cout << "dx dims are " << dxTensorDim[0] << ", " << dxTensorDim[1] << ", " << dxTensorDim[2] << ", " + << dxTensorDim[3] << std::endl; + + std::cout << "dy dims are " << dyTensorDim[0] << ", " << dyTensorDim[1] << ", " << dyTensorDim[2] << ", " + << dyTensorDim[3] << std::endl; int64_t dXsize = dxTensorDim[0] * dxTensorDim[1] * dxTensorDim[2] * dxTensorDim[3]; int64_t dYsize = dyTensorDim[0] * dyTensorDim[1] * dyTensorDim[2] * dyTensorDim[3]; @@ -2587,22 +2434,16 @@ TEST_CASE("Conv Scale", "[frontend][fusion][ConvScaleReduction]") { int64_t amaxTensorDim[] = {1, 1, 1, 1}; // Output is AMAX of conv + scale - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - amaxTensorDim[0], - amaxTensorDim[1], - amaxTensorDim[2], - amaxTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << amaxTensorDim[0] << ", " << amaxTensorDim[1] << ", " << amaxTensorDim[2] << ", " + << amaxTensorDim[3] << std::endl; int64_t outputSize = amaxTensorDim[0] * amaxTensorDim[1] * amaxTensorDim[2] * amaxTensorDim[3]; @@ -2650,22 +2491,15 @@ TEST_CASE("Conv Descale Descale Amax Scale sample", "[frontend][fusion][ConvScal int64_t amaxTensorDim[] = {1, 1, 1, 1}; // Output is AMAX of conv + scale - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t inputSize = xTensorDim[0] * xTensorDim[1] * xTensorDim[2] * xTensorDim[3]; int64_t filterSize = wTensorDim[0] * wTensorDim[1] * wTensorDim[2] * wTensorDim[3]; @@ -2719,17 +2553,12 @@ TEST_CASE("Scale transpose convert amax sample", "[frontend][fusion][Transpose]" int64_t amaxTensorDim[] = {1, 1, 1, 1}; // Output is AMAX of conv + scale - printf("====DIMENSIONS====\n"); - printf("input dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - xTensorDim[0], - xTensorDim[1], - xTensorDim[2], - xTensorDim[3]); - printf("output dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - yTensorDim[0], - yTensorDim[1], - yTensorDim[2], - yTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + std::cout << "input dims are " << xTensorDim[0] << ", " << xTensorDim[1] << ", " << xTensorDim[2] << ", " + << xTensorDim[3] << std::endl; + + std::cout << "output dims are " << yTensorDim[0] << ", " << yTensorDim[1] << ", " << yTensorDim[2] << ", " + << yTensorDim[3] << std::endl; int64_t inputSize = xTensorDim[0] * xTensorDim[1] * xTensorDim[2] * xTensorDim[3]; int64_t outputSize = yTensorDim[0] * yTensorDim[1] * yTensorDim[2] * yTensorDim[3]; @@ -2782,22 +2611,16 @@ TEST_CASE("Dgrad Descale Descale Amax Scale sample", "[frontend][fusion][ConvSca int64_t scaleDim[] = {1, 1, 1, 1}; // Scalar scale int64_t amaxTensorDim[] = {1, 1, 1, 1}; // Output is AMAX of conv + scale - printf("====DIMENSIONS====\n"); - printf("dx dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - dxTensorDim[0], - dxTensorDim[1], - dxTensorDim[2], - dxTensorDim[3]); - printf("filter dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - wTensorDim[0], - wTensorDim[1], - wTensorDim[2], - wTensorDim[3]); - printf("dy dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - dyTensorDim[0], - dyTensorDim[1], - dyTensorDim[2], - dyTensorDim[3]); + std::cout << "====DIMENSIONS====" << std::endl; + + std::cout << "dx dims are " << dxTensorDim[0] << ", " << dxTensorDim[1] << ", " << dxTensorDim[2] << ", " + << dxTensorDim[3] << std::endl; + + std::cout << "filter dims are " << wTensorDim[0] << ", " << wTensorDim[1] << ", " << wTensorDim[2] << ", " + << wTensorDim[3] << std::endl; + + std::cout << "dy dims are " << dyTensorDim[0] << ", " << dyTensorDim[1] << ", " << dyTensorDim[2] << ", " + << dyTensorDim[3] << std::endl; int64_t dxSize = dxTensorDim[0] * dxTensorDim[1] * dxTensorDim[2] * dxTensorDim[3]; int64_t filterSize = wTensorDim[0] * wTensorDim[1] * wTensorDim[2] * wTensorDim[3]; @@ -2934,33 +2757,22 @@ TEST_CASE("Back2Back Batch GEMM sample", "[frontend][fusion][back2backBatchGemm] int64_t sTensorStride[] = {4194304, 262144, 512, 1}; int64_t vTensorStride[] = {524288, 64, 1024, 1}; int64_t oTensorStride[] = {524288, 64, 1024, 1}; + std::cout << "====DIMENSIONS====" << std::endl; + + std::cout << "q dims are " << qTensorDim[0] << ", " << qTensorDim[1] << ", " << qTensorDim[2] << ", " + << qTensorDim[3] << std::endl; + + std::cout << "k dims are " << kTensorDim[0] << ", " << kTensorDim[1] << ", " << kTensorDim[2] << ", " + << kTensorDim[3] << std::endl; + + std::cout << "s dims are " << sTensorDim[0] << ", " << sTensorDim[1] << ", " << sTensorDim[2] << ", " + << sTensorDim[3] << std::endl; + + std::cout << "v dims are " << vTensorDim[0] << ", " << vTensorDim[1] << ", " << vTensorDim[2] << ", " + << vTensorDim[3] << std::endl; - printf("====DIMENSIONS====\n"); - printf("q dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - qTensorDim[0], - qTensorDim[1], - qTensorDim[2], - qTensorDim[3]); - printf("k dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - kTensorDim[0], - kTensorDim[1], - kTensorDim[2], - kTensorDim[3]); - printf("s dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - sTensorDim[0], - sTensorDim[1], - sTensorDim[2], - sTensorDim[3]); - printf("v dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - vTensorDim[0], - vTensorDim[1], - vTensorDim[2], - vTensorDim[3]); - printf("o dims are %" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "\n", - oTensorDim[0], - oTensorDim[1], - oTensorDim[2], - oTensorDim[3]); + std::cout << "o dims are " << oTensorDim[0] << ", " << oTensorDim[1] << ", " << oTensorDim[2] << ", " + << oTensorDim[3] << std::endl; int64_t qSize = qTensorDim[0] * qTensorDim[1] * qTensorDim[2] * qTensorDim[3]; int64_t kSize = kTensorDim[0] * kTensorDim[1] * kTensorDim[2] * kTensorDim[3]; @@ -3025,14 +2837,9 @@ TEST_CASE("MHA Fprop sample", "[frontend][fusion][mhaFprop]") { bool is_causal_masking = false; // specify if we need causal masking - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQ = nullptr; // queries void* devPtrK = nullptr; // keys @@ -3146,14 +2953,9 @@ TEST_CASE("MHA Bprop sample", "[frontend][fusion][mhaBprop]") { bool is_causal_masking = false; // specify if we need causal masking - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQ = nullptr; // queries void* devPtrK = nullptr; // keys @@ -3291,14 +3093,9 @@ TEST_CASE("BF16 LLM Flash MHA Fprop sample", "[frontend][fusion][BF16LLMFprop]") bool isTraining = true; // training or inference mode double dropout_probability = 0.2f; // probability of dropout. Should be 0.0 for inference mode - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQ = nullptr; // queries void* devPtrK = nullptr; // keys @@ -3383,14 +3180,9 @@ TEST_CASE("BF16 LLM Flash MHA Bprop sample", "[frontend][fusion][BF16LLMBprop]") int64_t seed = 123456; // seed for generating the dropout mask - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQ = nullptr; // queries void* devPtrKTranspose = nullptr; // keys transposed @@ -3505,14 +3297,9 @@ TEST_CASE("FP8 Flash MHA Fprop sample", "[frontend][fusion][fp8flashmhaFprop]") float dropoutProbability = 0.0f; // probability of dropout. If inference, dropout should be 0.0f int64_t seed = 123456; // seed for generating the dropout mask - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQKV = nullptr; // QKV interleaved tensor void* devPtrM = nullptr; // M tensor (row reduction max of QK.T) @@ -3700,14 +3487,9 @@ TEST_CASE("FP8 Flash MHA Bprop sample", "[frontend][fusion][fp8flashmhaBprop]") float dropoutProbability = 0.0f; // probability of dropout. If inference, dropout should be 0.0f int64_t seed = 123456; // seed for generating the dropout mask - printf("====PARAMETERS====\n"); - printf("batch is %" PRId64 ", head dim is %" PRId64 ", q sequence length is %" PRId64 - ", kv sequence length is %" PRId64 ", hidden dim is %" PRId64 "\n", - b, - h, - s_q, - s_kv, - d); + std::cout << "====PARAMETERS====" << std::endl; + std::cout << "batch is " << b << ", head dim is " << h << ", q sequence length is " << s_q + << ", kv sequence length is " << s_kv << ", hidden dim is " << d << std::endl; void* devPtrQKV = nullptr; // QKV interleaved tensor void* devPtrM = nullptr; // M tensor (row reduction max of QK.T) diff --git a/samples/utils/error_util.h b/samples/legacy_samples/utils/error_util.h similarity index 100% rename from samples/utils/error_util.h rename to samples/legacy_samples/utils/error_util.h diff --git a/samples/utils/fp16_dev.h b/samples/legacy_samples/utils/fp16_dev.h similarity index 100% rename from samples/utils/fp16_dev.h rename to samples/legacy_samples/utils/fp16_dev.h diff --git a/samples/utils/fp16_emu.h b/samples/legacy_samples/utils/fp16_emu.h similarity index 100% rename from samples/utils/fp16_emu.h rename to samples/legacy_samples/utils/fp16_emu.h diff --git a/samples/utils/helpers.h b/samples/legacy_samples/utils/helpers.h similarity index 100% rename from samples/utils/helpers.h rename to samples/legacy_samples/utils/helpers.h diff --git a/samples/python/50_scaled_dot_product_attention.ipynb b/samples/python/50_scaled_dot_product_attention.ipynb index 26ab1d76..f2538c2a 100644 --- a/samples/python/50_scaled_dot_product_attention.ipynb +++ b/samples/python/50_scaled_dot_product_attention.ipynb @@ -14,7 +14,7 @@ "\n", "The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention)\n", "\n", - "The python test code for the full set of features can be found in: [test/python_fe/test.mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)" + "The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)" ] }, { diff --git a/samples/python/51_scaled_dot_product_attention_backward.ipynb b/samples/python/51_scaled_dot_product_attention_backward.ipynb index acacf1e5..c5e5c56f 100644 --- a/samples/python/51_scaled_dot_product_attention_backward.ipynb +++ b/samples/python/51_scaled_dot_product_attention_backward.ipynb @@ -10,7 +10,7 @@ "\n", "The full documentation can be found in: [docs/operations/Attention.md#scaled-dot-product-attention-backward](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention-backward)\n", "\n", - "The python test code for the full set of features can be found in: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)" + "The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)" ] }, { diff --git a/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb b/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb new file mode 100644 index 00000000..7d54ad1d --- /dev/null +++ b/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb @@ -0,0 +1,329 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Paged Attention in cuDNN Frontend\n", + "\n", + "This notebook illustrates how the cuDNN's frontend scaled dot product attention operator can be used to supported paged attention. For a simpler introduction to the scaled dot product attention operator, please refer to [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)\n", + "\n", + "The full documentation of cuDNN's scaled dot production attention operator can be found in: [docs/operations/Attention.md#scaled-dot-product-attention](https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention). The python test code for the full set of features can be found in: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)\n", + "\n", + "More details on paged attention can be found in the [PagedAttention paper](https://arxiv.org/abs/2309.06180)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prerequisites and Setup\n", + "This notebook requires an NVIDIA GPU A100 or newer. If running on Colab, go to Runtime → Change runtime type → Hardware accelerator and select a GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# get_ipython().system('nvidia-smi')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# get_ipython().system('pip install nvidia-cudnn-cu12')\n", + "# get_ipython().system('pip install nvidia-cudnn-frontend')\n", + "# get_ipython().system('pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import cudnn\n", + "import torch\n", + "import math\n", + "\n", + "torch.manual_seed(42)\n", + "handle = cudnn.create_handle()\n", + "\n", + "assert torch.cuda.is_available()\n", + "assert (\n", + " torch.cuda.get_device_capability()[0] >= 8\n", + "), \"SDPA operation is only supported on SM80 architecture (Ampere) or above\"\n", + "\n", + "assert (\n", + " cudnn.backend_version() >= 90500\n", + "), \"SDPA operation is only supported cuDNN version 9.5.0 or above\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Problem sizes and tensor setup\n", + "\n", + "For this example, we will use the same problem size as in [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb).\n", + "In addition we are setting the block_size for both K and V to 64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create the query, key, value, and output GPU tensors using PyTorch. However, the user may use any DLPack compatible tensor instead." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "b = 4 # batch size\n", + "h = 12 # query number of heads\n", + "s = 1024 # maximum sequence length\n", + "d = 64 # embedding dimension per head\n", + "\n", + "block_size_k = block_size_v = (\n", + " 64 # block size to be used by the non contiguous K/V containers\n", + ")\n", + "\n", + "attn_scale = 1.0 / math.sqrt(d)\n", + "\n", + "# The tensors will have non-interleaved\n", + "# BSHD (batch, sequence_length, num_head, dims_per_head) physical tensor layout\n", + "# BHSD (batch, num_head, sequence_length, dims_per_head) logical tensor layout\n", + "dims = (b, h, s, d)\n", + "strides = (s * h * d, d, h * d, 1)\n", + "\n", + "q_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n", + "k_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n", + "v_gpu = torch.randn(b * s * h * d).half().cuda().as_strided(dims, strides)\n", + "o_gpu = torch.empty(b * s * h * d).half().cuda().as_strided(dims, strides)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Create variable sequence length tensors. These are required when using paged K/V caches. To keep things simple, we set these to the maximum sequence length `s` in this example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set to s for all batches, just for the notebook sample\n", + "seq_len_q_gpu = torch.full((b, 1, 1, 1), s, device=\"cuda\")\n", + "seq_len_kv_gpu = torch.full((b, 1, 1, 1), s, device=\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generate containers and page tables for K and V\n", + "\n", + "In a real world scenario, container and page table tensors are generated by other parts of the model. For illustration purposes in this example, we provide a helper function to generate a trivial container from contiguous K and V caches. \n", + "The helper function basically takes e.g., the K-cache and splits up the sequence (`S`) dimension in different blocks of length `block_size`. The resulting page table then helps identify which block belongs to which sequence ID." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Helper function to create a non contiguous container in blocks of block_size from a contiguous tensor\n", + "def create_container_and_page_table(tensor, block_size):\n", + " B, H, S, D = tensor.shape\n", + " blocks_per_batch = math.ceil(S / block_size)\n", + "\n", + " # This assertion keeps the helper function of this example simple, but is not a requirement for paged attention.\n", + " assert (blocks_per_batch * block_size) == S\n", + "\n", + " # Create a container by splitting on the S dimension and concatenating at the block dimension\n", + " # Its dimensions are [num_blocks, H, block_size, D] with num_blocks = B * blocks_per_batch\n", + " container = torch.cat((tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0)\n", + "\n", + " # Create the page table\n", + " page_table = torch.linspace(\n", + " 0,\n", + " B * blocks_per_batch - 1,\n", + " B * blocks_per_batch,\n", + " device=\"cuda\",\n", + " dtype=torch.int32,\n", + " ).reshape(blocks_per_batch, 1, B, 1)\n", + " page_table = torch.transpose(page_table, 0, 2)\n", + "\n", + " return (container, page_table)\n", + "\n", + "\n", + "# Create non contiguous containers with page tables for K and V from the contiguous k_gpu and v_gpu\n", + "container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, block_size_k)\n", + "container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, block_size_v)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Graph creation and execution\n", + "\n", + "Create the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "graph = cudnn.pygraph(\n", + " io_data_type=cudnn.data_type.HALF,\n", + " intermediate_data_type=cudnn.data_type.FLOAT,\n", + " compute_data_type=cudnn.data_type.FLOAT,\n", + ")\n", + "\n", + "q = graph.tensor_like(q_gpu)\n", + "\n", + "container_k = graph.tensor_like(container_k_gpu)\n", + "container_v = graph.tensor_like(container_v_gpu)\n", + "page_table_k = graph.tensor_like(page_table_k_gpu)\n", + "page_table_v = graph.tensor_like(page_table_v_gpu)\n", + "\n", + "seq_len_q = graph.tensor_like(seq_len_q_gpu)\n", + "seq_len_kv = graph.tensor_like(seq_len_kv_gpu)\n", + "\n", + "o, _ = graph.sdpa(\n", + " name=\"sdpa\",\n", + " q=q,\n", + " k=container_k, # Container K: non contiguous container with K blocks\n", + " v=container_v, # Container V: non contiguous container with V blocks\n", + " is_inference=True,\n", + " attn_scale=attn_scale,\n", + " use_causal_mask=True,\n", + " use_padding_mask=True,\n", + " seq_len_q=seq_len_q,\n", + " seq_len_kv=seq_len_kv,\n", + " paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks\n", + " paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks\n", + " paged_attention_max_seq_len_kv=s, # The maximum sequence length for K caches (this is optional, but recommended)\n", + ")\n", + "\n", + "o.set_output(True).set_dim(dims).set_stride(strides)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Build the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "graph.validate()\n", + "graph.build_operation_graph()\n", + "graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])\n", + "graph.check_support()\n", + "graph.build_plans()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Execute the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "variant_pack = {\n", + " q: q_gpu,\n", + " container_k: container_k_gpu,\n", + " container_v: container_v_gpu,\n", + " page_table_k: page_table_k_gpu,\n", + " page_table_v: page_table_v_gpu,\n", + " seq_len_q: seq_len_q_gpu,\n", + " seq_len_kv: seq_len_kv_gpu,\n", + " o: o_gpu,\n", + "}\n", + "\n", + "workspace = torch.empty(graph.get_workspace_size(), device=\"cuda\", dtype=torch.uint8)\n", + "graph.execute(variant_pack, workspace)\n", + "torch.cuda.synchronize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Run the PyTorch reference and compare against cuDNN's output" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "q_ref = q_gpu.detach().float().requires_grad_()\n", + "k_ref = k_gpu.detach().float().requires_grad_()\n", + "v_ref = v_gpu.detach().float().requires_grad_()\n", + "\n", + "o_ref = torch.nn.functional.scaled_dot_product_attention(\n", + " q_ref, k_ref, v_ref, is_causal=True, scale=attn_scale\n", + ")\n", + "\n", + "torch.testing.assert_close(o_ref, o_gpu.float(), atol=5e-3, rtol=3e-3)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index ac9218c2..aaa6bf15 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def build_extension(self, ext: CMakeExtension) -> None: f"-DCUDNN_FRONTEND_BUILD_PYTHON_BINDINGS=ON", # There's no need to build cpp samples and tests with python f"-DCUDNN_FRONTEND_BUILD_SAMPLES=OFF", - f"-DCUDNN_FRONTEND_BUILD_UNIT_TESTS=OFF", + f"-DCUDNN_FRONTEND_BUILD_TESTS=OFF", # All these are handled by pip f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", f"-DCUDNN_FRONTEND_KEEP_PYBINDS_IN_BINARY_DIR=OFF", diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index e155038b..86822eb9 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,3 +1,3 @@ cmake_minimum_required(VERSION 3.18) -add_subdirectory(unit_tests) \ No newline at end of file +add_subdirectory(cpp) diff --git a/test/unit_tests/CMakeLists.txt b/test/cpp/CMakeLists.txt similarity index 89% rename from test/unit_tests/CMakeLists.txt rename to test/cpp/CMakeLists.txt index 633fb5ec..e244cd0c 100644 --- a/test/unit_tests/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -18,8 +18,9 @@ endif() include(${PROJECT_SOURCE_DIR}/cmake/cuDNN.cmake) add_executable( - unit_tests + tests + pointwise_tests.cpp serialize.cpp validate.cpp version.cpp @@ -28,7 +29,7 @@ add_executable( if (MSVC) target_compile_options( - unit_tests PRIVATE + tests PRIVATE /W4 /WX # warning level 3 and all warnings as errors /wd4100 # allow unused parameters /wd4458 # local hides class member (currently a problem for all inline setters) @@ -38,7 +39,7 @@ if (MSVC) ) else() target_compile_options( - unit_tests PRIVATE + tests PRIVATE -Wall -Wextra -Werror @@ -47,19 +48,20 @@ else() endif() target_link_libraries( - unit_tests + tests cudnn_frontend _cudnn_frontend_pch Catch2::Catch2WithMain - CUDNN::cudnn_all + CUDNN::cudnn + + CUDA::cudart ) # cuDNN dlopen's its libraries # Add all libraries in link line as NEEDED set_target_properties( - unit_tests + tests PROPERTIES - LINK_WHAT_YOU_USE TRUE RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/bin ) diff --git a/test/cpp/pointwise_tests.cpp b/test/cpp/pointwise_tests.cpp new file mode 100644 index 00000000..ad3e350c --- /dev/null +++ b/test/cpp/pointwise_tests.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ +#include + +#include + +TEST_CASE("Pointwise shape deduction", "[pointwise_shape_deduction]") { + namespace fe = cudnn_frontend; + + cudnnHandle_t handle; + cudnnCreate(&handle); + + fe::graph::Graph graph; + graph.set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto in0 = graph.tensor( + fe::graph::Tensor_attributes().set_name("in0").set_dim({8, 128, 16000, 1}).set_stride({2048000, 1, 128, 128})); + + auto in1 = graph.tensor( + fe::graph::Tensor_attributes().set_name("in1").set_dim({1, 128, 1, 1}).set_stride({128, 1, 128, 128})); + + auto add_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + + auto out_0 = graph.pointwise(in0, in1, add_options); + + out_0->set_output(true); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + REQUIRE(out_0->get_dim() == in0->get_dim()); + REQUIRE(out_0->get_stride() == in0->get_stride()); + + cudnnDestroy(handle); +} + +TEST_CASE("Pointwise Add shape deduction", "[pointwise_shape_deduction]") { + namespace fe = cudnn_frontend; + + cudnnHandle_t handle; + cudnnCreate(&handle); + + fe::graph::Graph graph; + graph.set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto in0 = graph.tensor( + fe::graph::Tensor_attributes().set_name("in0").set_dim({1, 4194304, 1}).set_stride({1, 1, 4194304})); + + auto in1 = + graph.tensor(fe::graph::Tensor_attributes().set_name("in1").set_dim({1, 4194304, 32}).set_stride({1, 32, 1})); + + auto add_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); + + auto out_0 = graph.pointwise(in0, in1, add_options); + out_0->set_output(true); + + REQUIRE(graph.validate().is_good()); + + REQUIRE(graph.build_operation_graph(handle).is_good()); + + REQUIRE(out_0->get_dim() == in1->get_dim()); + REQUIRE(out_0->get_stride() == in1->get_stride()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/test/unit_tests/serialize.cpp b/test/cpp/serialize.cpp similarity index 99% rename from test/unit_tests/serialize.cpp rename to test/cpp/serialize.cpp index 5b28d5e6..d138e2b2 100644 --- a/test/unit_tests/serialize.cpp +++ b/test/cpp/serialize.cpp @@ -140,6 +140,8 @@ TEST_CASE("Graph key", "[graph][key]") { REQUIRE(graph.build_plans(handle).is_good()); REQUIRE(key == graph.key()); + + cudnnDestroy(handle); } TEST_CASE("Graph key dynamic shape", "[graph][key][dynamic_shape]") { @@ -218,6 +220,8 @@ TEST_CASE("Graph key dynamic shape", "[graph][key][dynamic_shape]") { REQUIRE(graph.build_plans(handle).is_good()); REQUIRE(key == graph.key()); + + cudnnDestroy(handle); } } diff --git a/test/unit_tests/tensor.cpp b/test/cpp/tensor.cpp similarity index 100% rename from test/unit_tests/tensor.cpp rename to test/cpp/tensor.cpp diff --git a/test/unit_tests/validate.cpp b/test/cpp/validate.cpp similarity index 100% rename from test/unit_tests/validate.cpp rename to test/cpp/validate.cpp diff --git a/test/unit_tests/version.cpp b/test/cpp/version.cpp similarity index 100% rename from test/unit_tests/version.cpp rename to test/cpp/version.cpp diff --git a/test/python_fe/conftest.py b/test/python/conftest.py similarity index 90% rename from test/python_fe/conftest.py rename to test/python/conftest.py index 61919585..2ca5f132 100644 --- a/test/python_fe/conftest.py +++ b/test/python/conftest.py @@ -41,3 +41,8 @@ def pytest_addoption(parser): default=None, help="[test_mhas.py] force deterministic algorithm", ) + parser.addoption( + "--mha_block_size", + default=None, + help="[test_mhas.py] block size for paged attention", + ) diff --git a/test/python_fe/test_apply_rope.py b/test/python/test_apply_rope.py similarity index 100% rename from test/python_fe/test_apply_rope.py rename to test/python/test_apply_rope.py diff --git a/test/python_fe/test_batchnorm.py b/test/python/test_batchnorm.py similarity index 100% rename from test/python_fe/test_batchnorm.py rename to test/python/test_batchnorm.py diff --git a/test/python_fe/test_conv_bias.py b/test/python/test_conv_bias.py similarity index 100% rename from test/python_fe/test_conv_bias.py rename to test/python/test_conv_bias.py diff --git a/test/python_fe/test_conv_genstats.py b/test/python/test_conv_genstats.py similarity index 100% rename from test/python_fe/test_conv_genstats.py rename to test/python/test_conv_genstats.py diff --git a/test/python_fe/test_conv_reduction.py b/test/python/test_conv_reduction.py similarity index 100% rename from test/python_fe/test_conv_reduction.py rename to test/python/test_conv_reduction.py diff --git a/test/python_fe/test_instancenorm.py b/test/python/test_instancenorm.py similarity index 100% rename from test/python_fe/test_instancenorm.py rename to test/python/test_instancenorm.py diff --git a/test/python/test_kernel_cache.py b/test/python/test_kernel_cache.py new file mode 100644 index 00000000..08cc9ed2 --- /dev/null +++ b/test/python/test_kernel_cache.py @@ -0,0 +1,96 @@ +import cudnn +import pytest +import torch +import itertools +from looseversion import LooseVersion + +from collections import namedtuple + +problem_defintion = namedtuple("problem_defintion", ["b", "m", "n", "k"]) + +shapes = [ + problem_defintion(b=16, m=32, n=32, k=128), + problem_defintion(b=16, m=64, n=64, k=128), + problem_defintion(b=16, m=80, n=80, k=128), + problem_defintion(b=32, m=128, n=128, k=256), + problem_defintion(b=32, m=64, n=64, k=256), +] + + +def build_cudnn_graph(handle, cache, shape): + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.HALF, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=handle, + kernel_cache=cache, + ) + + A = graph.tensor( + name="A", + dim=[shape.b, shape.m, shape.k], + stride=[shape.m * shape.k, shape.k, 1], + ) + B = graph.tensor( + name="B", + dim=[shape.b, shape.k, shape.n], + stride=[shape.n * shape.k, shape.n, 1], + ) + + C = graph.matmul(name="matmul", A=A, B=B) + C.set_output(True).set_uid(2) + + A.set_uid(0) + B.set_uid(1) + + graph.build([cudnn.heur_mode.A]) + + return graph + + +@pytest.mark.skipif( + LooseVersion(cudnn.backend_version_string()) < "9.5", + reason="requires cudnn 9.5 or higher", +) +def test_kernel_cache(cudnn_handle): + + cache = cudnn.create_kernel_cache() + + for shape in shapes: + graph = build_cudnn_graph(cudnn_handle, cache, shape) + + A = torch.randn( + shape.b, + shape.m, + shape.k, + requires_grad=False, + device="cuda", + dtype=torch.bfloat16, + ) + B = torch.randn( + shape.b, + shape.k, + shape.n, + requires_grad=False, + device="cuda", + dtype=torch.bfloat16, + ) + C = torch.randn( + shape.b, + shape.m, + shape.n, + requires_grad=False, + device="cuda", + dtype=torch.bfloat16, + ) + + workspace = torch.empty( + graph.get_workspace_size(), device="cuda", dtype=torch.uint8 + ) + + print("Executing", shape) + graph.execute({0: A, 1: B, 2: C}, workspace, handle=cudnn_handle) + + +if __name__ == "__main__": + test_kernel_cache(cudnn_handle) diff --git a/test/python_fe/test_layernorm.py b/test/python/test_layernorm.py similarity index 100% rename from test/python_fe/test_layernorm.py rename to test/python/test_layernorm.py diff --git a/test/python_fe/test_matmul_bias_relu.py b/test/python/test_matmul_bias_relu.py similarity index 100% rename from test/python_fe/test_matmul_bias_relu.py rename to test/python/test_matmul_bias_relu.py diff --git a/test/python_fe/test_mhas.py b/test/python/test_mhas.py similarity index 92% rename from test/python_fe/test_mhas.py rename to test/python/test_mhas.py index 9b7f0c2d..e037baff 100644 --- a/test/python_fe/test_mhas.py +++ b/test/python/test_mhas.py @@ -1,3 +1,13 @@ +""" +This test harness allows for testing the various options of the attention operator. See example usage under "main" below. + +The full documentation on the attention operator can be found in: https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md#scaled-dot-product-attention + +Notebooks that demonstrate the attention operator can be found here: +- Introductory example: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb +- Example with paged caches: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb +""" + import cudnn import pytest import torch @@ -21,6 +31,7 @@ dropout_options = [False, True] ragged_options = [False, True] is_infer_options = [False, True] +page_table_options = [False, True] def convert_to_cudnn_type(torch_type): @@ -424,6 +435,7 @@ def convert_ragged_to_uniform(ragged_tensor, seq_len): @pytest.mark.parametrize("is_padding", padding_mask_options, ids=lambda p: f"padding{int(p)}") @pytest.mark.parametrize("is_alibi", alibi_mask_options, ids=lambda p: f"alibi{int(p)}") @pytest.mark.parametrize("is_bias", bias_options, ids=lambda p: f"bias{int(p)}") +@pytest.mark.parametrize("is_paged_attention", page_table_options, ids=lambda p: f"paged{int(p)}") @pytest.mark.parametrize("head_group", head_group_options) @pytest.mark.parametrize("layout", layout_options) @pytest.mark.parametrize("input_type", input_type_options, ids=lambda p: str(p)) @@ -433,6 +445,7 @@ def test_sdpa( input_type, layout, head_group, + is_paged_attention, is_bias, is_alibi, is_padding, @@ -445,6 +458,8 @@ def test_sdpa( request, cudnn_handle ): + + #pytest.set_trace() cudnn_version = LooseVersion(cudnn.backend_version_string()) @@ -478,6 +493,10 @@ def test_sdpa( if is_ragged and not is_padding: pytest.skip("Ragged tensor is only tested with packed variable length tensors") + if is_paged_attention and (not is_padding or cudnn_version < "9.4" or not layout == "bshd_bshd_bshd" or is_ragged): + pytest.skip("Paged attention is only tested with packed variable length tensors, thd_thd_thd, no ragged offsets, and only on cuDNNv9.4 or greater") + + # -------------------------- default randomized parameter testing ------------------------ # batch size b = 2 @@ -511,6 +530,9 @@ def test_sdpa( else: assert False, "Head group must be either MHA, GQA, or MQA" + # block size for paged attention + block_size = random.choice([32, 64, 128]) + # -------------------------- override test parameters if args are provided ---------------- b = int(request.config.option.mha_b) if request.config.option.mha_b != None else b s_q = int(request.config.option.mha_s_q) if request.config.option.mha_s_q != None else s_q @@ -520,6 +542,7 @@ def test_sdpa( h_q = int(request.config.option.mha_h_q) if request.config.option.mha_h_q != None else h_q h_k = int(request.config.option.mha_h_k) if request.config.option.mha_h_k != None else h_k h_v = int(request.config.option.mha_h_v) if request.config.option.mha_h_v != None else h_v + block_size = int(request.config.option.mha_block_size) if request.config.option.mha_block_size != None else block_size if d_qk != d_v and cudnn_version < "8.9.6": pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") @@ -532,7 +555,7 @@ def test_sdpa( print("\n=============== TEST CMD TO REPRODUCE ===============") print( - f"pytest {request.node.nodeid} --mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}" + f"pytest {request.node.nodeid} --mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v} --mha_block_size={block_size}" ) print("=====================================================") @@ -619,6 +642,35 @@ def test_sdpa( else None ) + def create_container_and_page_table(tensor, block_size): + B, H, S, D = tensor.shape + # num_blocks = math.ceil(S/block_size) * B + blocks_per_batch = math.ceil(S/block_size) + + padding_seq = (blocks_per_batch * block_size) - S + if padding_seq > 0: + zeros = torch.zeros(B,H,padding_seq,D, device='cuda', dtype=tensor.dtype) + cat_tensor = torch.cat((tensor, zeros), axis = 2) + else: + cat_tensor = tensor + + reshaped = torch.cat((cat_tensor.clone()).chunk(blocks_per_batch, dim=2), dim=0) + + table_size = math.ceil(S/block_size) + page_table = torch.linspace(0, B*table_size-1, B*table_size, device='cuda', dtype=torch.int32).reshape(table_size,1,B,1) + page_table = torch.transpose(page_table,0,2) + + return(reshaped, page_table) + + + container_k_gpu = None + container_v_gpu = None + page_table_k_gpu = None + page_table_v_gpu = None + if is_paged_attention: + container_k_gpu, page_table_k_gpu = create_container_and_page_table(k_gpu, block_size) + container_v_gpu, page_table_v_gpu = create_container_and_page_table(v_gpu, block_size) + stream = torch.cuda.current_stream().cuda_stream cudnn.set_stream(handle=cudnn_handle, stream=stream) @@ -631,8 +683,11 @@ def test_sdpa( ) q = graph.tensor_like(q_gpu) - k = graph.tensor_like(k_gpu) - v = graph.tensor_like(v_gpu) + k = graph.tensor_like(k_gpu) if not is_paged_attention else graph.tensor_like(container_k_gpu) + v = graph.tensor_like(v_gpu) if not is_paged_attention else graph.tensor_like(container_v_gpu) + + page_table_k = graph.tensor_like(page_table_k_gpu) if is_paged_attention else None + page_table_v = graph.tensor_like(page_table_v_gpu) if is_paged_attention else None bias = graph.tensor_like(bias_gpu) if is_bias else None @@ -660,6 +715,8 @@ def test_sdpa( if is_sliding_window: sliding_window_length = s_kv // 4 + + o, stats = graph.sdpa( name="sdpa", q=q, @@ -677,6 +734,9 @@ def test_sdpa( sliding_window_length=sliding_window_length, dropout=dropout_tuple if is_dropout else None, rng_dump=rng_dump, + paged_attention_k_table=page_table_k, + paged_attention_v_table=page_table_v, + paged_attention_max_seq_len_kv=s_kv if is_paged_attention else None ) o.set_output(True).set_dim(shape_o).set_stride(stride_o) @@ -689,19 +749,21 @@ def test_sdpa( try: graph.validate() except cudnn.cudnnGraphNotSupportedError as e: + print("Graph not supported") pytest.xfail(repr(e)) except Exception as e: pytest.fail(repr(e)) graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + #graph.create_execution_plans([cudnn.heur_mode.FALLBACK]) graph.check_support() graph.build_plans() variant_pack = { q: q_gpu, - k: k_gpu, - v: v_gpu, + k: k_gpu if not is_paged_attention else container_k_gpu, + v: v_gpu if not is_paged_attention else container_v_gpu, bias: bias_gpu, seq_len_q: seq_len_q_gpu, seq_len_kv: seq_len_kv_gpu, @@ -712,6 +774,8 @@ def test_sdpa( o: o_gpu, stats: stats_gpu, rng_dump: rng_dump_gpu, + page_table_k: page_table_k_gpu, + page_table_v: page_table_v_gpu } if is_dropout: @@ -847,9 +911,6 @@ def test_sdpa_backward( if is_ragged and not (layout == "bshd_bshd_bshd" or layout == "bs3hd"): pytest.skip("Ragged tensor is only tested with thd_thd_thd and t3hd") - if is_ragged and head_group != "multi_head": - pytest.skip("Ragged offset is only supported with multi_head") - if is_ragged and layout == "bs3hd" and cudnn_version < "9.1.0": pytest.skip("t3hd is only supported on 9.1.0 onwards") @@ -1342,7 +1403,7 @@ def test_sdpa_backward( # ================== forward ================== """ pytest \ - test/python_fe/test_mhas.py::test_sdpa[torch.float16-bshd_bshd_bshd-group_query-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0-infer0] \ + test/python/test_mhas.py::test_sdpa[torch.float16-bshd_bshd_bshd-group_query-paged0-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0-infer0] \ -s \ --mha_b 3 \ --mha_s_q 256 \ @@ -1357,7 +1418,7 @@ def test_sdpa_backward( # ================== backward ================== """ pytest \ - test/python_fe/test_mhas.py::test_sdpa_backward[torch.float16-bshd_bshd_bshd-group_query-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0] \ + test/python/test_mhas.py::test_sdpa_backward[torch.float16-bshd_bshd_bshd-group_query-bias0-alibi0-padding0-causal0-causal_bottom_right0-sliding_window0-dropout0-ragged0] \ -s \ --mha_b 3 \ --mha_s_q 256 \ diff --git a/test/python_fe/test_rmsnorm.py b/test/python/test_rmsnorm.py similarity index 100% rename from test/python_fe/test_rmsnorm.py rename to test/python/test_rmsnorm.py diff --git a/test/python_fe/test_silu_and_mul.py b/test/python/test_silu_and_mul.py similarity index 100% rename from test/python_fe/test_silu_and_mul.py rename to test/python/test_silu_and_mul.py diff --git a/test/python_fe/test_slice.py b/test/python/test_slice.py similarity index 100% rename from test/python_fe/test_slice.py rename to test/python/test_slice.py diff --git a/test/python_fe/test_utils.py b/test/python/test_utils.py similarity index 100% rename from test/python_fe/test_utils.py rename to test/python/test_utils.py diff --git a/test/python_fe/test_wgrads.py b/test/python/test_wgrads.py similarity index 100% rename from test/python_fe/test_wgrads.py rename to test/python/test_wgrads.py