From ea7f8b9820ac4a28729eddd10331423077e5ad61 Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Sun, 24 Sep 2023 20:25:49 -0700 Subject: [PATCH] cudnn prerelease_3: Improvements over prerelease 2: [Feature] Added SDPA flash attention backwward node. [Enhancement] Resolved an issue where the computed Alibi slopes were copied onto GPU memory on default stream instead of user specified stream in the handle. [Bug fix] Fix windows compilation error when pedantic warnings are treated as error. [Bug fix] Fixed issue in causal padding where the masked values were `std::numeric_limits::min()` instead of `std::numeric_limits::lowest()` Under investigation and development: - We are still working on additional features for SDPA back prop. - Better error messages and logging --- README.FE.1.0.md | 3 +- docs/operations/Attention.md | 152 +++- .../cudnn_frontend_cudnn_interface.h | 8 +- .../cudnn_frontend_graph_helpers.h | 182 +++-- .../cudnn_frontend_graph_interface.h | 49 +- .../cudnn_frontend_graph_properties.h | 158 +++- .../cudnn_frontend_node_interface.h | 101 +-- include/cudnn_frontend/node/batchnorm.h | 28 +- .../cudnn_frontend/node/batchnorm_inference.h | 10 +- include/cudnn_frontend/node/bn_finalize.h | 10 +- include/cudnn_frontend/node/conv_dgrad.h | 14 +- include/cudnn_frontend/node/conv_fprop.h | 5 +- include/cudnn_frontend/node/conv_wgrad.h | 14 +- include/cudnn_frontend/node/dbn.h | 29 +- include/cudnn_frontend/node/dbn_weight.h | 10 +- include/cudnn_frontend/node/dln.h | 35 +- include/cudnn_frontend/node/genstats.h | 10 +- include/cudnn_frontend/node/layernorm.h | 38 +- include/cudnn_frontend/node/matmul.h | 23 +- include/cudnn_frontend/node/pointwise.h | 41 +- include/cudnn_frontend/node/reduction.h | 18 +- include/cudnn_frontend/node/reshape.h | 19 +- include/cudnn_frontend/node/rng.h | 11 +- .../node/scaled_dot_product_attention.h | 3 +- .../node/scaled_dot_product_flash_attention.h | 603 ++++++++++++-- include/cudnn_frontend/node/softmax.h | 8 +- python_bindings/cudnn_frontend_properties.cpp | 11 +- python_bindings/cudnn_frontend_pygraph.cpp | 126 ++- samples/cpp/matmuls.cpp | 4 +- samples/cpp/mha.cpp | 293 +++++-- samples/python/matmul_bias_relu.py | 45 -- samples/python/test_matmul_bias_relu.py | 63 ++ samples/python/test_mhas.py | 763 ++++++++++++------ 33 files changed, 2109 insertions(+), 778 deletions(-) delete mode 100644 samples/python/matmul_bias_relu.py create mode 100644 samples/python/test_matmul_bias_relu.py diff --git a/README.FE.1.0.md b/README.FE.1.0.md index 931d5049..e5655249 100644 --- a/README.FE.1.0.md +++ b/README.FE.1.0.md @@ -42,7 +42,8 @@ FE v1.0 API follows a functional style of building a graph. Operations take in i | Generate stats of output| genstats
Genstats_attributes | genstats | | BN Finalize of stats | bn_finalize
BN_finalize_attributes | bn_finalize | | Dbn weight | dbn_weight
DBN_weight_attributes | dbn_weight | -| Scale dot product flash attention | scaled_dot_product_flash_attention
Scaled_dot_product_flash_attention_attributes | scaled_dot_product_flash_attention| +| Scale dot product flash attention | scaled_dot_product_flash_attention
Scaled_dot_product_flash_attention_attributes | scaled_dot_product_flash_attention | +| Scale dot product flash attention_backward | scaled_dot_product_flash_attention_backward
Scaled_dot_product_flash_attention_backward_attributes | scaled_dot_product_flash_attention_backward | ### Create Graph Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations. diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index f6bb0670..eb9767c6 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -1,11 +1,12 @@ ## Table of Contents -1. [Scaled Dot Product Flash Attention](#Scaled Dot Product Flash Attention) - +1. [Scaled Dot Product Flash Attention](#scaled-dot-product-flash-attention) +2. [Scaled Dot Product Flash Attention Backward](#scaled-dot-product-flash-attention-backward) ### Scaled Dot Product Flash Attention Computes the scaled dot product attention for given Query, Key and Value tensors. Optionally, can set dropout probability, causal mask. Can optionally dump stats to be used for the bprop computation. API: + ``` std::array, 2> scaled_dot_product_flash_attention @@ -15,50 +16,121 @@ scaled_dot_product_flash_attention Scaled_dot_product_flash_attention_attributes options); ``` -where the output array has tensors in order of: `[output, softmax_stats]` -where, `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation +where the output array has tensors in order of: `[output, softmax_stats]` and `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation + +``` +Scaled_dot_product_flash_attention_attributes & +set_is_inference(bool const value); + +Scaled_dot_product_flash_attention_attributes & +set_causal_mask(bool const value); + +Scaled_dot_product_flash_attention_attributes & +set_bias(std::shared_ptr value); + +Scaled_dot_product_flash_attention_attributes & +set_attn_scale(std::shared_ptr value); + +Scaled_dot_product_flash_attention_attributes & +set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset); + +Scaled_dot_product_flash_attention_attributes & +set_dropout(std::shared_ptr mask, std::shared_ptr scale); +Scaled_dot_product_flash_attention_attributes & +set_compute_data_type(DataType_t value) +``` + +Python API: ``` - Scaled_dot_product_flash_attention_attributes & - set_is_inference(bool const value); - - Scaled_dot_product_flash_attention_attributes & - set_causal_mask(bool const value); - - Scaled_dot_product_flash_attention_attributes & - set_bias(std::shared_ptr value); - - Scaled_dot_product_flash_attention_attributes & - set_attn_scale(std::shared_ptr value); - - Scaled_dot_product_flash_attention_attributes & - set_dropout(float const probability, - std::shared_ptr seed, - std::shared_ptr offset); - - Scaled_dot_product_flash_attention_attributes & - set_dropout(std::shared_ptr mask, std::shared_ptr scale); - - Scaled_dot_product_flash_attention_attributes & - set_compute_data_type(DataType_t value) +Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. + is_inference (bool): Whether it is an inference step or training step. + attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. + bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. + use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + +Returns: + o (cudnn_tensor): The result of scaled dot-product flash attention. + stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. +``` + +### Scaled Dot Product Flash Attention Backward +Computes the query, key and value gradient tensors for scaled dot product flash attention. Optionally, can set dropout probability, causal mask. + +API: +``` +std::array, 3> +scaled_dot_product_flash_attention_backward + (std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr stats, + Scaled_dot_product_flash_attention_backward_attributes); +``` + +where the output array has tensors in order of: `[dQ, dK, dV]` +where, `Scaled_dot_product_flash_attention_backward_attributes` controls the sub-graph in the operation + + +``` +Scaled_dot_product_flash_attention_backward_attributes& +set_attn_scale(std::shared_ptr value) + +Scaled_dot_product_flash_attention_backward_attributes& +set_bias(std::shared_ptr value) + +Scaled_dot_product_flash_attention_backward_attributes& +set_causal_mask(bool const value) + +Scaled_dot_product_flash_attention_backward_attributes& +set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) + +Scaled_dot_product_flash_attention_backward_attributes& +set_dropout(std::shared_ptr mask, std::shared_ptr scale, std::shared_ptr scale_inv) + +Scaled_dot_product_flash_attention_backward_attributes& +set_compute_data_type(DataType_t const value) ``` Python API: - - q - - k - - v - - seq_q - - seq_k - - is_inference - - attn_scale - - bias - - use_padding_mask - - use_alibi_mask - - use_causal_mask - - dropout - - compute_data_type - - name + +``` +Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + o (cudnn_tensor): The output data. + dO (cudnn_tensor): The output loss gradient. + stats (cudnn_tensor): The softmax statistics from the forward pass. + attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. + bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + +Returns: + dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention. + dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention. + dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention. +``` ## Miscellaneous - FE provides shadow enums which help avoid users to workaround having different enums for different cudnn versions. diff --git a/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h b/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h index 5fb2cbac..c6bf056e 100644 --- a/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h +++ b/include/cudnn_frontend/cudnn_frontend_cudnn_interface.h @@ -47,7 +47,9 @@ class ICudnn { error_t create_cudnn_tensor(std::shared_ptr const& props) { // Check whether tensor already created - if (tensors.find(props->get_uid()) != tensors.end()) { + auto const uid = props->get_uid(); + if (tensors.find(uid) != tensors.end()) { + getLogger() << "[cudnn_frontend] INFO: Backend tensor already created for Id: " << uid << ".\n"; return {error_code_t::OK, ""}; } @@ -55,14 +57,14 @@ class ICudnn { auto tensor = cudnn_frontend::TensorBuilder() .setDim(props->get_dim().size(), props->get_dim().data()) .setStrides(props->get_stride().size(), props->get_stride().data()) - .setId(props->get_uid()) + .setId(uid) .setAlignment(16) .setDataType(props->get_data_type()) .setVirtual(props->get_is_virtual()) .setByValue(props->get_is_pass_by_value()) .setReorderType(props->get_reordering_type()) .build(); - tensors.emplace(props->get_uid(), std::make_shared(std::move(tensor))); + tensors.emplace(uid, std::make_shared(std::move(tensor))); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/cudnn_frontend_graph_helpers.h b/include/cudnn_frontend/cudnn_frontend_graph_helpers.h index bdc8ca6f..1240fe3b 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_helpers.h +++ b/include/cudnn_frontend/cudnn_frontend_graph_helpers.h @@ -20,6 +20,8 @@ enum class [[nodiscard]] error_code_t{OK, GRAPH_EXECUTION_FAILED, HEURISTIC_QUERY_FAILED, UNSUPPORTED_GRAPH_FORMAT, + CUDA_API_FAILED, + CUDNN_BACKEND_API_FAILED, INVALID_CUDA_DEVICE, HANDLE_ERROR}; @@ -61,69 +63,83 @@ typedef struct error_object { } error_t; -#define CHECK_CUDNN_FRONTEND_ERROR(x) \ +#ifdef WIN32 +#define CUDNN_FRONTEND_WHILE_FALSE \ + __pragma(warning(push)) __pragma(warning(disable : 4127)) while (0) __pragma(warning(pop)) +#else +#define CUDNN_FRONTEND_WHILE_FALSE while (0) +#endif + +#define CHECK_CUDNN_FRONTEND_ERROR(x) \ + do { \ + if (auto retval = x; retval.is_bad()) { \ + getLogger() << "[cudnn_frontend] ERROR: " << #x << " at " << __FILE__ << ":" << __LINE__ << std::endl; \ + return retval; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define RETURN_CUDNN_FRONTEND_ERROR_IF(cond, retval, message) \ do { \ - if (x.is_bad()) { \ - getLogger() << "[cudnn_frontend] ERROR: " << #x << " code " << x.get_code() << " at " << __FILE__ << ":" \ - << __LINE__ << std::endl; \ - return x; \ + if (cond) { \ + if (retval == error_code_t::OK) { \ + getLogger() << "[cudnn_frontend] INFO: "; \ + } else { \ + getLogger() << "[cudnn_frontend] ERROR: "; \ + } \ + getLogger() << message << ". " << retval << " because (" << #cond ") at " << __FILE__ << ":" << __LINE__ \ + << "\n"; \ + return {retval, message}; \ } \ - } while (0) + } \ + CUDNN_FRONTEND_WHILE_FALSE + +#define CHECK_CUDNN_ERROR(x) \ + do { \ + if (auto cudnn_retval = x; cudnn_retval != CUDNN_STATUS_SUCCESS) { \ + std::stringstream error_msg; \ + error_msg << #x << " failed with " << cudnnGetErrorString(cudnn_retval); \ + getLogger() << "[cudnn_frontend] ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__ \ + << std::endl; \ + return {error_code_t::CUDNN_BACKEND_API_FAILED, error_msg.str()}; \ + } \ + } \ + CUDNN_FRONTEND_WHILE_FALSE -#define RETURN_CUDNN_FRONTEND_ERROR_IF(cond, retval) \ +#define CHECK_CUDA_ERROR(x) \ do { \ - if (cond) { \ - if (retval.get_code() == error_code_t::OK) { \ - getLogger() << "[cudnn_frontend] INFO: " << #cond << " returned " << retval.get_code() << " at " \ - << __FILE__ << ":" << __LINE__ << std::endl; \ - } else { \ - getLogger() << "[cudnn_frontend] ERROR: " << #cond << " returned " << retval.get_code() << " at " \ - << __FILE__ << ":" << __LINE__ << std::endl; \ - } \ - return {retval}; \ + if (auto cuda_retval = x; cuda_retval != cudaSuccess) { \ + std::stringstream error_msg; \ + error_msg << #x << " failed with " << cudaGetErrorString(cuda_retval); \ + getLogger() << "[cudnn_frontend] ERROR: " << error_msg.str() << " at " << __FILE__ << ":" << __LINE__ \ + << std::endl; \ + return {error_code_t::CUDA_API_FAILED, error_msg.str()}; \ } \ - } while (0) + } \ + CUDNN_FRONTEND_WHILE_FALSE + +NLOHMANN_JSON_SERIALIZE_ENUM(error_code_t, + { + {error_code_t::OK, "OK"}, + {error_code_t::ATTRIBUTE_NOT_SET, "ATTRIBUTE_NOT_SET"}, + {error_code_t::SHAPE_DEDUCTION_FAILED, "SHAPE_DEDUCTION_FAILED"}, + {error_code_t::INVALID_TENSOR_NAME, "INVALID_TENSOR_NAME"}, + {error_code_t::INVALID_VARIANT_PACK, "INVALID_VARIANT_PACK"}, + {error_code_t::GRAPH_NOT_SUPPORTED, "GRAPH_NOT_SUPPORTED"}, + {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, + "GRAPH_EXECUTION_PLAN_CREATION_FAILED"}, + {error_code_t::GRAPH_EXECUTION_FAILED, "GRAPH_EXECUTION_FAILED"}, + {error_code_t::HEURISTIC_QUERY_FAILED, "HEURISTIC_QUERY_FAILED"}, + {error_code_t::CUDNN_BACKEND_API_FAILED, "CUDNN_BACKEND_API_FAILED"}, + {error_code_t::CUDA_API_FAILED, "CUDA_API_FAILED"}, + {error_code_t::INVALID_CUDA_DEVICE, "INVALID_CUDA_DEVICE"}, + {error_code_t::UNSUPPORTED_GRAPH_FORMAT, "UNSUPPORTED_GRAPH_FORMAT"}, + {error_code_t::HANDLE_ERROR, "HANDLE_ERROR"}, + }) static inline std::ostream& operator<<(std::ostream& os, const error_code_t& mode) { - switch (mode) { - case error_code_t::OK: - os << "OK"; - break; - case error_code_t::ATTRIBUTE_NOT_SET: - os << "ATTRIBUTE_NOT_SET"; - break; - case error_code_t::SHAPE_DEDUCTION_FAILED: - os << "SHAPE_DEDUCTION_FAILED"; - break; - case error_code_t::INVALID_TENSOR_NAME: - os << "INVALID_TENSOR_NAME"; - break; - case error_code_t::INVALID_VARIANT_PACK: - os << "INVALID_VARIANT_PACK"; - break; - case error_code_t::GRAPH_NOT_SUPPORTED: - os << "GRAPH_NOT_SUPPORTED"; - break; - case error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED: - os << "GRAPH_EXECUTION_PLAN_CREATION_FAILED"; - break; - case error_code_t::GRAPH_EXECUTION_FAILED: - os << "GRAPH_EXECUTION_FAILED"; - break; - case error_code_t::HEURISTIC_QUERY_FAILED: - os << "HEURISTIC_QUERY_FAILED"; - break; - case error_code_t::INVALID_CUDA_DEVICE: - os << "INVALID_CUDA_DEVICE"; - break; - case error_code_t::UNSUPPORTED_GRAPH_FORMAT: - os << "UNSUPPORTED_GRAPH_FORMAT"; - break; - case error_code_t::HANDLE_ERROR: - os << "HANDLE_ERROR"; - break; - } + os << json{mode}; return os; } @@ -195,20 +211,62 @@ class Context { } }; -// Always generates NCHW (4d/5d tensors) or Col major (matrices) +// Creates dense, non-overlapping strides from given dim and stride_order. +// For example, if a is a 4D tensor with dimensions labeled NCHW, then strided(a, (3, 0, 2, 1)) produces +// strides where the C dimension has a corresponding stride of one. inline std::vector -generate_stride(std::vector const& dim) { - std::vector stride(dim.size(), 1); +generate_stride(std::vector const& dim, std::vector const& stride_order) { + size_t num_dims = dim.size(); + std::vector stride(num_dims); - stride[dim.size() - 1] = stride[1] * dim[1]; - for (int64_t d = dim.size() - 2; d >= 2; d--) { - stride[d] = stride[d + 1] * dim[d + 1]; + // Sort the dimensions according to strides from least to greatest. + // Example, dim = (2, 3, 4, 5) stride_order = (3, 1, 2, 0) + // sorted_stride_order = ((0, (3, 5)), (1, (1, 3)), (2, (2, 4)), (3, (0, 2))) + std::vector>> sorted_stride_order; + for (size_t i = 0; i < num_dims; ++i) { + sorted_stride_order.push_back({stride_order[i], {i, dim[i]}}); + } + std::sort(sorted_stride_order.begin(), sorted_stride_order.end()); + + // As dims have now been ordered starting from fastest changing, + // just fill in strides by iterating linearly over them. + int64_t product = 1; + for (size_t i = 0; i < num_dims; ++i) { + stride[sorted_stride_order[i].second.first] = product; + product *= sorted_stride_order[i].second.second; } - stride[0] = stride[2] * dim[2]; return stride; } +// Generate NHWC stride_order +inline std::vector +generate_NHWC_stride_order(int64_t const num_dims) { + std::vector stride_order(num_dims); + + int64_t order = 0; + stride_order[1] = order++; + for (size_t i = num_dims - 1; i > 1; --i) { + stride_order[i] = order++; + } + stride_order[0] = order; + + return stride_order; +} + +// Generate column major stride_order for matrices +// dim = (*, M, N) where * is batch dimsensions +// strides should be (..., N, 1) +inline std::vector +generate_column_major_stride_order(int64_t const num_dims) { + std::vector stride_order(num_dims); + + int64_t order = num_dims - 1; + std::generate(stride_order.begin(), stride_order.end(), [&order] { return order--; }); + + return stride_order; +} + } // namespace detail } // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/cudnn_frontend_graph_interface.h b/include/cudnn_frontend/cudnn_frontend_graph_interface.h index 90226fbe..86fd0347 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_interface.h +++ b/include/cudnn_frontend/cudnn_frontend_graph_interface.h @@ -49,8 +49,8 @@ class Plans { inline error_t check_support(cudnnHandle_t h) { - auto status = list_of_engine_configs.check_support(h); - return status; + CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.check_support(h)); + return {error_code_t::OK, ""}; } int64_t @@ -184,8 +184,8 @@ Plans::filter_out_numeric_notes(std::vector const & inline error_t Plans::build_all_plans(cudnnHandle_t h) { - auto status = list_of_engine_configs.build_all_plans(h); - return status; + CHECK_CUDNN_FRONTEND_ERROR(list_of_engine_configs.build_all_plans(h)); + return {error_code_t::OK, ""}; } inline int64_t @@ -306,6 +306,14 @@ class Graph : public INode { std::shared_ptr, std::shared_ptr, Scaled_dot_product_flash_attention_attributes); + std::array, 3> scaled_dot_product_flash_attention_backward( + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Scaled_dot_product_flash_attention_backward_attributes); Plans get_execution_plan_list(HeurMode_t mode); @@ -342,12 +350,7 @@ class Graph : public INode { createOperationGraphs(cudnnHandle_t handle) override final { getLogger() << "Operation Graph has " << operations.size() << " operations." << std::endl; - auto status = create_cudnn_operation_graphs(handle); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: " << status.get_code() - << " Failed to create execution plans for graph partitioning in FlatNode." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(create_cudnn_operation_graphs(handle)); return {error_code_t::OK, ""}; } @@ -742,4 +745,30 @@ Graph::scaled_dot_product_flash_attention(std::shared_ptr q, return {O, Stats}; } +inline std::array, 3> +Graph::scaled_dot_product_flash_attention_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr Stats, + Scaled_dot_product_flash_attention_backward_attributes options) { + // Set inputs + options.inputs.Q = q; + options.inputs.K = k; + options.inputs.V = v; + options.inputs.O = o; + options.inputs.dO = dO; + options.inputs.Stats = Stats; + + // Make required output tensors + auto dQ = options.outputs.dQ = output_tensor(options.get_name() + "::dQ"); + auto dK = options.outputs.dK = output_tensor(options.get_name() + "::dK"); + auto dV = options.outputs.dV = output_tensor(options.get_name() + "::dV"); + + sub_nodes.emplace_back(std::make_unique(std::move(options), context)); + + return {dQ, dK, dV}; +} + } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/cudnn_frontend_graph_properties.h b/include/cudnn_frontend/cudnn_frontend_graph_properties.h index 4f9ab115..9982a8ce 100644 --- a/include/cudnn_frontend/cudnn_frontend_graph_properties.h +++ b/include/cudnn_frontend/cudnn_frontend_graph_properties.h @@ -165,6 +165,7 @@ class Operation { Reshape, Scaled_dot_product_attention, Scaled_dot_product_flash_attention, + Scaled_dot_product_flash_attention_backward, Softmax, }; Tag tag; @@ -187,27 +188,28 @@ class Operation { virtual ~Operation() = default; }; -NLOHMANN_JSON_SERIALIZE_ENUM(Operation::Tag, - { - {Operation::Tag::BN, "BN"}, - {Operation::Tag::BN_inference, "BN_inference"}, - {Operation::Tag::BN_finalize, "BN_finalize"}, - {Operation::Tag::Conv_fprop, "Conv_fprop"}, - {Operation::Tag::Conv_dgrad, "Conv_dgrad"}, - {Operation::Tag::Conv_wgrad, "Conv_wgrad"}, - {Operation::Tag::DBN, "DBN"}, - {Operation::Tag::DBN_weight, "DBN_weight"}, - {Operation::Tag::Genstats, "Genstats"}, - {Operation::Tag::Matmul, "Matmul"}, - {Operation::Tag::Pointwise, "Pointwise"}, - {Operation::Tag::Reduction, "Reduction"}, - {Operation::Tag::Rng, "Rng"}, - {Operation::Tag::Reshape, "Reshape"}, - {Operation::Tag::Scaled_dot_product_attention, "Scaled_dot_product_attention"}, - {Operation::Tag::Scaled_dot_product_flash_attention, - "Scaled_dot_product_flash_attention"}, - {Operation::Tag::Softmax, "Softmax"}, - }) +NLOHMANN_JSON_SERIALIZE_ENUM( + Operation::Tag, + { + {Operation::Tag::BN, "BN"}, + {Operation::Tag::BN_inference, "BN_inference"}, + {Operation::Tag::BN_finalize, "BN_finalize"}, + {Operation::Tag::Conv_fprop, "Conv_fprop"}, + {Operation::Tag::Conv_dgrad, "Conv_dgrad"}, + {Operation::Tag::Conv_wgrad, "Conv_wgrad"}, + {Operation::Tag::DBN, "DBN"}, + {Operation::Tag::DBN_weight, "DBN_weight"}, + {Operation::Tag::Genstats, "Genstats"}, + {Operation::Tag::Matmul, "Matmul"}, + {Operation::Tag::Pointwise, "Pointwise"}, + {Operation::Tag::Reduction, "Reduction"}, + {Operation::Tag::Rng, "Rng"}, + {Operation::Tag::Reshape, "Reshape"}, + {Operation::Tag::Scaled_dot_product_attention, "Scaled_dot_product_attention"}, + {Operation::Tag::Scaled_dot_product_flash_attention, "Scaled_dot_product_flash_attention"}, + {Operation::Tag::Scaled_dot_product_flash_attention_backward, "Scaled_dot_product_flash_attention_backward"}, + {Operation::Tag::Softmax, "Softmax"}, + }) class BN_finalize_attributes : public Operation { public: @@ -875,9 +877,15 @@ class Layernorm_backward_attributes : public Operation { inputs.SCALE->fill_from_context(context); inputs.DY->fill_from_context(context); - if (inputs.MEAN) { inputs.MEAN->fill_from_context(context);} - if (inputs.INV_VARIANCE) {inputs.INV_VARIANCE->fill_from_context(context);} - if (inputs.EPSILON) {inputs.EPSILON->fill_from_context(context);} + if (inputs.MEAN) { + inputs.MEAN->fill_from_context(context); + } + if (inputs.INV_VARIANCE) { + inputs.INV_VARIANCE->fill_from_context(context); + } + if (inputs.EPSILON) { + inputs.EPSILON->fill_from_context(context); + } outputs.DX->fill_from_context(context); outputs.DSCALE->fill_from_context(context); @@ -1606,6 +1614,108 @@ class Scaled_dot_product_flash_attention_attributes : public Operation { } }; +class Scaled_dot_product_flash_attention_backward_attributes : public Operation { + public: + struct Inputs { + std::shared_ptr Q; + std::shared_ptr K; + std::shared_ptr V; + std::shared_ptr O; + std::shared_ptr dO; + std::shared_ptr Stats; + std::shared_ptr Attn_scale; + std::shared_ptr Bias; + std::shared_ptr Seed; + std::shared_ptr Offset; + std::shared_ptr Dropout_mask; + std::shared_ptr Dropout_scale; + std::shared_ptr Dropout_scale_inv; + } inputs; + + struct Outputs { + std::shared_ptr dQ; + std::shared_ptr dK; + std::shared_ptr dV; + } outputs; + + bool causal_mask = false; + std::optional dropout_probability; + + public: + Scaled_dot_product_flash_attention_backward_attributes() + : Operation(Tag::Scaled_dot_product_flash_attention_backward) {} + + Scaled_dot_product_flash_attention_backward_attributes& + set_attn_scale(std::shared_ptr value) { + inputs.Attn_scale = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_bias(std::shared_ptr value) { + inputs.Bias = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_causal_mask(bool const value) { + causal_mask = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_dropout(float const probability, + std::shared_ptr seed, + std::shared_ptr offset) { + dropout_probability = probability; + inputs.Seed = seed; + inputs.Offset = offset; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_dropout(std::shared_ptr mask, + std::shared_ptr scale, + std::shared_ptr scale_inv) { + inputs.Dropout_mask = mask; + inputs.Dropout_scale = scale; + inputs.Dropout_scale_inv = scale_inv; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_compute_data_type(DataType_t const value) { + compute_data_type = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + set_name(std::string const& value) { + name = value; + return *this; + } + + Scaled_dot_product_flash_attention_backward_attributes& + fill_from_context(detail::Context const& context) { + // Fill node's tensors + inputs.Q->fill_from_context(context); + inputs.K->fill_from_context(context); + inputs.V->fill_from_context(context); + inputs.O->fill_from_context(context); + inputs.dO->fill_from_context(context); + inputs.Stats->fill_from_context(context); + outputs.dQ->fill_from_context(context); + outputs.dK->fill_from_context(context); + outputs.dV->fill_from_context(context); + + // Fill this node + if (get_compute_data_type() == DataType_t::NOT_SET) { + set_compute_data_type(context.get_compute_data_type()); + } + return *this; + } +}; + class Softmax_attributes : public Operation { public: struct Inputs { diff --git a/include/cudnn_frontend/cudnn_frontend_node_interface.h b/include/cudnn_frontend/cudnn_frontend_node_interface.h index b60c66d0..f6cc0ff4 100644 --- a/include/cudnn_frontend/cudnn_frontend_node_interface.h +++ b/include/cudnn_frontend/cudnn_frontend_node_interface.h @@ -55,10 +55,7 @@ class INode : public ICudnn { assign_uids() { CHECK_CUDNN_FRONTEND_ERROR(assign_uids_node()); for (auto const& sub_node : sub_nodes) { - auto status = sub_node->assign_uids(); - if (status.is_bad()) { - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(sub_node->assign_uids()); } return {error_code_t::OK, ""}; } @@ -88,24 +85,24 @@ class INode : public ICudnn { } virtual error_t - pass_by_value_tensors_(std::unordered_map, pass_by_values_t>&, - [[maybe_unused]] void* node_workspace) { + pass_by_value_tensors_(cudnnHandle_t, + std::unordered_map, pass_by_values_t>&, + void*) { return {error_code_t::OK, ""}; } error_t gather_pass_by_value_tensors( + cudnnHandle_t const& handle, std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, void* fe_workspace) { void* node_workspace = fe_workspace; - CHECK_CUDNN_FRONTEND_ERROR(pass_by_value_tensors_(tensor_to_pass_by_value, node_workspace)); + CHECK_CUDNN_FRONTEND_ERROR(pass_by_value_tensors_(handle, tensor_to_pass_by_value, node_workspace)); node_workspace = static_cast(node_workspace) + get_fe_workspace_size_node(); for (auto const& sub_node : sub_nodes) { - auto status = sub_node->gather_pass_by_value_tensors(tensor_to_pass_by_value, node_workspace); + CHECK_CUDNN_FRONTEND_ERROR( + sub_node->gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, node_workspace)); node_workspace = static_cast(node_workspace) + sub_node->get_fe_workspace_size_node(); - if (status.get_code() != error_code_t::OK) { - return status; - } } return {error_code_t::OK, ""}; } @@ -140,10 +137,7 @@ class INode : public ICudnn { virtual error_t createTensors() { for (auto const& sub_node : sub_nodes) { - auto status = sub_node->createTensors(); - if (status.get_code() != error_code_t::OK) { - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(sub_node->createTensors()); } return {error_code_t::OK, ""}; } @@ -156,10 +150,7 @@ class INode : public ICudnn { virtual error_t createOperations() { for (auto const& sub_node : sub_nodes) { - auto status = sub_node->createOperations(); - if (status.is_bad()) { - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(sub_node->createOperations()); // Roll up operations to parent node, so that parent can too partition operation graphs. for (auto&& operation_with_uids : sub_node->operations) { @@ -182,26 +173,14 @@ class INode : public ICudnn { } // validate self - auto status = validate_node(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Validation failed." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(validate_node()); // infer_properties self - status = infer_properties_node(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Infer properties failed." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(infer_properties_node()); // validate sub nodes for (auto const& sub_node : sub_nodes) { - status = sub_node->validate(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Validation failed." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(sub_node->validate()); } has_validation_checked = true; @@ -210,36 +189,11 @@ class INode : public ICudnn { error_t build_operation_graph(cudnnHandle_t handle) { - auto status = validate(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; - return status; - } - - status = assign_uids(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; - return status; - } - - status = createTensors(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; - return status; - } - - status = createOperations(); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; - return status; - } - - status = createOperationGraphs(handle); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to build." << std::endl; - return status; - } - + CHECK_CUDNN_FRONTEND_ERROR(validate()); + CHECK_CUDNN_FRONTEND_ERROR(assign_uids()); + CHECK_CUDNN_FRONTEND_ERROR(createTensors()); + CHECK_CUDNN_FRONTEND_ERROR(createOperations()); + CHECK_CUDNN_FRONTEND_ERROR(createOperationGraphs(handle)); return {error_code_t::OK, ""}; } @@ -264,11 +218,7 @@ class INode : public ICudnn { void* fe_workspace = workspace; void* cudnn_workspace = static_cast(fe_workspace) + get_fe_workspace_size(); - auto status = gather_pass_by_value_tensors(tensor_to_pass_by_value, fe_workspace); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Failed to gather_pass_by_value_tensors." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(gather_pass_by_value_tensors(handle, tensor_to_pass_by_value, fe_workspace)); // Add pass_by_value data pointers to tensor_uid_to_pointer map // object lifetime is controlled by tensor_to_pass_by_value which means the pointer should stay valid during @@ -281,19 +231,14 @@ class INode : public ICudnn { } else if (void** void_value_ptr = std::get_if(&value)) { tensor_uid_to_pointer_map.emplace(tensor->get_uid(), *void_value_ptr); } else { - status.code = error_code_t::INVALID_VARIANT_PACK; - status.err_msg = "[cudnn_frontend] ERROR: Unexpected type for pass by value tensor."; - return status; + RETURN_CUDNN_FRONTEND_ERROR_IF( + true, error_code_t::INVALID_VARIANT_PACK, "Unexpected type for pass by value tensor."); } } - status = execute_cudnn_plans(handle, tensor_uid_to_pointer_map, cudnn_workspace); - if (status.is_bad()) { - getLogger() << "[cudnn_frontend] ERROR: Execution failed." << std::endl; - return status; - } + CHECK_CUDNN_FRONTEND_ERROR(execute_cudnn_plans(handle, tensor_uid_to_pointer_map, cudnn_workspace)); - return status; + return {error_code_t::OK, ""}; } INode(detail::Context const& context) : context(context) {} diff --git a/include/cudnn_frontend/node/batchnorm.h b/include/cudnn_frontend/node/batchnorm.h index ce2df20a..ded10230 100644 --- a/include/cudnn_frontend/node/batchnorm.h +++ b/include/cudnn_frontend/node/batchnorm.h @@ -39,7 +39,10 @@ class BatchNormNode : public INode { Y->set_dim(x_tensor_dim); } if (Y->get_stride().empty()) { - Y->set_stride(detail::generate_stride(Y->get_dim())); + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); } // Set channel length tensors @@ -52,7 +55,10 @@ class BatchNormNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_per_channel_tensors(options.outputs.MEAN); @@ -73,7 +79,10 @@ class BatchNormNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_scalar_tensors(options.inputs.EPSILON); @@ -81,7 +90,10 @@ class BatchNormNode : public INode { for (auto const& peer_stat : options.inputs.peer_stats) { if (peer_stat->get_stride().empty()) { - peer_stat->set_stride(detail::generate_stride(peer_stat->get_dim())); + auto const& peer_stat_dim = peer_stat->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(peer_stat_dim.size()); + peer_stat->set_stride(detail::generate_stride(peer_stat_dim, stride_order)); } } @@ -94,11 +106,9 @@ class BatchNormNode : public INode { << "Validating BatchNormNode " << options.name << "..." << std::endl; // Norm forward phase should be set - if (options.forward_phase == NormFwdPhase_t::NOT_SET) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: Forward phase not set of batchnorm node."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(options.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of batchnorm node."); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/node/batchnorm_inference.h b/include/cudnn_frontend/node/batchnorm_inference.h index 53818742..c67fa7e7 100644 --- a/include/cudnn_frontend/node/batchnorm_inference.h +++ b/include/cudnn_frontend/node/batchnorm_inference.h @@ -39,7 +39,10 @@ class BatchnormInferenceNode : public INode { Y->set_dim(x_tensor_dim); } if (Y->get_stride().empty()) { - Y->set_stride(detail::generate_stride(Y->get_dim())); + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); } // Set channel length tensors @@ -52,7 +55,10 @@ class BatchnormInferenceNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_per_channel_tensors(attributes.inputs.MEAN); diff --git a/include/cudnn_frontend/node/bn_finalize.h b/include/cudnn_frontend/node/bn_finalize.h index c3e02ffd..af90f48e 100644 --- a/include/cudnn_frontend/node/bn_finalize.h +++ b/include/cudnn_frontend/node/bn_finalize.h @@ -41,7 +41,10 @@ class BatchNormFinalizeNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_per_channel_tensors(options.inputs.SQ_SUM); @@ -65,7 +68,10 @@ class BatchNormFinalizeNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_scalars(options.inputs.EPSILON); diff --git a/include/cudnn_frontend/node/conv_dgrad.h b/include/cudnn_frontend/node/conv_dgrad.h index 0f6eae8b..abb22f90 100644 --- a/include/cudnn_frontend/node/conv_dgrad.h +++ b/include/cudnn_frontend/node/conv_dgrad.h @@ -26,12 +26,9 @@ class DgradNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating DgradNode " << options.name << "..." << std::endl; - if (options.outputs.DX->get_dim().empty()) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: dgrad requires output tensor to have its dims set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(options.outputs.DX->get_dim().empty(), + error_code_t::ATTRIBUTE_NOT_SET, + "dgrad requires output tensor to have its dims set."); return {error_code_t::OK, ""}; } @@ -55,7 +52,10 @@ class DgradNode : public INode { // No dim inferencing as inverse mapping from DY, W to DX is not unique. // Only infer strides if user did not set them if (DX->get_stride().empty()) { - DX->set_stride(detail::generate_stride(DX->get_dim())); + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/conv_fprop.h b/include/cudnn_frontend/node/conv_fprop.h index 3d0ababb..563edbba 100644 --- a/include/cudnn_frontend/node/conv_fprop.h +++ b/include/cudnn_frontend/node/conv_fprop.h @@ -56,7 +56,10 @@ class ConvolutionNode : public INode { Y->set_dim(y_tensor_dim); } if (Y->get_stride().empty()) { - Y->set_stride(detail::generate_stride(Y->get_dim())); + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/conv_wgrad.h b/include/cudnn_frontend/node/conv_wgrad.h index a7fc41f7..b9c942a8 100644 --- a/include/cudnn_frontend/node/conv_wgrad.h +++ b/include/cudnn_frontend/node/conv_wgrad.h @@ -40,7 +40,10 @@ class WgradNode : public INode { // No dim inferencing as inverse mapping from DY, X to DX is not unique. // Only infer strides if user did not set them if (DW->get_stride().empty()) { - DW->set_stride(detail::generate_stride(DW->get_dim())); + auto const& DW_dim = DW->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DW_dim.size()); + DW->set_stride(detail::generate_stride(DW_dim, stride_order)); } return {error_code_t::OK, ""}; @@ -51,12 +54,9 @@ class WgradNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating WgradNode " << options.name << "..." << std::endl; - if (options.outputs.DW->get_dim().empty()) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: wgrad requires output tensor to have its dims set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(options.outputs.DW->get_dim().empty(), + error_code_t::ATTRIBUTE_NOT_SET, + "wgrad requires output tensor to have its dims set."); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/node/dbn.h b/include/cudnn_frontend/node/dbn.h index 393f4a5e..bc6627ad 100644 --- a/include/cudnn_frontend/node/dbn.h +++ b/include/cudnn_frontend/node/dbn.h @@ -27,11 +27,10 @@ class DBNNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating DBNNode " << options.name << "..." << std::endl; - if (!(options.inputs.MEAN) && !(options.inputs.INV_VARIANCE) && !(options.inputs.EPSILON)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: Either saved mean/inv_variance or epsilon required."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + !(options.inputs.MEAN) && !(options.inputs.INV_VARIANCE) && !(options.inputs.EPSILON), + error_code_t::ATTRIBUTE_NOT_SET, + "Either saved mean/inv_variance or epsilon required."); return {error_code_t::OK, ""}; } @@ -55,7 +54,10 @@ class DBNNode : public INode { DY->set_dim(x_tensor_dim); } if (DY->get_stride().empty()) { - DY->set_stride(detail::generate_stride(DY->get_dim())); + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); } auto DX = options.outputs.DX; @@ -66,7 +68,10 @@ class DBNNode : public INode { DX->set_dim(x_tensor_dim); } if (DX->get_stride().empty()) { - DX->set_stride(detail::generate_stride(DX->get_dim())); + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); } // Set channel length tensors @@ -79,7 +84,10 @@ class DBNNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_per_channel_tensors(options.inputs.MEAN); @@ -90,7 +98,10 @@ class DBNNode : public INode { for (auto const& peer_stat : options.inputs.peer_stats) { if (peer_stat->get_stride().empty()) { - peer_stat->set_stride(detail::generate_stride(peer_stat->get_dim())); + auto const& peer_stat_dim = peer_stat->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(peer_stat_dim.size()); + peer_stat->set_stride(detail::generate_stride(peer_stat_dim, stride_order)); } } diff --git a/include/cudnn_frontend/node/dbn_weight.h b/include/cudnn_frontend/node/dbn_weight.h index 0a09a796..d589bd55 100644 --- a/include/cudnn_frontend/node/dbn_weight.h +++ b/include/cudnn_frontend/node/dbn_weight.h @@ -41,7 +41,10 @@ class DBNWeightNode : public INode { X->set_dim(dy_tensor_dim); } if (X->get_stride().empty()) { - X->set_stride(detail::generate_stride(X->get_dim())); + auto const& X_dim = X->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(X_dim.size()); + X->set_stride(detail::generate_stride(X_dim, stride_order)); } // Set channel length tensors @@ -54,7 +57,10 @@ class DBNWeightNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_per_channel_tensors(options.inputs.MEAN); diff --git a/include/cudnn_frontend/node/dln.h b/include/cudnn_frontend/node/dln.h index 3147a688..c90d4558 100644 --- a/include/cudnn_frontend/node/dln.h +++ b/include/cudnn_frontend/node/dln.h @@ -27,12 +27,10 @@ class DLNNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating DLNNode " << options.name << "..." << std::endl; - if (!(options.inputs.MEAN) && !(options.inputs.INV_VARIANCE) && !(options.inputs.EPSILON) && - !(options.inputs.SCALE)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: Either saved mean/inv_variance/scale or epsilon required."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.inputs.MEAN) && !(options.inputs.INV_VARIANCE) && + !(options.inputs.EPSILON) && !(options.inputs.SCALE), + error_code_t::ATTRIBUTE_NOT_SET, + "Either saved mean/inv_variance/scale or epsilon required."); return {error_code_t::OK, ""}; } @@ -57,7 +55,10 @@ class DLNNode : public INode { DY->set_dim(x_tensor_dim); } if (DY->get_stride().empty()) { - DY->set_stride(detail::generate_stride(DY->get_dim())); + auto const& DY_dim = DY->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DY_dim.size()); + DY->set_stride(detail::generate_stride(DY_dim, stride_order)); } auto DX = options.outputs.DX; @@ -68,7 +69,10 @@ class DLNNode : public INode { DX->set_dim(x_tensor_dim); } if (DX->get_stride().empty()) { - DX->set_stride(detail::generate_stride(DX->get_dim())); + auto const& DX_dim = DX->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(DX_dim.size()); + DX->set_stride(detail::generate_stride(DX_dim, stride_order)); } auto scale_bias_dim = X->get_dim(); @@ -84,7 +88,10 @@ class DLNNode : public INode { mean->set_dim(stats_dim); } if (mean->get_stride().empty()) { - mean->set_stride(detail::generate_stride(mean->get_dim())); + auto const& mean_dim = mean->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(mean_dim.size()); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); } auto inv_var = options.inputs.INV_VARIANCE; @@ -92,7 +99,10 @@ class DLNNode : public INode { inv_var->set_dim(stats_dim); } if (inv_var->get_stride().empty()) { - inv_var->set_stride(detail::generate_stride(inv_var->get_dim())); + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); } // Set channel length tensors @@ -103,7 +113,10 @@ class DLNNode : public INode { T->set_dim(scale_bias_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; diff --git a/include/cudnn_frontend/node/genstats.h b/include/cudnn_frontend/node/genstats.h index 66e91e83..c6e5e7f6 100644 --- a/include/cudnn_frontend/node/genstats.h +++ b/include/cudnn_frontend/node/genstats.h @@ -41,7 +41,10 @@ class GenstatsNode : public INode { SUM->set_dim(sum_tensor_dim); } if (SUM->get_stride().empty()) { - SUM->set_stride(detail::generate_stride(SUM->get_dim())); + auto const& SUM_dim = SUM->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(SUM_dim.size()); + SUM->set_stride(detail::generate_stride(SUM_dim, stride_order)); } // Only infer dims and strides if user did not set them @@ -51,7 +54,10 @@ class GenstatsNode : public INode { SQ_SUM->set_dim(sq_sum_tensor_dim); } if (SQ_SUM->get_stride().empty()) { - SQ_SUM->set_stride(detail::generate_stride(SQ_SUM->get_dim())); + auto const& SQ_SUM_dim = SQ_SUM->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(SQ_SUM_dim.size()); + SQ_SUM->set_stride(detail::generate_stride(SQ_SUM_dim, stride_order)); } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/layernorm.h b/include/cudnn_frontend/node/layernorm.h index ff067eab..1f4bfe8c 100644 --- a/include/cudnn_frontend/node/layernorm.h +++ b/include/cudnn_frontend/node/layernorm.h @@ -40,7 +40,10 @@ class LayerNormNode : public INode { Y->set_dim(x_tensor_dim); } if (Y->get_stride().empty()) { - Y->set_stride(detail::generate_stride(Y->get_dim())); + auto const& Y_dim = Y->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(Y_dim.size()); + Y->set_stride(detail::generate_stride(Y_dim, stride_order)); } // scale_bias dim is 1,c,h,w @@ -58,7 +61,10 @@ class LayerNormNode : public INode { scale->set_dim(scale_bias_dim); } if (scale->get_stride().empty()) { - scale->set_stride(detail::generate_stride(scale->get_dim())); + auto const& scale_dim = scale->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(scale_dim.size()); + scale->set_stride(detail::generate_stride(scale_dim, stride_order)); } auto bias = options.inputs.BIAS; @@ -66,7 +72,10 @@ class LayerNormNode : public INode { bias->set_dim(scale_bias_dim); } if (bias->get_stride().empty()) { - bias->set_stride(detail::generate_stride(bias->get_dim())); + auto const& bias_dim = bias->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(bias_dim.size()); + bias->set_stride(detail::generate_stride(bias_dim, stride_order)); } if (options.forward_phase == NormFwdPhase_t::TRAINING) { @@ -75,7 +84,10 @@ class LayerNormNode : public INode { mean->set_dim(stats_dim); } if (mean->get_stride().empty()) { - mean->set_stride(detail::generate_stride(mean->get_dim())); + auto const& mean_dim = mean->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(mean_dim.size()); + mean->set_stride(detail::generate_stride(mean_dim, stride_order)); } auto inv_var = options.outputs.INV_VARIANCE; @@ -83,7 +95,10 @@ class LayerNormNode : public INode { inv_var->set_dim(stats_dim); } if (inv_var->get_stride().empty()) { - inv_var->set_stride(detail::generate_stride(inv_var->get_dim())); + auto const& inv_var_dim = inv_var->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(inv_var_dim.size()); + inv_var->set_stride(detail::generate_stride(inv_var_dim, stride_order)); } } @@ -96,7 +111,10 @@ class LayerNormNode : public INode { T->set_dim(tensor_dim); } if (T->get_stride().empty()) { - T->set_stride(detail::generate_stride(T->get_dim())); + auto const& T_dim = T->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(T_dim.size()); + T->set_stride(detail::generate_stride(T_dim, stride_order)); } }; infer_scalar_tensors(options.inputs.EPSILON); @@ -110,11 +128,9 @@ class LayerNormNode : public INode { << "Validating LayerNormNode " << options.name << "..." << std::endl; // Norm forward phase should be set - if (options.forward_phase == NormFwdPhase_t::NOT_SET) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: Forward phase not set of layernorm node."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(options.forward_phase == NormFwdPhase_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Forward phase not set of layernorm node."); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index 8ddbb7e9..c8cd3ccd 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -26,23 +26,11 @@ class MatmulNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating matmul node " << options.name << "..." << std::endl; - if (!(options.inputs.A)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: matmul A not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.inputs.A), error_code_t::ATTRIBUTE_NOT_SET, "matmul A not set."); - if (!(options.inputs.B)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: matmul B not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.inputs.B), error_code_t::ATTRIBUTE_NOT_SET, "matmul B not set."); - if (!(options.outputs.C)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: matmul C not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.outputs.C), error_code_t::ATTRIBUTE_NOT_SET, "matmul C not set."); return {error_code_t::OK, ""}; } @@ -72,7 +60,10 @@ class MatmulNode : public INode { c_tensor->set_dim(c_tensor_dim); } if (c_tensor->get_stride().empty()) { - c_tensor->set_stride(detail::generate_stride(c_tensor->get_dim())); + auto const& c_dim = c_tensor->get_dim(); + // Default to Col major + auto const& stride_order = detail::generate_column_major_stride_order(c_dim.size()); + c_tensor->set_stride(detail::generate_stride(c_dim, stride_order)); } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index 1b3f5d9a..edbc688b 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -26,45 +26,26 @@ class PointwiseNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating pointwise node " << options.name << "..." << std::endl; - if (options.mode == PointwiseMode_t::NOT_SET) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: pointwise mode not set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.mode == PointwiseMode_t::NOT_SET, error_code_t::ATTRIBUTE_NOT_SET, "pointwise mode not set."); - if (!(options.inputs.IN_0)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: pointwise input IN_0 not set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + !(options.inputs.IN_0), error_code_t::ATTRIBUTE_NOT_SET, "pointwise input IN_0 not set."); auto const port_count = get_pointwise_mode_port_count(options.mode); if (port_count >= 3) { - if (!(options.inputs.IN_1)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: pointwise input IN_1 not set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + !(options.inputs.IN_1), error_code_t::ATTRIBUTE_NOT_SET, "pointwise input IN_1 not set."); } if (port_count >= 4) { - if (!(options.inputs.IN_2)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: pointwise input IN_2 not set."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + !(options.inputs.IN_2), error_code_t::ATTRIBUTE_NOT_SET, "pointwise input IN_2 not set."); } - if (!(options.outputs.OUT_0)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: pointwise output OUT_0 not set in " + options.get_name(); - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.outputs.OUT_0), + error_code_t::ATTRIBUTE_NOT_SET, + "pointwise output OUT_0 not set in " + options.get_name()); return {error_code_t::OK, ""}; } diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index 8509ab54..1b963a21 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -25,17 +25,10 @@ class ReductionNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating reduction node " << options.name << "..." << std::endl; - if (!(options.inputs.X)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: reduction input not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + !(options.inputs.X), error_code_t::ATTRIBUTE_NOT_SET, "reduction input not set."); - if (!(options.outputs.Y)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: reduction Y not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.outputs.Y), error_code_t::ATTRIBUTE_NOT_SET, "reduction Y not set."); return {error_code_t::OK, ""}; } @@ -58,7 +51,10 @@ class ReductionNode : public INode { y_tensor->set_dim(x_tensor_dim); } if (y_tensor->get_stride().empty()) { - y_tensor->set_stride(detail::generate_stride(y_tensor->get_dim())); + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); } return {error_code_t::OK, ""}; diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index 8549f415..84d238ad 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -24,16 +24,10 @@ class ReshapeNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating ReshapeNode " << options.name << "..." << std::endl; - if (nullptr == options.inputs.X) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: reshape input not set."; - return {status, message}; - } - if (nullptr == options.outputs.Y) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: reshape output not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF( + nullptr == options.inputs.X, error_code_t::ATTRIBUTE_NOT_SET, "reshape input not set."); + RETURN_CUDNN_FRONTEND_ERROR_IF( + nullptr == options.outputs.Y, error_code_t::ATTRIBUTE_NOT_SET, "reshape output not set."); return {error_code_t::OK, ""}; } @@ -65,7 +59,10 @@ class ReshapeNode : public INode { if (options.get_stride().size()) { y_tensor->set_stride(options.get_stride()); } else { - y_tensor->set_stride(detail::generate_stride(y_tensor->get_dim())); + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); } } diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index ca0835f3..68ac0c4b 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -24,11 +24,7 @@ class RngNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating RngNode " << options.name << "..." << std::endl; - if (!(options.outputs.Y)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: rng output not set."; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(!(options.outputs.Y), error_code_t::ATTRIBUTE_NOT_SET, "rng output not set."); return {error_code_t::OK, ""}; } @@ -74,7 +70,10 @@ class RngNode : public INode { if (options.get_stride().size()) { y_tensor->set_stride(options.get_stride()); } else { - y_tensor->set_stride(detail::generate_stride(y_tensor->get_dim())); + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); } } diff --git a/include/cudnn_frontend/node/scaled_dot_product_attention.h b/include/cudnn_frontend/node/scaled_dot_product_attention.h index f5b08c13..6a3f87c8 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_attention.h @@ -329,8 +329,9 @@ class ScaledDotProductAttentionNode : public INode { virtual error_t pass_by_value_tensors_( + cudnnHandle_t, std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, - [[maybe_unused]] void* node_workspace) override { + void*) override { half dropout_scale_value = options.dropout_scale; tensor_to_pass_by_value.emplace(options.inputs.Dropout_scale, dropout_scale_value); diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index e15c2e3e..df416c7e 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -37,52 +37,33 @@ class ScaledDotProductFlashAttentionNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating ScaledDotProductFlashAttentionNode " << options.name << "..." << std::endl; - if (options.is_inference.has_value() == false) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: is_infernece attribute not set."; - getLogger() << message << std::endl; - return {status, message}; - } - - if (options.dropout_probability.has_value() && options.inputs.Dropout_mask) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: Using both, custom dropout mask and internal-mask generation using dropout " - "probability, is ill-formed."; - getLogger() << message << std::endl; - return {status, message}; - } - - if (options.dropout_probability.has_value() && options.dropout_probability.value() == 1.0) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: Dropout probability cannot be 1 as corresponding scale wont be well formed."; - getLogger() << message << std::endl; - return {status, message}; - } - - if (context.get_intermediate_data_type() == DataType_t::NOT_SET) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: Intermediate tensor data type needs to be set as internal tensors require it."; - getLogger() << message << std::endl; - return {status, message}; - } - - if (options.padding_mask && (!(options.inputs.SEQ_LEN_Q) || !(options.inputs.SEQ_LEN_KV))) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: Padding mask requires seq_len_q and seq_len_kv to be set."; - getLogger() << message << std::endl; - return {status, message}; - } - - if ((!options.padding_mask) && (options.inputs.SEQ_LEN_Q || options.inputs.SEQ_LEN_KV)) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = - "[cudnn_frontend] ERROR: seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."; - getLogger() << message << std::endl; - return {status, message}; - } + RETURN_CUDNN_FRONTEND_ERROR_IF(options.is_inference.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "is_infernece attribute not set"); + + RETURN_CUDNN_FRONTEND_ERROR_IF(options.dropout_probability.has_value() && options.inputs.Dropout_mask, + error_code_t::ATTRIBUTE_NOT_SET, + "Using both, custom dropout mask and internal-mask generation using dropout " + "probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.dropout_probability.has_value() && options.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF(context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "Intermediate tensor data type needs to be set as internal tensors require it."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.padding_mask && (!(options.inputs.SEQ_LEN_Q) || !(options.inputs.SEQ_LEN_KV)), + error_code_t::ATTRIBUTE_NOT_SET, + "Padding mask requires seq_len_q and seq_len_kv to be set."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + (!options.padding_mask) && (options.inputs.SEQ_LEN_Q || options.inputs.SEQ_LEN_KV), + error_code_t::ATTRIBUTE_NOT_SET, + "seq_len_q and seq_len_kv needs to be set only if padding mask is enabled."); return {error_code_t::OK, ""}; } @@ -404,7 +385,7 @@ class ScaledDotProductFlashAttentionNode : public INode { // Two cases for training: dropout present or not // Special case: Skip dropout when 0.0 probability bool dropout_present = (options.dropout_probability.has_value() && options.dropout_probability.value() != 0.0); - dropout_present = dropout_present || options.inputs.Dropout_mask; + dropout_present = dropout_present || options.inputs.Dropout_mask; if (dropout_present) { // Lower options to rng options @@ -498,24 +479,25 @@ class ScaledDotProductFlashAttentionNode : public INode { virtual error_t pass_by_value_tensors_( + cudnnHandle_t handle, std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, void* node_workspace) override { if (options.dropout_probability.has_value()) { #if CUDNN_VERSION < 8903 - half dropout_scale_value = (1.f / (1.0 - options.dropout_probability.value())); + half dropout_scale_value = (1.0f / (1.0f - options.dropout_probability.value())); #else - float dropout_scale_value = (1.f / (1.0 - options.dropout_probability.value())); + float dropout_scale_value = (1.0f / (1.0f - options.dropout_probability.value())); #endif tensor_to_pass_by_value.emplace(dropout_scale, dropout_scale_value); } if (options.padding_mask) { - float negative_inf_value = std::numeric_limits::min(); + float negative_inf_value = std::numeric_limits::lowest(); tensor_to_pass_by_value.emplace(negative_inf_padding, negative_inf_value); } if (options.causal_mask) { - float negative_inf_value = std::numeric_limits::min(); + float negative_inf_value = std::numeric_limits::lowest(); tensor_to_pass_by_value.emplace(negative_inf_causal, negative_inf_value); } @@ -523,7 +505,10 @@ class ScaledDotProductFlashAttentionNode : public INode { int64_t const h = options.inputs.Q->get_dim()[1]; auto h_alibi_slopes_vector = detail::get_abili_slope(h); - cudaMemcpy(node_workspace, h_alibi_slopes_vector.data(), h * sizeof(float), cudaMemcpyHostToDevice); + cudaStream_t stream; + CHECK_CUDNN_ERROR(cudnnGetStream(handle, &stream)); + CHECK_CUDA_ERROR(cudaMemcpyAsync( + node_workspace, h_alibi_slopes_vector.data(), h * sizeof(float), cudaMemcpyHostToDevice, stream)); tensor_to_pass_by_value.emplace(alibi_slopes, node_workspace); } @@ -531,4 +516,518 @@ class ScaledDotProductFlashAttentionNode : public INode { } }; +class ScaledDotProductFlashAttentionBackwardNode : public INode { + private: + std::shared_ptr negative_inf_causal; + // one_tensor is needed for non-dropout graphs + std::shared_ptr one_tensor; + + // non-virtual node workspace tensors + std::shared_ptr dQ_accum; + int64_t dQ_accum_size = 0; + std::shared_ptr softmax_sum; + int64_t softmax_sum_size = 0; + + public: + Scaled_dot_product_flash_attention_backward_attributes options; + + ScaledDotProductFlashAttentionBackwardNode(Scaled_dot_product_flash_attention_backward_attributes&& options_, + detail::Context const& context) + : INode(context), options(std::move(options_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating ScaledDotProductFlashAttentionBackwardNode" << options.name << "..." << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.dropout_probability.has_value() && options.inputs.Dropout_mask, + error_code_t::ATTRIBUTE_NOT_SET, + "[cudnn_frontend] ERROR: Using both, custom dropout mask and internal-mask generation using dropout " + "probability, is ill-formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.dropout_probability.has_value() && options.dropout_probability.value() == 1.0, + error_code_t::ATTRIBUTE_NOT_SET, + "[cudnn_frontend] ERROR: Dropout probability cannot be 1 as corresponding scale wont be well formed."); + + RETURN_CUDNN_FRONTEND_ERROR_IF( + context.get_intermediate_data_type() == DataType_t::NOT_SET, + error_code_t::ATTRIBUTE_NOT_SET, + "[cudnn_frontend] ERROR: Intermediate tensor data type needs to be set as internal tensors require it."); + + return {error_code_t::OK, ""}; + } + + error_t + infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for ScaledDotProductFlashAttentionBackwardNode " + << options.name << "..." << std::endl; + + options.fill_from_context(context); + + // Gather dims to fill properties of virtual tensors + auto const& q_dim = options.inputs.Q->get_dim(); + auto b = q_dim[0]; + auto h = q_dim[1]; + auto s_q = q_dim[2]; + auto d = q_dim[3]; + auto const& k_dim = options.inputs.K->get_dim(); + auto s_kv = k_dim[3]; + + std::shared_ptr last_output, exp_softmax_output, dp_scaled_output, rng_output; + + // --------------Initialize and create tensors before creating nodes-------------------- + + // one_tensor is needed for non-dropout graphs + one_tensor = std::make_shared(); + one_tensor->set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(DataType_t::FLOAT); + + // create tensors internal to the node + if (options.causal_mask) { + negative_inf_causal = std::make_shared(); + negative_inf_causal->set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(DataType_t::FLOAT); + } + + bool is_dropout_prob = (options.dropout_probability.has_value()); + bool is_dropout_mask = (options.inputs.Dropout_mask != nullptr); + + // if dropout_prob is used, then the node creates scale and scale inverse + // if dropout_mask is used, then the user creates scale and scale_inverse + if (is_dropout_prob) { + options.inputs.Dropout_scale = make_tensor_(true, {1, 1, 1, 1}); + options.inputs.Dropout_scale->set_data_type(DataType_t::FLOAT).set_is_pass_by_value(true); + options.inputs.Dropout_scale_inv = make_tensor_(true, {1, 1, 1, 1}); + options.inputs.Dropout_scale_inv->set_data_type(DataType_t::FLOAT).set_is_pass_by_value(true); + } + + // WAR non-virtual dQAccum is required if it is not + // cudnn verision >= 8.9.5 + // device version >= hopper + // sizeof(dp tensor) <= max_dp_workspace + bool war_use_non_virtual_dQAccum = true; + + if (cudnnGetVersion() >= 8905) { + struct cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + if (prop.major >= 9) { + // default upper limit for workspace 256MB + int64_t max_dp_workspace_bytes = 256 * 1024 * 1024; + + // allow setting the upper limit with envvars + char* env_dp_workspace_limit_char = std::getenv("CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"); + if (env_dp_workspace_limit_char != nullptr) { + try { + std::string env_dp_workspace_limit_str(env_dp_workspace_limit_char); + int64_t env_dp_workspace_limit = static_cast(std::stol(env_dp_workspace_limit_str)); + max_dp_workspace_bytes = std::max(max_dp_workspace_bytes, env_dp_workspace_limit); + } catch (...) { + RETURN_CUDNN_FRONTEND_ERROR_IF(true, + error_code_t::ATTRIBUTE_NOT_SET, + "Invalid argument for CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT " + "(int64_t; in bytes)"); + } + } + + int64_t workspace_s_q = ((s_q + 64 - 1) / 64) * 64; + int64_t workspace_s_kv = ((s_kv + 64 - 1) / 64) * 64; + int64_t required_dp_workspace_bytes = b * h * workspace_s_q * workspace_s_kv * 2; + required_dp_workspace_bytes = (required_dp_workspace_bytes + 1024 * 1024 - 1) / (1024 * 1024); + + if (required_dp_workspace_bytes <= max_dp_workspace_bytes) { + war_use_non_virtual_dQAccum = false; + } + } + } + + if (war_use_non_virtual_dQAccum) { + dQ_accum = make_tensor_(false, {b, h, s_q, d}); + dQ_accum->set_data_type(DataType_t::FLOAT).set_reordering_type(TensorReordering_t::F16x16); + dQ_accum_size = b * h * s_q * d * sizeof(float); + } + + // non-virtual softmax_sum is required for below cuDNN 8.9.5 + if (cudnnGetVersion() < 8905) { + softmax_sum = make_tensor_(false, {b, h, s_q, 1}); + softmax_sum->set_data_type(DataType_t::FLOAT); + softmax_sum_size = b * h * s_q * sizeof(float); + } + + // --------------RNG node-------------------- + + if (is_dropout_prob) { + Rng_attributes rng_attr; + rng_attr.set_distribution(RngDistribution_t::BERNOULLI); + rng_attr.set_bernoulli_probability(1.0f - options.dropout_probability.value()); + rng_attr.inputs.Seed = options.inputs.Seed; + rng_attr.inputs.Offset = options.inputs.Offset; + rng_attr.outputs.Y = rng_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(rng_attr), context)); + } else if (is_dropout_mask) { + rng_output = options.inputs.Dropout_mask; + } + + // --------------"dO * o => softmax_sum" chain-------------------- + + // pointwise mul: dO * O + Pointwise_attributes pw_mul_dO_O_attr; + pw_mul_dO_O_attr.set_name("pw_mul_dO_O"); + pw_mul_dO_O_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dO_O_attr.inputs.IN_0 = options.inputs.dO; + pw_mul_dO_O_attr.inputs.IN_1 = options.inputs.O; + pw_mul_dO_O_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, d}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dO_O_attr), context)); + + // reduction add: dO * O + Reduction_attributes reduction_add_dO_O_attr; + reduction_add_dO_O_attr.set_name("reduction_add_dO_O"); + reduction_add_dO_O_attr.set_mode(ReductionMode_t::ADD); + reduction_add_dO_O_attr.inputs.X = last_output; + reduction_add_dO_O_attr.outputs.Y = last_output = make_tensor_(true, {b, h, s_q, 1}); + sub_nodes.emplace_back(std::make_unique(std::move(reduction_add_dO_O_attr), context)); + + // pointwise mul: dropout_scale inverse + Pointwise_attributes pw_mul_dropout_scale_inv_attr; + pw_mul_dropout_scale_inv_attr.set_name("pw_mul_dropout_scale_inv"); + pw_mul_dropout_scale_inv_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dropout_scale_inv_attr.inputs.IN_0 = last_output; + if (options.inputs.Dropout_scale_inv != nullptr) { + pw_mul_dropout_scale_inv_attr.inputs.IN_1 = options.inputs.Dropout_scale_inv; + } else { + // WAR dropout scale inverse is needed for non-dropout graphs + pw_mul_dropout_scale_inv_attr.inputs.IN_1 = one_tensor; + } + if (softmax_sum != nullptr) { + pw_mul_dropout_scale_inv_attr.outputs.OUT_0 = softmax_sum; + } else { + pw_mul_dropout_scale_inv_attr.outputs.OUT_0 = softmax_sum = make_tensor_(true, {b, h, s_q, 1}); + } + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dropout_scale_inv_attr), context)); + + // --------------"Q @ KT => exp_softmax => dV" chain-------------------- + + // matmul: Q * K^T + Matmul_attributes matmul_Q_KT_attr; + matmul_Q_KT_attr.set_name("matmul_Q_KT"); + matmul_Q_KT_attr.inputs.A = options.inputs.Q; + matmul_Q_KT_attr.inputs.B = options.inputs.K; + matmul_Q_KT_attr.outputs.C = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(matmul_Q_KT_attr), context)); + + // pointwise mul: P bmmScale + if (options.inputs.Attn_scale != nullptr) { + Pointwise_attributes pw_mul_S_bmm_scale_attr; + pw_mul_S_bmm_scale_attr.set_name("pw_mul_S_bmm_scale"); + pw_mul_S_bmm_scale_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_S_bmm_scale_attr.inputs.IN_0 = last_output; + pw_mul_S_bmm_scale_attr.inputs.IN_1 = options.inputs.Attn_scale; + pw_mul_S_bmm_scale_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_S_bmm_scale_attr), context)); + } + + // pointwise add: bias + if (options.inputs.Bias) { + Pointwise_attributes pw_add_bias_attr; + pw_add_bias_attr.set_name("pw_add_bias"); + pw_add_bias_attr.set_mode(PointwiseMode_t::ADD); + pw_add_bias_attr.inputs.IN_0 = last_output; + pw_add_bias_attr.inputs.IN_1 = options.inputs.Bias; + pw_add_bias_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_add_bias_attr), context)); + } + + // Causal Mask DAG + if (options.causal_mask) { + std::shared_ptr row_index_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr col_index_output = make_tensor_(true, {b, h, s_q, s_kv}); + std::shared_ptr row_gt_col_output = make_tensor_(true, {b, h, s_q, s_kv}); + row_gt_col_output->set_data_type(DataType_t::BOOLEAN); + + // Lower options to generate row index options + Pointwise_attributes row_index_attr; + row_index_attr.set_name("gen_row_index"); + row_index_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + row_index_attr.inputs.IN_0 = last_output; + row_index_attr.outputs.OUT_0 = row_index_output; + sub_nodes.emplace_back(std::make_unique(std::move(row_index_attr), context)); + + Pointwise_attributes col_index_attr; + col_index_attr.set_name("gen_col_index"); + col_index_attr.set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + col_index_attr.inputs.IN_0 = last_output; + col_index_attr.outputs.OUT_0 = col_index_output; + sub_nodes.emplace_back(std::make_unique(std::move(col_index_attr), context)); + + Pointwise_attributes greater_than_attr; + greater_than_attr.set_name("row_greater_than_col"); + greater_than_attr.set_mode(PointwiseMode_t::CMP_GE).set_compute_data_type(DataType_t::BOOLEAN); + greater_than_attr.inputs.IN_0 = row_index_output; + greater_than_attr.inputs.IN_1 = col_index_output; + greater_than_attr.outputs.OUT_0 = row_gt_col_output; + sub_nodes.emplace_back(std::make_unique(std::move(greater_than_attr), context)); + + Pointwise_attributes binary_select_attr; + binary_select_attr.set_name("binary_select"); + binary_select_attr.set_mode(PointwiseMode_t::BINARY_SELECT); + binary_select_attr.inputs.IN_0 = last_output; + binary_select_attr.inputs.IN_1 = negative_inf_causal; + binary_select_attr.inputs.IN_2 = row_gt_col_output; + binary_select_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(binary_select_attr), context)); + } + + // pointwise subtract S + Pointwise_attributes pw_subtract_s_attr; + pw_subtract_s_attr.set_name("pw_subtract_s"); + pw_subtract_s_attr.set_mode(PointwiseMode_t::SUB); + pw_subtract_s_attr.inputs.IN_0 = last_output; + pw_subtract_s_attr.inputs.IN_1 = options.inputs.Stats; + pw_subtract_s_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_subtract_s_attr), context)); + + // pointwise exp softmax + Pointwise_attributes exp_attr; + exp_attr.set_name("exp_softmax"); + exp_attr.set_mode(PointwiseMode_t::EXP); + exp_attr.inputs.IN_0 = last_output; + exp_attr.outputs.OUT_0 = last_output = exp_softmax_output = make_tensor_(true, {b, h, s_q, s_kv}); + last_output->set_data_type(context.get_io_data_type()); + sub_nodes.emplace_back(std::make_unique(std::move(exp_attr), context)); + + // pointwise dropout mask mul + if (is_dropout_prob || is_dropout_mask) { + Pointwise_attributes mask_attr; + mask_attr.set_name("dropout_mask_mul"); + mask_attr.set_mode(PointwiseMode_t::MUL); + mask_attr.inputs.IN_0 = last_output; + mask_attr.inputs.IN_1 = rng_output; + mask_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(mask_attr), context)); + } + + // pointwise dropout scale + if (options.inputs.Dropout_scale != nullptr) { + Pointwise_attributes pw_mul_dropout_scale; + pw_mul_dropout_scale.set_name("pw_mul_dropout_scale"); + pw_mul_dropout_scale.set_mode(PointwiseMode_t::MUL); + pw_mul_dropout_scale.inputs.IN_0 = last_output; + pw_mul_dropout_scale.inputs.IN_1 = options.inputs.Dropout_scale; + pw_mul_dropout_scale.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dropout_scale), context)); + } + + // reshape: transpose S + Reshape_attributes transpose_s_attr; + transpose_s_attr.set_name("transpose_s"); + transpose_s_attr.inputs.X = last_output; + transpose_s_attr.outputs.Y = last_output = + make_tensor_(true, {b, h, s_kv, s_q}, {h * s_q * s_kv, s_q * s_kv, 1, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(transpose_s_attr), context)); + + // matmul: S^T * dO + Matmul_attributes matmul_ST_dO_attr; + matmul_ST_dO_attr.set_name("matmul_ST_dO"); + matmul_ST_dO_attr.inputs.A = last_output; + matmul_ST_dO_attr.inputs.B = options.inputs.dO; + matmul_ST_dO_attr.outputs.C = options.outputs.dV; + sub_nodes.emplace_back(std::make_unique(std::move(matmul_ST_dO_attr), context)); + + // --------------"dO @ VT => dp_scaled_output => dK" chain-------------------- + + // matmul: dO * V^T + Matmul_attributes matmul_dO_VT_attr; + matmul_dO_VT_attr.set_name("matmul_dO_VT"); + matmul_dO_VT_attr.inputs.A = options.inputs.dO; + matmul_dO_VT_attr.inputs.B = options.inputs.V; + matmul_dO_VT_attr.outputs.C = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(matmul_dO_VT_attr), context)); + + // pointwise mul: dS * mask + Pointwise_attributes pw_mul_dS_mask_attr; + pw_mul_dS_mask_attr.set_name("pw_mul_dS_mask"); + pw_mul_dS_mask_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dS_mask_attr.inputs.IN_0 = last_output; + if (is_dropout_prob || is_dropout_mask) { + pw_mul_dS_mask_attr.inputs.IN_1 = rng_output; + } else { + pw_mul_dS_mask_attr.inputs.IN_1 = one_tensor; + } + pw_mul_dS_mask_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dS_mask_attr), context)); + + // pointwise: subtract ds + Pointwise_attributes pw_subtract_ds_attr; + pw_subtract_ds_attr.set_name("pw_subtract_ds"); + pw_subtract_ds_attr.set_mode(PointwiseMode_t::SUB); + pw_subtract_ds_attr.inputs.IN_0 = last_output; + pw_subtract_ds_attr.inputs.IN_1 = softmax_sum; + pw_subtract_ds_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_subtract_ds_attr), context)); + + // pointwise: mul dP + Pointwise_attributes pw_mul_dP_attr; + pw_mul_dP_attr.set_name("pw_mul_dP"); + pw_mul_dP_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dP_attr.inputs.IN_0 = last_output; + pw_mul_dP_attr.inputs.IN_1 = exp_softmax_output; + pw_mul_dP_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dP_attr), context)); + + // pointwise: mul dP_dropout_scale + if (options.inputs.Dropout_scale != nullptr) { + Pointwise_attributes pw_mul_dP_dropout_scale_attr; + pw_mul_dP_dropout_scale_attr.set_name("pw_mul_dP_dropout_scale"); + pw_mul_dP_dropout_scale_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dP_dropout_scale_attr.inputs.IN_0 = last_output; + pw_mul_dP_dropout_scale_attr.inputs.IN_1 = options.inputs.Dropout_scale; + pw_mul_dP_dropout_scale_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dP_dropout_scale_attr), context)); + } + + // pointwise: mul dP_bmmScale + if (options.inputs.Attn_scale != nullptr) { + Pointwise_attributes pw_mul_dP_bmm_scale_attr; + pw_mul_dP_bmm_scale_attr.set_name("pw_mul_dP_bmm_scale"); + pw_mul_dP_bmm_scale_attr.set_mode(PointwiseMode_t::MUL); + pw_mul_dP_bmm_scale_attr.inputs.IN_0 = last_output; + pw_mul_dP_bmm_scale_attr.inputs.IN_1 = options.inputs.Attn_scale; + pw_mul_dP_bmm_scale_attr.outputs.OUT_0 = last_output = make_tensor_(true, {b, h, s_q, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(pw_mul_dP_bmm_scale_attr), context)); + } + + dp_scaled_output = last_output; + + // tranpose dP + Reshape_attributes transpose_dP_attr; + transpose_dP_attr.set_name("transpose_dP"); + transpose_dP_attr.inputs.X = last_output; + transpose_dP_attr.outputs.Y = last_output = + make_tensor_(true, {b, h, s_kv, s_q}, {h * s_q * s_kv, s_q * s_kv, 1, s_kv}); + sub_nodes.emplace_back(std::make_unique(std::move(transpose_dP_attr), context)); + + // matmul: dP^T * Q + Matmul_attributes matmul_dP_Q_attr; + matmul_dP_Q_attr.set_name("matmul_dP_Q"); + matmul_dP_Q_attr.inputs.A = last_output; + matmul_dP_Q_attr.inputs.B = options.inputs.Q; + matmul_dP_Q_attr.outputs.C = options.outputs.dK; + sub_nodes.emplace_back(std::make_unique(std::move(matmul_dP_Q_attr), context)); + + // --------------"dp_scaled_output @ KT => dQ" chain-------------------- + + // transpose K + Reshape_attributes transpose_K_attr; + transpose_K_attr.set_name("transpose_K"); + transpose_K_attr.inputs.X = options.inputs.K; + transpose_K_attr.outputs.Y = last_output = make_tensor_(true, {b, h, s_kv, d}); + sub_nodes.emplace_back(std::make_unique(std::move(transpose_K_attr), context)); + + // matmul: dP * K + Matmul_attributes matmul_dP_K_attr; + matmul_dP_K_attr.set_name("matmul_dP_K"); + matmul_dP_K_attr.inputs.A = dp_scaled_output; + matmul_dP_K_attr.inputs.B = last_output; + if (dQ_accum != nullptr) { + matmul_dP_K_attr.outputs.C = dQ_accum; + } else { + matmul_dP_K_attr.outputs.C = options.outputs.dQ; + } + sub_nodes.emplace_back(std::make_unique(std::move(matmul_dP_K_attr), context)); + + if (dQ_accum != nullptr) { + Pointwise_attributes pw_identity_dQ_attr; + pw_identity_dQ_attr.set_name("pw_identity_dQ"); + pw_identity_dQ_attr.set_mode(PointwiseMode_t::IDENTITY); + pw_identity_dQ_attr.inputs.IN_0 = dQ_accum; + pw_identity_dQ_attr.outputs.OUT_0 = options.outputs.dQ; + sub_nodes.emplace_back(std::make_unique(std::move(pw_identity_dQ_attr), context)); + } + + return {error_code_t::OK, ""}; + } + + virtual int64_t + get_fe_workspace_size_node() const override final { + // set in infer_properties_node() + return dQ_accum_size + softmax_sum_size; + } + + error_t + pass_by_value_tensors_( + cudnnHandle_t handle, + std::unordered_map, pass_by_values_t>& tensor_to_pass_by_value, + void* node_workspace) override { + if (options.causal_mask) { + float negative_inf_value = std::numeric_limits::lowest(); + tensor_to_pass_by_value.emplace(negative_inf_causal, negative_inf_value); + } + + if (options.dropout_probability.has_value()) { + float dropout_scale_value = 1.0f / (1.0f - options.dropout_probability.value()); + float dropout_scale_inv_value = (1.0f - options.dropout_probability.value()); + tensor_to_pass_by_value.emplace(options.inputs.Dropout_scale, dropout_scale_value); + tensor_to_pass_by_value.emplace(options.inputs.Dropout_scale_inv, dropout_scale_inv_value); + } + + // one_tensor is needed for non-dropout graphs + if (one_tensor != nullptr) { + tensor_to_pass_by_value.emplace(one_tensor, 1.0f); + } + + if (dQ_accum != nullptr) { + cudaStream_t stream; + CHECK_CUDNN_ERROR(cudnnGetStream(handle, &stream)); + CHECK_CUDA_ERROR(cudaMemsetAsync(node_workspace, 0, dQ_accum_size, stream)); + tensor_to_pass_by_value.emplace(dQ_accum, node_workspace); + node_workspace = static_cast(node_workspace) + dQ_accum_size; + } + + if (softmax_sum != nullptr) { + // There is no requirement for softmax_sum to be memset to 0 + tensor_to_pass_by_value.emplace(softmax_sum, node_workspace); + } + + return {error_code_t::OK, ""}; + } + + private: + inline std::shared_ptr + make_tensor_(bool is_virtual) { + auto tensor = std::make_shared(); + tensor->set_is_virtual(is_virtual); + return tensor; + } + + inline std::shared_ptr + make_tensor_(bool is_virtual, std::vector const& dim) { + std::vector stride(dim.size()); + int64_t prod = 1; + for (int i = (int)dim.size() - 1; i >= 0; --i) { + stride[i] = prod; + prod *= dim[i]; + } + auto tensor = std::make_shared(); + tensor->set_is_virtual(is_virtual).set_dim(dim).set_stride(stride); + return tensor; + } + + inline std::shared_ptr + make_tensor_(bool is_virtual, std::vector const& dim, std::vector const& stride) { + auto tensor = std::make_shared(); + tensor->set_is_virtual(is_virtual).set_dim(dim).set_stride(stride); + return tensor; + } +}; + } // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/softmax.h b/include/cudnn_frontend/node/softmax.h index 4f390737..0ac6d809 100644 --- a/include/cudnn_frontend/node/softmax.h +++ b/include/cudnn_frontend/node/softmax.h @@ -28,12 +28,8 @@ class SoftmaxNode : public INode { getLogger() << "[cudnn_frontend] INFO: " << "Validating SoftmaxNode " << options.name << "..." << std::endl; - if (options.use_stats.has_value() == false) { - auto status = error_code_t::ATTRIBUTE_NOT_SET; - std::string message = "[cudnn_frontend] ERROR: use_stats attribute not set."; - return {status, message}; - } - + RETURN_CUDNN_FRONTEND_ERROR_IF( + options.use_stats.has_value() == false, error_code_t::ATTRIBUTE_NOT_SET, "use_stats attribute not set."); return {error_code_t::OK, ""}; } diff --git a/python_bindings/cudnn_frontend_properties.cpp b/python_bindings/cudnn_frontend_properties.cpp index 9f169910..011f57dc 100644 --- a/python_bindings/cudnn_frontend_properties.cpp +++ b/python_bindings/cudnn_frontend_properties.cpp @@ -16,15 +16,15 @@ namespace python_bindings { void throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const& error_msg); -void * +void* create_handle() { cudnnHandle_t handle; cudnnCreate(&handle); - return (void *)handle; + return (void*)handle; } void -destroy_handle(void *handle) { +destroy_handle(void* handle) { auto status = cudnnDestroy((cudnnHandle_t)handle); throw_if(status != CUDNN_STATUS_SUCCESS, cudnn_frontend::error_code_t::HANDLE_ERROR, "cudnnHandle Destroy failed"); } @@ -61,8 +61,7 @@ init_properties(py::module_& m) { out << json{props}; return out.str(); }); - } -} -} \ No newline at end of file +} // namespace python_bindings +} // namespace cudnn_frontend \ No newline at end of file diff --git a/python_bindings/cudnn_frontend_pygraph.cpp b/python_bindings/cudnn_frontend_pygraph.cpp index 4421589f..234b17ce 100644 --- a/python_bindings/cudnn_frontend_pygraph.cpp +++ b/python_bindings/cudnn_frontend_pygraph.cpp @@ -43,6 +43,10 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED: throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED: + throw std::runtime_error(error_msg); + case cudnn_frontend::error_code_t::CUDA_API_FAILED: + throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE: throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT: @@ -131,22 +135,21 @@ class PyGraph { cudnn_frontend::DataType_t io_data_type, cudnn_frontend::DataType_t intermediate_data_type, cudnn_frontend::DataType_t compute_data_type, - void * handle_ = nullptr) - : graph(), handle((cudnnHandle_t)handle_), - is_handle_owner(false), is_built(false) { + void* handle_ = nullptr) + : graph(), handle((cudnnHandle_t)handle_), is_handle_owner(false), is_built(false) { graph.set_compute_data_type(compute_data_type) .set_intermediate_data_type(intermediate_data_type) .set_io_data_type(io_data_type); - + if (handle_ == nullptr) { cudnnCreate(&handle); - is_handle_owner = true; + is_handle_owner = true; } } - ~PyGraph() { + ~PyGraph() { if (is_handle_owner) { - cudnnDestroy(handle); + cudnnDestroy(handle); } } @@ -343,7 +346,8 @@ class PyGraph { std::shared_ptr& B, cudnn_frontend::DataType_t const& compute_data_type, std::string const& name) { - auto attributes = cudnn_frontend::graph::Matmul_attributes().set_compute_data_type(compute_data_type).set_name(name); + auto attributes = + cudnn_frontend::graph::Matmul_attributes().set_compute_data_type(compute_data_type).set_name(name); auto C = graph.matmul(A, B, attributes); return C; @@ -507,6 +511,66 @@ class PyGraph { return {O, Stats}; } + std::array, 3> + scaled_dot_product_flash_attention_backward(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& o, + std::shared_ptr& dO, + std::shared_ptr& stats, + std::shared_ptr& attn_scale, + std::shared_ptr& bias, + bool const use_causal_mask, + py::object const& dropout, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::Scaled_dot_product_flash_attention_backward_attributes() + .set_attn_scale(attn_scale) + .set_bias(bias) + .set_causal_mask(use_causal_mask) + .set_compute_data_type(compute_data_type) + .set_name(name); + + py::object cudnn_tensor_type = py::module_::import("cudnn").attr("tensor"); + + if (!dropout.is_none()) { + if (!py::isinstance(dropout)) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + py::tuple dropout_tuple = dropout.cast(); + if (dropout_tuple.size() != 3) { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + + if (py::isinstance(dropout_tuple[0]) && py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const probability = dropout_tuple[0].cast(); + auto const seed = dropout_tuple[1].cast>(); + auto const offset = dropout_tuple[2].cast>(); + attributes.set_dropout(probability, seed, offset); + } else if (py::isinstance(dropout_tuple[0], cudnn_tensor_type) && + py::isinstance(dropout_tuple[1], cudnn_tensor_type) && + py::isinstance(dropout_tuple[2], cudnn_tensor_type)) { + auto const mask = dropout_tuple[0].cast>(); + auto const scale = dropout_tuple[1].cast>(); + auto const scale_inv = + dropout_tuple[2].cast>(); + attributes.set_dropout(mask, scale, scale_inv); + } else { + throw std::runtime_error( + "dropout must be a tuple of (float probability, a seed tensor" + ", and an offset tensor) or (mask tensor, scale tensor)"); + } + } + + auto [dQ, dK, dV] = graph.scaled_dot_product_flash_attention_backward(q, k, v, o, dO, stats, attributes); + return {dQ, dK, dV}; + } + void check_support() { build(); @@ -576,7 +640,7 @@ init_pygraph_submodule(py::module_& m) { cudnn_frontend::DataType_t, cudnn_frontend::DataType_t, cudnn_frontend::DataType_t, - void *>(), + void*>(), py::arg_v("name", "test_graph"), py::arg_v("io_data_type", cudnn_frontend::DataType_t::NOT_SET), py::arg_v("intermediate_data_type", cudnn_frontend::DataType_t::NOT_SET), @@ -803,8 +867,44 @@ init_pygraph_submodule(py::module_& m) { name (Optional[str]): The name of the operation. Returns: - cudnn_tensor: The result of scaled dot-product flash attention. - Optional[cudnn_tensor]: The softmax statistics in case the operation is in a training step. + o (cudnn_tensor): The result of scaled dot-product flash attention. + stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. + )pbdoc") + .def("scaled_dot_product_flash_attention_backward", + &PyGraph::scaled_dot_product_flash_attention_backward, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("o"), + py::arg("dO"), + py::arg("stats"), + py::arg_v("attn_scale", nullptr), + py::arg_v("bias", nullptr), + py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute the key, query, value gradients of scaled dot-product flash attention. + + Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + o (cudnn_tensor): The output data. + dO (cudnn_tensor): The output loss gradient. + stats (cudnn_tensor): The softmax statistics from the forward pass. + attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None. + bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + + Returns: + dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention. + dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention. + dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention. )pbdoc") .def("build", &PyGraph::build) .def("check_support", &PyGraph::check_support) @@ -1765,6 +1865,6 @@ init_pygraph_submodule(py::module_& m) { )pbdoc"); } -} +} // namespace python_bindings -} \ No newline at end of file +} // namespace cudnn_frontend \ No newline at end of file diff --git a/samples/cpp/matmuls.cpp b/samples/cpp/matmuls.cpp index 9efe2fb5..ac680173 100644 --- a/samples/cpp/matmuls.cpp +++ b/samples/cpp/matmuls.cpp @@ -43,12 +43,12 @@ TEST_CASE("Matmul SBR Graph", "[matmul][graph]") { auto scale_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::MUL); auto S = graph.tensor( - fe::graph::Tensor_attributes().set_name("scale").set_dim({4, 16, 32}).set_stride({16 * 32, 1, 16})); + fe::graph::Tensor_attributes().set_name("scale").set_dim({4, 16, 32}).set_stride({16 * 32, 32, 1})); auto scale_output = graph.pointwise(Z, S, scale_options); auto bias_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD); auto B = - graph.tensor(fe::graph::Tensor_attributes().set_name("bias").set_dim({4, 16, 32}).set_stride({16 * 32, 1, 16})); + graph.tensor(fe::graph::Tensor_attributes().set_name("bias").set_dim({4, 16, 32}).set_stride({16 * 32, 32, 1})); auto bias_output = graph.pointwise(scale_output, B, bias_options); auto relu_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::RELU_FWD); diff --git a/samples/cpp/mha.cpp b/samples/cpp/mha.cpp index c82d33a4..f7c42f4e 100644 --- a/samples/cpp/mha.cpp +++ b/samples/cpp/mha.cpp @@ -27,14 +27,25 @@ #include TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { -#if CUDART_VERSION < 12000 - SKIP("Test requires cuda toolkit 12.0 or above"); - return; -#endif - int64_t b = 1; // batch size - int64_t h = 2; // head dim - int64_t s_q = 2048; // q tensor is padded to this seq length - int64_t s_kv = 2048; // k and v tensor is padded to this seq length + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + return; + } + + if (cudnnGetVersion() < 8901) { + SKIP("Test requires cuDNN version 8.9.1 or above"); + return; + } + + if (check_device_arch_newer_than("ampere") == false) { + SKIP("Test requires Hopper or above arch."); + return; + } + + int64_t b = 3; // batch size + int64_t h = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length int64_t d = 128; // hidden dim bool is_inference = false; float dropout_probability = 0.2f; @@ -87,29 +98,28 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { .set_attn_scale(attn_scale) .set_dropout(dropout_probability, seed, offset); -// Optional bias in flash attention is only supported 8.9.3 onwards -#if (CUDNN_VERSION >= 8904) - scaled_dot_product_flash_attention_options.set_alibi_mask(true); -#endif + // Optional bias in flash attention is only supported 8.9.3 onwards + if (cudnnGetVersion() >= 8904) { + scaled_dot_product_flash_attention_options.set_alibi_mask(true); + } -#if (CUDNN_VERSION >= 8903) auto seq_q = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("seq_q") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); auto seq_kv = mha_graph.tensor(fe::graph::Tensor_attributes() - .set_name("seq_kv") - .set_dim({b, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(fe::DataType_t::INT32)); - - scaled_dot_product_flash_attention_options.set_bias(bias) - .set_padding_mask(true) - .set_seq_len_q(seq_q) - .set_seq_len_kv(seq_kv); - scaled_dot_product_flash_attention_options.set_bias(bias); -#endif + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + + if (cudnnGetVersion() >= 8903) { + scaled_dot_product_flash_attention_options.set_bias(bias) + .set_padding_mask(true) + .set_seq_len_q(seq_q) + .set_seq_len_kv(seq_kv); + } auto [O, Stats] = mha_graph.scaled_dot_product_flash_attention(Q, K, V, scaled_dot_product_flash_attention_options); @@ -120,15 +130,6 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } -#if (CUDNN_VERSION < 8900) - SKIP("MHA Graph requires cudnn 8.9 and up"); - return; -#endif - if (check_device_arch_newer_than("hopper") == false) { - SKIP("MHA Graph requires Hopper or above arch."); - return; - } - cudnnHandle_t handle; checkCudnnErr(cudnnCreate(&handle)); @@ -201,14 +202,25 @@ TEST_CASE("Flash with rng dropout", "[graph][mha][flash][forward]") { } TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { -#if CUDART_VERSION < 12000 - SKIP("Test requires cuda toolkit 12.0 or above"); - return; -#endif - int64_t b = 1; // batch size - int64_t h = 2; // head dim - int64_t s_q = 2048; // q tensor is padded to this seq length - int64_t s_kv = 2048; // k and v tensor is padded to this seq length + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + return; + } + + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cuDNN version 8.9.3 or above"); + return; + } + + if (check_device_arch_newer_than("ampere") == false) { + SKIP("Test requires Hopper or above arch."); + return; + } + + int64_t b = 3; // batch size + int64_t h = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length int64_t d = 128; // hidden dim bool is_inference = false; @@ -250,10 +262,10 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { .set_attn_scale(attn_scale) .set_bias(bias); -// Alibi mask in flash attention is only supported 8.9.4 onwards -#if (CUDNN_VERSION >= 8904) - scaled_dot_product_flash_attention_options.set_alibi_mask(true); -#endif + // Alibi mask in flash attention is only supported 8.9.4 onwards + if (cudnnGetVersion() >= 8904) { + scaled_dot_product_flash_attention_options.set_alibi_mask(true); + } auto [O, Stats] = mha_graph.scaled_dot_product_flash_attention(Q, K, V, scaled_dot_product_flash_attention_options); O->set_output(true).set_stride({h * d, d, b * h * d, 1}); @@ -263,16 +275,6 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); } -// No dropout in flash attention only supported 8.9.3 onwards. -#if (CUDNN_VERSION < 8903) - SKIP("MHA Graph requires cudnn 8.9 and up"); - return; -#endif - if (check_device_arch_newer_than("hopper") == false) { - SKIP("MHA Graph requires Hopper or above arch."); - return; - } - cudnnHandle_t handle; checkCudnnErr(cudnnCreate(&handle)); @@ -313,3 +315,176 @@ TEST_CASE("Flash with no dropout", "[graph][mha][flash][forward]") { cudnnDestroy(handle); } + +TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { + if (cudnnGetCudartVersion() < 12000) { + SKIP("Test requires cuda toolkit 12.0 or above"); + return; + } + if (cudnnGetVersion() < 8903) { + SKIP("Test requires cuDNN version 8.9.3 or above"); + return; + } + + if (check_device_arch_newer_than("ampere") == false) { + SKIP("Test requires Hopper or above arch."); + return; + } + + int64_t b = 3; // batch size + int64_t h = 4; // head dim + int64_t s_q = 1024; // q tensor is padded to this seq length + int64_t s_kv = 1024; // k and v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + bool is_bias = true; + float dropout_probability = 0.2f; + + namespace fe = cudnn_frontend; + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::HALF) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + // used for bias, and dropout != 0.0f + std::shared_ptr bias, dropout_seed, dropout_offset; + + auto q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d}) + .set_stride({h * s_q * d, s_q * d, d, 1})); + auto k = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, h, d, s_kv}) + .set_stride({h * s_kv * d, s_kv * d, 1, d})); + auto v = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, h, d, s_kv}) + .set_stride({h * s_kv * d, s_kv * d, 1, d})); + auto o = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d}) + .set_stride({h * s_q * d, s_q * d, d, 1})); + auto dO = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d}) + .set_stride({h * s_q * d, s_q * d, d, 1})); + auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + + auto attn_scale = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("attn_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(fe::DataType_t::FLOAT)); + + if (is_bias) { + bias = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({b, 1, s_q, s_kv}) + .set_stride({s_q * s_kv, s_q * s_kv, s_kv, 1})); + } + + if (dropout_probability != 0.0f) { + dropout_seed = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + dropout_offset = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + } + + auto scaled_dot_product_flash_attention_backward_options = fe::graph::Scaled_dot_product_flash_attention_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + if (is_bias) { + scaled_dot_product_flash_attention_backward_options.set_bias(bias); + } + + if (dropout_probability != 0.0f) { + scaled_dot_product_flash_attention_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } + + auto [dQ, dK, dV] = mha_graph.scaled_dot_product_flash_attention_backward(q, k, v, o, dO, stats, scaled_dot_product_flash_attention_backward_options); + + dQ->set_output(true).set_dim({b, h, s_q, d}).set_stride({h * s_q * d, s_q * d, d, 1}); + dK->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); + dV->set_output(true).set_dim({b, h, s_kv, d}).set_stride({h * s_kv * d, s_kv * d, d, 1}); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(mha_graph.validate().is_good()); + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + + auto plans = mha_graph.get_execution_plan_list(fe::HeurMode_t::HEUR_MODE_A); + + REQUIRE(plans.check_support(handle).is_good()); + + REQUIRE(mha_graph.set_execution_plans(plans).is_good()); + + // build variant pack + // inputs + Surface q_tensor(b * h * s_q * d, false); + Surface k_tensor(b * h * d * s_kv, false); + Surface v_tensor(b * h * d * s_kv, false); + Surface o_tensor(b * h * s_q * d, false); + Surface dO_tensor(b * h * s_q * d, false); + Surface stats_tensor(b * h * s_q * 1, false); + // outputs + Surface dQ_tensor(b * h * s_q * d, false); + Surface dK_tensor(b * h * s_kv * d, false); + Surface dV_tensor(b * h * s_kv * d, false); + + float attn_scale_cpu = 0.5f; + + Surface bias_tensor(b * 1 * s_q * s_kv, false); + + int32_t seed_value = 123456; + int32_t offset_value = 789; + Surface dropout_seed_tensor(1, false, seed_value); + Surface dropout_offset_tensor(1, false, offset_value); + + std::unordered_map, void*> variant_pack = { + // inputs + {q, q_tensor.devPtr}, + {k, k_tensor.devPtr}, + {v, v_tensor.devPtr}, + {o, o_tensor.devPtr}, + {dO, dO_tensor.devPtr}, + {stats, stats_tensor.devPtr}, + // outputs + {dQ, dQ_tensor.devPtr}, + {dK, dK_tensor.devPtr}, + {dV, dV_tensor.devPtr}, + // pass by value + {attn_scale, &attn_scale_cpu} + }; + + if (is_bias) { + variant_pack[bias] = bias_tensor.devPtr; + } + + if (dropout_probability != 0.0f) { + variant_pack[dropout_seed] = dropout_seed_tensor.devPtr; + variant_pack[dropout_offset] = dropout_offset_tensor.devPtr; + } + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} \ No newline at end of file diff --git a/samples/python/matmul_bias_relu.py b/samples/python/matmul_bias_relu.py deleted file mode 100644 index aea29541..00000000 --- a/samples/python/matmul_bias_relu.py +++ /dev/null @@ -1,45 +0,0 @@ -import cudnn -import numpy as np -import cupy as cp -import sys -print("Example 2. Executing the Matmul + bias + relu graph") - -if cudnn.backend_version() < 8500: - print("cudnn version does not support matmul+bias fusion for specified layout") - exit(0) - -graph = cudnn.pygraph(io_data_type = cudnn.data_type.HALF, intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) - -image = graph.tensor(name = "image", dim = [4,16,64], stride = [1024,1,16]) -weight = graph.tensor(name = "weight", dim = [4,64,16], stride = [1024,1,64]) -bias = graph.tensor(name = "bias", dim = [4,16,16], stride = [256,1,16]) - -response = graph.matmul(name = "matmul", image = image, weight = weight) - -output = graph.bias(name = "bias", input = response, bias = bias) - -relu = graph.relu(name = "relu", input = output) -relu.set_output(True) - -graph.check_support() - -graph.build() - -X_cpu = np.full([4,16,64], 1, dtype=np.half) -W_cpu = np.full([4,64,16], 1, dtype=np.half) -B_cpu = np.full([4,16,16], 2, dtype=np.half) - -X_gpu = cp.asarray(X_cpu) -W_gpu = cp.asarray(W_cpu) -B_gpu = cp.asarray(B_cpu) -Y_gpu = cp.full([4,16,16], 0, dtype=cp.half) - -workspace = cp.empty(graph.get_workspace_size(), dtype=cp.uint8) -graph.execute({image : X_gpu, weight : W_gpu, bias : B_gpu, relu : Y_gpu}, workspace) - -Y_actual = cp.asnumpy(Y_gpu) - -Y_expected = np.matmul(X_cpu, W_cpu) + B_cpu -Y_expected[Y_expected < 0] = 0 - -np.testing.assert_allclose(Y_actual, Y_expected) diff --git a/samples/python/test_matmul_bias_relu.py b/samples/python/test_matmul_bias_relu.py new file mode 100644 index 00000000..bfddc6b0 --- /dev/null +++ b/samples/python/test_matmul_bias_relu.py @@ -0,0 +1,63 @@ +import cudnn +import itertools +import pytest +import torch + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.bool: + return cudnn.data_type.BOOLEAN + elif torch_type == torch.uint8: + return cudnn.data_type.UINT8 + else: + raise ValueError("Unsupported tensor data type.") + +problem_size_options = [(1, 128, 768) + # , (16, 512, 1600) TODO: BUG https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=4291755&cmtNo= + , (1, 128, 1024)] +input_type_options = [torch.bfloat16, torch.float16] + +all_options = [elem for elem in itertools.product(*[problem_size_options, input_type_options])] + +@pytest.fixture(params=all_options) +def param_extract(request): + return request.param + +def test_matmul_bias_relu(param_extract): + problem_size_options, input_type = param_extract + b, s, e = problem_size_options + + X_gpu = torch.randn(b,s,e, requires_grad=False, device="cuda", dtype=input_type) + W_gpu = torch.randn(1,e,e*4, requires_grad=False, device="cuda", dtype=input_type) + B_gpu = torch.randn(1,1,e*4, requires_grad=False, device="cuda", dtype=input_type) + Y_expected = torch.nn.functional.linear(X_gpu, W_gpu.squeeze().T, bias=B_gpu.squeeze()) + + graph = cudnn.pygraph(intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + + X = graph.tensor(name = "X", dim = X_gpu.size(), stride = X_gpu.stride(), data_type = convert_to_cudnn_type(input_type)) + W = graph.tensor(name = "W", dim = W_gpu.size(), stride = W_gpu.stride(), data_type = convert_to_cudnn_type(input_type)) + B = graph.tensor(name = "B", dim = B_gpu.size(), stride = B_gpu.stride(), data_type = convert_to_cudnn_type(input_type)) + + response = graph.matmul(name = "matmul", A = X, B = W) + Y = graph.bias(name = "bias", input = response, bias = B) + Y.set_output(True).set_data_type(convert_to_cudnn_type(input_type)) + + graph.check_support() + graph.build() + + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + Y_actual = torch.zeros_like(Y_expected) + + graph.execute({X: X_gpu, W: W_gpu, B: B_gpu, Y: Y_actual}, workspace) + + rtol = 1e-2 if input_type == torch.bfloat16 else 1e-3 + torch.testing.assert_close(Y_expected, Y_actual, atol=1e-3, rtol=rtol) + +if __name__ == "__main__": + test_matmul_bias_relu(((1,128,1600), torch.float16)) \ No newline at end of file diff --git a/samples/python/test_mhas.py b/samples/python/test_mhas.py index e5f1f140..2a2314ed 100644 --- a/samples/python/test_mhas.py +++ b/samples/python/test_mhas.py @@ -6,6 +6,7 @@ import itertools import random + def convert_to_cudnn_type(torch_type): if torch_type == torch.float16: return cudnn.data_type.HALF @@ -20,71 +21,63 @@ def convert_to_cudnn_type(torch_type): else: raise ValueError("Unsupported tensor data type.") -def perr(a, b): - a, b = a.float(), b.float() - diff = (a-b) - return (diff.abs().sum() / a.abs().sum()).item() - -def mean_avg_error(a, b): - a, b = a.float(), b.float() - diff = (a-b) - return diff.abs().mean().item() - -def compare_tensors(a_, b_, tensor_name): - print("================================================") - print(tensor_name) - print("================================================") +def make_tensor_attr(graph, torch_tensor, name="", dim=None, stride=None, is_pass_by_value=None): + return graph.tensor( + name=name, + dim=dim if dim else torch_tensor.size(), + stride=stride if stride else torch_tensor.stride(), + data_type=convert_to_cudnn_type(torch_tensor.dtype), + is_pass_by_value=is_pass_by_value, + ) - n_elem = torch.numel(a_) - abs_err_tol = 0.1 - rel_err_tol = 0.1 +def compare_tensors(expected, actual, tensor_name, rtol=2e-2, atol=2e-2, fudge=1e-9, print_compare=False): + assert expected.shape == actual.shape - if a_.cuda: - a = a_.to(device='cpu') + expected = expected.to(dtype=torch.float64, device="cuda").flatten() + actual = actual.to(dtype=torch.float64, device="cuda").flatten() - if b_.cuda: - b = b_.to(device='cpu') + n_elem = torch.numel(expected) - mae = mean_avg_error(a, b) - some_perr = perr(a, b) + mae = (expected - actual).abs().mean().item() + perr = ((expected - actual).abs().sum() / expected.abs().sum()).item() + snr = (expected**2).mean().sqrt() / ((expected - actual) ** 2).mean().sqrt() + snr_db = (10 * torch.log10(snr)).item() - absolute_error = torch.abs(a - b) - relative_error = torch.div(absolute_error, a + 0.0000000001) - - max_abs_error = torch.max(absolute_error) - max_rel_error = torch.max(relative_error) - - abs_error_indices = absolute_error > abs_err_tol - rel_error_indices = relative_error > rel_err_tol + absolute_error = (expected - actual).abs() + relative_error = absolute_error / torch.where(expected.abs() < fudge, fudge, expected.abs()) + abs_error_indices = absolute_error > atol + rel_error_indices = relative_error > rtol n_abs_errors = torch.sum(abs_error_indices) n_rel_errors = torch.sum(rel_error_indices) - error_indices = torch.logical_and(abs_error_indices, rel_error_indices) n_errors = torch.sum(error_indices) - print("Absolute Tolerance = {}".format(abs_err_tol)) - print("Relative Tolerance = {}".format(rel_err_tol)) - print("Number of elements = {}".format(n_elem)) - - print("Number of absolute errors = {} ({:.2f}%)". format(n_abs_errors, (n_abs_errors * 100)/n_elem)) - print("Number of relative errors = {} ({:.2f}%)". format(n_rel_errors, (n_rel_errors * 100)/n_elem)) - - print("Number of errors (absolute and relative) = {} ({:.2f}%)". format(n_errors, (n_errors * 100)/n_elem)) - - print("Maximum absolute error = {:.4f}".format(max_abs_error)) - print("Maximum relative error = {:.4f}".format(max_rel_error)) - print("Mean average error = {:.4f}".format(mae)) - print("Perr error = {:.4f}".format(some_perr)) - print("Number of Nans = {} ({:.2f}%)".format(torch.sum(torch.isnan(b)), torch.sum(torch.isnan(b) * 100/n_elem))) - print("Number of Zeros = {} ({:.2f}%)".format(n_elem - torch.count_nonzero(b), (n_elem - torch.count_nonzero(b)) * 100/n_elem)) - - print("================================================\n") + n_nans = torch.isnan(actual).sum() + n_zeros = n_elem - torch.count_nonzero(actual) + + if print_compare or n_errors != 0: + print(f"========== {tensor_name} ==========") + print(f"Absolute Tolerance = {atol}") + print(f"Relative Tolerance = {rtol}") + print(f"Number of elements = {n_elem}") + print(f"Number of absolute errors = {n_abs_errors} ({n_abs_errors * 100 / n_elem:.2f}%)") + print(f"Number of relative errors = {n_rel_errors} ({n_rel_errors * 100 / n_elem:.2f}%)") + print(f"Number of errors (absolute and relative) = {n_errors} ({(n_errors * 100)/n_elem:.2f}%)") + print(f"Maximum absolute error = {absolute_error.max():.4f}") + print(f"Maximum relative error = {relative_error.max():.4f}") + print(f"Mean average error = {mae:.4f}") + print(f"Perr error = {perr:.4f} = 1/{1/perr:.2f}") + print(f"Signal to noise ratio = {snr.item():.2f} = {snr_db:.2f}dB") + print(f"Number of Nans = {n_nans} ({n_nans * 100 / n_elem:.2f}%)") + print(f"Number of Zeros = {n_zeros} ({n_zeros * 100 / n_elem:.2f}%)") + print("===================================\n") return n_errors + def get_slopes(n_heads: int): """ ## Get head-specific slope $m$ for each head @@ -123,230 +116,522 @@ def get_slopes(n_heads: int): m = torch.cat([m, m_hat]) # Reshape the tensor to [1, num_heads, 1, 1] - m = m.view(1, -1, 1, 1).to(device='cuda') + m = m.view(1, -1, 1, 1).to(device="cuda") return m -class scaled_dot_product_attention(torch.nn.Module): - def forward(self, query, key, value, is_causal, is_infer, bias, is_alibi, attn_scale): - _, h, s_q, d = query.shape - _, _, s_kv, _ = key.shape - - S = query @ key.transpose(-2, -1) * attn_scale - S = S.to(dtype=torch.float32) - if bias is not None: - S.add_(bias) - if is_alibi: - S.add_(((torch.arange(s_kv, dtype=torch.float32, device = 'cuda')) - torch.arange(s_q, dtype=torch.float32, device = 'cuda').view(-1, 1)) * get_slopes(h)) - if is_causal: - causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device = 'cuda').triu_(diagonal=1) - S.masked_fill_(causal_mask, float('-inf')) - - Stats = None - if not is_infer: - row_max, _ = torch.max(S, -1, True) - row_exp = torch.exp(S - row_max) - row_sum = torch.sum(row_exp, -1, True) - Stats = row_max + torch.log(row_sum) - - return torch.softmax(S, dim=-1).to(dtype=value.dtype) @ value, Stats - -alibi_mask_options = [True, False] -padding_mask_options = [True, False] -causal_mask_options = [True, False] -layout_options = ["non_interleaved", "bs3hd", "sbh3d"] -dropout = [False] -is_infer_options = [True, False] -bias = [True, False] -input_type_options = [torch.float16, torch.bfloat16] - -all_options = [elem for elem in itertools.product(*[alibi_mask_options, padding_mask_options, causal_mask_options, layout_options, dropout, is_infer_options, bias, input_type_options])] - -@pytest.fixture(params=all_options) -def param_extract(request): - return request.param - -@pytest.mark.skipif(cudnn.backend_version() < 8903, reason="requires cudnn 8.9 or higher") -def test_scale_dot_product_flash_attention(param_extract): - - alibi_mask, padding_mask, causal_mask, layout, dropout_enable, is_infer, bias_enable, input_type = param_extract - - if alibi_mask and cudnn.backend_version() < 8904: + +def compute_o_stats(q, k, v, attn_scale=1.0, bias=None, is_alibi=False, padding=None, is_causal=False, device="cuda"): + b, h, s_q, d = q.shape + _, _, s_kv, _ = k.shape + + assert k.shape == (b, h, s_kv, d) + assert v.shape == (b, h, s_kv, d) + + if padding is not None: + seq_len_q, seq_len_kv = padding + q_mask = torch.zeros(b, 1, s_q, 1, dtype=torch.bool, device=device) + k_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + v_mask = torch.zeros(b, 1, s_kv, 1, dtype=torch.bool, device=device) + s_mask = torch.zeros(b, 1, s_q, s_kv, dtype=torch.bool, device=device) + for i, (m, n) in enumerate(zip(seq_len_q, seq_len_kv)): + q_mask[i, :, m:, :] = True + k_mask[i, :, n:, :] = True + v_mask[i, :, n:, :] = True + s_mask[i, :, m:, :] = True + s_mask[i, :, :, n:] = True + + q = q.to(dtype=torch.float32, device=device) + k = k.to(dtype=torch.float32, device=device) + v = v.to(dtype=torch.float32, device=device) + if padding is not None: + q.masked_fill_(q_mask, 0) + k.masked_fill_(k_mask, 0) + v.masked_fill_(v_mask, 0) + s = torch.einsum("bhqd,bhkd->bhqk", q, k) * attn_scale + if bias is not None: + s.add_(bias) + if is_alibi: + lin_bias = ((torch.arange(s_kv, dtype=q.dtype)) - torch.arange(s_q, dtype=q.dtype).view(-1, 1)) + s.add_(lin_bias.to(device=device) * get_slopes(h)) + if padding is not None: + s.masked_fill_(s_mask, float("-inf")) + if is_causal: + causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool, device=device).triu_(diagonal=1) + s.masked_fill_(causal_mask, float("-inf")) + p = torch.softmax(s, dim=-1) + if padding is not None: + p.masked_fill_(s_mask, 0) + o = torch.einsum("bhqk,bhkd->bhqd", p, v) + # amax (NOT absolute max) is used here to evenly distribute gradient + row_max = torch.amax(s, -1, True) + row_exp = torch.exp(s - row_max) + row_sum = torch.sum(row_exp, -1, True) + stats = row_max + torch.log(row_sum) + + return o, stats + + +class ScaledDotProductAttentionPyT(torch.nn.Module): + def __init__(self, is_causal=False, is_bias=False, is_alibi=False, attn_scale=1.0): + super(ScaledDotProductAttentionPyT, self).__init__() + self.is_bias = is_bias + self.is_causal = is_causal + self.is_alibi = is_alibi + self.attn_scale = attn_scale + + def forward(self, q, k, v, bias=None): + b, h, s_q, d = q.shape + _, _, s_kv, _ = k.shape + + assert k.shape == (b, h, s_kv, d) + assert v.shape == (b, h, s_kv, d) + + assert self.is_bias == (bias != None) + + s = torch.einsum("bhqd,bhkd->bhqk", q, k) * self.attn_scale + if self.is_bias: + s.add_(bias) + if self.is_alibi: + s.add_(((torch.arange(s_kv, dtype=q.dtype)) - torch.arange(s_q, dtype=q.dtype).view(-1, 1)) * get_slopes(h)) + if self.is_causal: + causal_mask = torch.ones(s_q, s_kv, dtype=torch.bool).triu_(diagonal=1).cuda() + s.masked_fill_(causal_mask, float("-inf")) + p = torch.softmax(s, dim=-1) + o = torch.einsum("bhqk,bhkd->bhqd", p, v) + return o + +alibi_mask_options = [False, True] +padding_mask_options = [False, True] +causal_mask_options = [False, True] +layout_options = ["non_interleaved", "bs3hd", "sbh3d"] +dropout_options = [False] +is_infer_options = [False, True] +bias_options = [False, True] +input_type_options = [torch.float16, torch.bfloat16] + +all_options_forward = [ + elem + for elem in itertools.product( + *[ + alibi_mask_options, + padding_mask_options, + causal_mask_options, + layout_options, + dropout_options, + is_infer_options, + bias_options, + input_type_options + ] + ) +] + +all_options_backward = [ + elem + for elem in itertools.product( + *[ + causal_mask_options, + dropout_options, + bias_options, + input_type_options + ] + ) +] + +@pytest.fixture(params=all_options_forward) +def param_extract_forward(request): + return request.param + + +@pytest.mark.skipif(cudnn.backend_version() < 8903, reason="requires cudnn 8.9.3 or higher") +def test_scale_dot_product_flash_attention(param_extract_forward, print_compare=False): + ( + is_alibi, + is_padding, + is_causal, + layout, + is_dropout, + is_infer, + is_bias, + input_type + ) = param_extract_forward + + if is_alibi and cudnn.backend_version() < 8904: pytest.skip("ALiBi mask is only supported 8.9.4 onwards.") - - if padding_mask and cudnn.backend_version() < 8903: + + if is_padding and cudnn.backend_version() < 8903: pytest.skip("Padding mask is only supported 8.9.3 onwards.") - s_q_choices = [256, 512, 1024, 2048] - d_choices = [64,128] - - b = 32 - h = 12 - s_q = random.choice(s_q_choices) - s_kv = s_q + s_q_choices = [256, 512, 1024, 2048] + d_choices = [64, 128] + + b = 3 + h = 4 + s_q = random.choice(s_q_choices) + s_kv = s_q d = random.choice(d_choices) - - print(param_extract) - print ("d = {} s_kv = {} s_q = {} ".format(d, s_kv, s_q)) - - attn_scale = 0.125 - - if dropout_enable == False: - dropout_prob = 1.0 + + print(f"{str(param_extract_forward)} s={s_q} d={d}") + + attn_scale_val = 0.125 + dropout_prob = 0.1 if is_dropout else 0.0 + + shape_q = (b, h, s_q, d) + shape_k = (b, h, d, s_kv) + shape_v = (b, h, s_kv, d) + shape_o = (b, h, s_q, d) + + if layout == "sbh3d": + stride_q = (3 * h * d, 3 * d, b * 3 * h * d, 1) + stride_k = (3 * h * d, 3 * d, 1, b * 3 * h * d) + stride_v = (3 * h * d, 3 * d, b * 3 * h * d, 1) + stride_o = (h * d, d, b * h * d, 1) + stride_order_o = (2, 1, 3, 0) + + offset_q = d * 0 + offset_k = d * 1 + offset_v = d * 2 + elif layout == "bs3hd": + stride_q = (s_q * 3 * h * d, d, 3 * h * d, 1) + stride_k = (s_q * 3 * h * d, d, 1, 3 * h * d) + stride_v = (s_q * 3 * h * d, d, 3 * h * d, 1) + stride_o = (s_q * h * d, d, h * d, 1) + stride_order_o = (3, 1, 2, 0) + + offset_q = h * d * 0 + offset_k = h * d * 1 + offset_v = h * d * 2 + elif layout == "non_interleaved": + stride_q = (d * s_q * h, d * s_q, d, 1) + stride_k = (d * s_kv * h, d * s_kv, 1, d) + stride_v = (d * s_kv * h, d * s_kv, d, 1) + stride_o = (d * s_q * h, d * s_q, d, 1) + stride_order_o = (3, 2, 1, 0) + + offset_q = 0 + offset_k = offset_q + b * d * s_q * h + offset_v = offset_k + b * d * s_kv * h else: - dropout_prob = 0.1 + assert False, "Layout should be either sbh3d or bs3hd or non_interleaved" - shape_Q = (b, h, s_q, d) + qkv_gpu = 1 * (torch.randn(b * s_q * 3 * h * d, dtype=input_type, device="cuda") - 0.5) + q_gpu = torch.as_strided(qkv_gpu, shape_q, stride_q, storage_offset=offset_q) + k_gpu = torch.as_strided(qkv_gpu, shape_k, stride_k, storage_offset=offset_k) + v_gpu = torch.as_strided(qkv_gpu, shape_v, stride_v, storage_offset=offset_v) - shape_K = (b, h, d, s_kv) + if attn_scale_val != 1.0: + attn_scale_cpu = torch.full((1, 1, 1, 1), attn_scale_val, dtype=torch.float32, device="cpu") - shape_V = (b, h, s_kv, d) + if is_bias: + bias_gpu = torch.randn(b, 1, s_q, s_kv, requires_grad=False, device="cuda", dtype=input_type) - stride_sbh3d = (3 * h * d, 3 * d, b * 3 * h * d, 1) - stride_sbh3d_t = (3 * h * d, 3 * d, 1, b * 3 * h * d) - stride_sbhd = (h * d, d, b * h * d, 1) + if is_padding: + seq_len_q_gpu = torch.randint(0, s_q + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") + seq_len_kv_gpu = torch.randint(0, s_kv + 1, (b, 1, 1, 1), dtype=torch.int32, device="cuda") - stride_bs3hd = (s_q * 3 * h * d, d, 3 * h * d, 1) - stride_bs3hd_t = (s_q * 3 * h * d, d, 1, 3 * h * d) - stride_bshd = (s_q * h * d, d, h * d, 1) + if is_dropout: + seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") + offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") - offset_multiple_sbh3d = d - offset_multiple_bs3hd = h * d - - bias_gpu = torch.randn(b, 1, s_q, s_kv, requires_grad=False, device="cuda", dtype = input_type) if bias_enable else None + o_gpu = torch.empty(*shape_o, dtype=input_type, device="cuda").as_strided(shape_o, stride_o) + if is_infer == False: + stats_gpu = torch.empty(b, h, s_q, 1, dtype=torch.float32, device="cuda") + + # cuDNN graph + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(input_type), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = make_tensor_attr(graph, q_gpu, "q") + k = make_tensor_attr(graph, k_gpu, "k") + v = make_tensor_attr(graph, v_gpu, "v") + + if attn_scale_val != 1.0: + attn_scale = make_tensor_attr(graph, attn_scale_cpu, "attn_scale", is_pass_by_value=True) + + if is_bias: + bias = make_tensor_attr(graph, bias_gpu, "bias") + + if is_padding: + seq_len_q = make_tensor_attr(graph, seq_len_q_gpu, "seq_len_q") + seq_len_kv = make_tensor_attr(graph, seq_len_kv_gpu, "seq_len_kv") + + if is_dropout: + seed = make_tensor_attr(graph, seed_gpu, "seed") + offset = make_tensor_attr(graph, offset_gpu, "attn_scale") + dropout_tuple = (dropout_prob, seed, offset) + + o, stats = graph.scaled_dot_product_flash_attention( + name="scaled_dot_product_flash_attention", + q=q, + k=k, + v=v, + is_inference=is_infer, + attn_scale=attn_scale if attn_scale_val != 1.0 else None, + bias=bias if is_bias else None, + use_alibi_mask=is_alibi, + use_padding_mask=is_padding, + seq_len_q=seq_len_q if is_padding else None, + seq_len_kv=seq_len_kv if is_padding else None, + use_causal_mask=is_causal, + dropout=dropout_tuple if is_dropout else None, + ) + + o.set_output(True).set_dim(shape_o).set_stride(stride_o) + if is_infer == False: + stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - if layout == 'sbh3d': - stride_Q = stride_sbh3d - stride_K = stride_sbh3d_t - stride_V = stride_sbh3d + graph.check_support() + graph.build() - stride_O = stride_sbhd + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu + } - offset_Q = offset_multiple_sbh3d * 0 - offset_K = offset_multiple_sbh3d * 1 - offset_V = offset_multiple_sbh3d * 2 - elif layout == 'bs3hd': - stride_Q = stride_bs3hd - stride_K = stride_bs3hd_t - stride_V = stride_bs3hd + if attn_scale_val != 1.0: + variant_pack[attn_scale] = attn_scale_cpu - stride_O = stride_bshd + if is_bias: + variant_pack[bias] = bias_gpu - offset_Q = offset_multiple_bs3hd * 0 - offset_K = offset_multiple_bs3hd * 1 - offset_V = offset_multiple_bs3hd * 2 - elif layout == 'non_interleaved': - stride_Q = (1 * d * s_q * h, 1 * d * s_q, 1 * d, 1) - stride_K = (1 * d * s_kv * h, 1 * d * s_kv, 1, 1 * d) - stride_V = (1 * d * s_kv * h, 1 * d * s_kv, 1 * d, 1) - - stride_O = (d * s_q * h, d * s_q, d, 1) + if is_padding: + variant_pack[seq_len_q] = seq_len_q_gpu + variant_pack[seq_len_kv] = seq_len_kv_gpu - offset_Q = 0 - offset_K = offset_Q + b * d * s_q * h - offset_V = offset_K + b * d * s_kv * h + if is_dropout: + variant_pack[seed] = seed_gpu + variant_pack[offset] = offset_gpu - else: - assert False, "Layout should be either sbh3d or bs3hd or non_interleaved" + if is_infer == False: + variant_pack[stats] = stats_gpu - qkv_gpu = 1 * (torch.randn(b * s_q * 3 * h * d, dtype=input_type, device="cuda") - 0.5) - - Q_gpu = torch.as_strided(qkv_gpu, shape_Q, stride_Q, storage_offset=offset_Q) - K_gpu = torch.as_strided(qkv_gpu, shape_K, stride_K, storage_offset=offset_K) - V_gpu = torch.as_strided(qkv_gpu, shape_V, stride_V, storage_offset=offset_V) - - if padding_mask: - seq_len_Q_gpu = torch.full((b,1,1,1), s_q, dtype=torch.int32, device="cuda") - seq_len_KV_gpu = torch.full((b,1,1,1), s_kv, dtype=torch.int32, device="cuda") - - Attn_scale_cpu = torch.full((1,1,1,1), attn_scale, dtype=torch.float32, device="cpu") - - Seed_gpu = torch.full((1,1,1,1), 123456, dtype=torch.int64, device="cuda") - Offset_gpu = torch.full((1,1,1,1), 1, dtype=torch.int64, device="cuda") - - # Cudnn graph - graph = cudnn.pygraph(io_data_type = convert_to_cudnn_type(input_type), intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) - Q = graph.tensor(name = "Q", dim = Q_gpu.size(), stride = Q_gpu.stride(), data_type = convert_to_cudnn_type(Q_gpu.dtype)) - K = graph.tensor(name = "K", dim = K_gpu.size(), stride = K_gpu.stride(), data_type = convert_to_cudnn_type(K_gpu.dtype)) - V = graph.tensor(name = "V", dim = V_gpu.size(), stride = V_gpu.stride(), data_type = convert_to_cudnn_type(V_gpu.dtype)) - Attn_scale = graph.tensor(name = "Attn_scale", dim = Attn_scale_cpu.size(), stride = Attn_scale_cpu.stride(), data_type = convert_to_cudnn_type(Attn_scale_cpu.dtype), is_pass_by_value = True) - Seed = graph.tensor(name = "Seed", dim = Seed_gpu.size(), stride = Seed_gpu.stride(), data_type = convert_to_cudnn_type(Seed_gpu.dtype)) - Offset = graph.tensor(name = "Offset", dim = Offset_gpu.size(), stride = Offset_gpu.stride(), data_type = convert_to_cudnn_type(Offset_gpu.dtype)) - - Bias = graph.tensor(name = "bias", dim = bias_gpu.size(), stride = bias_gpu.stride(),data_type = convert_to_cudnn_type(Q_gpu.dtype)) if bias_enable else None - - dropout_tuple = None - if dropout_enable == True: - dropout_tuple = (dropout_prob, Seed, Offset) - - seq_len_Q = None - seq_len_KV = None - if padding_mask: - seq_len_Q = graph.tensor(name = "seq_len_Q", dim = seq_len_Q_gpu.size(), stride = seq_len_Q_gpu.stride(), data_type = convert_to_cudnn_type(seq_len_Q_gpu.dtype)) - seq_len_KV = graph.tensor(name = "seq_len_KV", dim = seq_len_KV_gpu.size(), stride = seq_len_KV_gpu.stride(), data_type = convert_to_cudnn_type(seq_len_KV_gpu.dtype)) - - - O, Stats = graph.scaled_dot_product_flash_attention(name = "scaled_dot_product_flash_attention" - , q = Q, k = K, v = V - , seq_len_q = seq_len_Q, seq_len_kv = seq_len_KV - , is_inference = is_infer - , bias = Bias - , dropout = dropout_tuple - , attn_scale = Attn_scale - , use_alibi_mask = alibi_mask - , use_padding_mask = padding_mask - , use_causal_mask = causal_mask - ) - - O.set_output(True).set_stride(stride_O) - + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + graph.execute(variant_pack, workspace) + + # compare with torch reference + q_ref = q_gpu.detach().float() + k_ref = k_gpu.permute(0, 1, 3, 2).detach().float() + v_ref = v_gpu.detach().float() + + if is_bias: + bias_ref = bias_gpu.detach().float() + + if is_padding: + seq_len_q_ref = seq_len_q_gpu.detach().flatten() + seq_len_kv_ref = seq_len_kv_gpu.detach().flatten() + + o_ref, stats_ref = compute_o_stats( + q_ref, + k_ref, + v_ref, + attn_scale=attn_scale_val, + bias=bias_ref if is_bias else None, + is_alibi=is_alibi, + is_causal=is_causal, + padding=(seq_len_q_ref, seq_len_kv_ref) if is_padding else None + ) + + if is_padding: + # zero out padded region of the output for comparison + for i, (m, n) in enumerate(zip(seq_len_q_ref, seq_len_kv_ref)): + o_ref[i, :, m:, :] = 0 + o_gpu[i, :, m:, :] = 0 + if is_infer == False: + stats_ref[i, :, m:, :] = 0 + stats_gpu[i, :, m:, :] = 0 + + assert compare_tensors(o_ref, o_gpu, "O", print_compare=print_compare) == 0 if is_infer == False: - Stats.set_output(True).set_data_type(cudnn.data_type.FLOAT) - + assert compare_tensors(stats_ref, stats_gpu, "stats", print_compare=print_compare) == 0 + + +@pytest.fixture(params=all_options_backward) +def param_extract_backward(request): + return request.param + + +@pytest.mark.skipif(cudnn.backend_version() < 8903, reason="requires cudnn 8.9.3 or higher") +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires ampere or higher") +def test_scale_dot_product_flash_attention_backward(param_extract_backward, print_compare=False): + ( + is_causal, + is_dropout, + is_bias, + input_type + ) = param_extract_backward + + layout = "naive" + + s_q_choices = [256, 512, 1024] + d_choices = [64, 128] + + b = 3 + h = 4 + s_q = random.choice(s_q_choices) + s_kv = s_q + d = random.choice(d_choices) + + print(f"{str(param_extract_backward)} s={s_q} d={d}") + + attn_scale_val = 0.125 + dropout_prob = 0.1 if is_dropout else 0.0 + + q_gpu = 1 * (torch.randn((b, h, s_q, d), dtype=input_type, device="cuda") - 0.5) + k_gpu = 1 * (torch.randn((b, h, s_kv, d), dtype=input_type, device="cuda") - 0.5) + v_gpu = 1 * (torch.randn((b, h, s_kv, d), dtype=input_type, device="cuda") - 0.5) + dO_gpu = 0.1 * torch.randn((b, h, s_q, d), dtype=input_type, device="cuda") + + if attn_scale_val != 1.0: + attn_scale_cpu = torch.full((1, 1, 1, 1), attn_scale_val, dtype=torch.float32, device="cpu") + + if is_bias: + bias_gpu = torch.randn(b, 1, s_q, s_kv, device="cuda", dtype=input_type) + + if is_dropout: + seed_gpu = torch.full((1, 1, 1, 1), 123456, dtype=torch.int64, device="cuda") + offset_gpu = torch.full((1, 1, 1, 1), 789, dtype=torch.int64, device="cuda") + + o_gpu, stats_gpu = compute_o_stats( + q_gpu, + k_gpu, + v_gpu, + is_causal=is_causal, + bias=bias_gpu if is_bias else None, + attn_scale=attn_scale_val + ) + o_gpu = o_gpu.to(dtype=input_type).detach().clone() + stats_gpu = stats_gpu.to(dtype=torch.float32).detach().clone() + + dQ_gpu = torch.empty((b, h, s_q, d), dtype=input_type, device="cuda") + dK_gpu = torch.empty((b, h, s_kv, d), dtype=input_type, device="cuda") + dV_gpu = torch.empty((b, h, s_kv, d), dtype=input_type, device="cuda") + + # cuDNN graph + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(input_type), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q = make_tensor_attr(graph, q_gpu, name="q") + k = make_tensor_attr(graph, k_gpu, dim=(b, h, d, s_kv), stride=(h * s_kv * d, s_kv * d, 1, d), name="k") + v = make_tensor_attr(graph, v_gpu, dim=(b, h, d, s_kv), stride=(h * s_kv * d, s_kv * d, 1, d), name="v") + o = make_tensor_attr(graph, o_gpu, name="o") + dO = make_tensor_attr(graph, dO_gpu, name="dO") + stats = make_tensor_attr(graph, stats_gpu, name="stats") + + if attn_scale_val != 1.0: + attn_scale = make_tensor_attr(graph, attn_scale_cpu, is_pass_by_value=True, name="attn_scale") + + if is_bias: + bias = make_tensor_attr(graph, bias_gpu, "bias") + + if is_dropout: + seed = make_tensor_attr(graph, seed_gpu, "seed") + offset = make_tensor_attr(graph, offset_gpu, "attn_scale") + dropout_tuple = (dropout_prob, seed, offset) + + dQ, dK, dV = graph.scaled_dot_product_flash_attention_backward( + name="scaled_dot_product_flash_attention", + q=q, + k=k, + v=v, + o=o, + dO=dO, + stats=stats, + attn_scale=attn_scale if attn_scale_val != 1.0 else None, + bias=bias if is_bias else None, + use_causal_mask=is_causal, + dropout=dropout_tuple if is_dropout else None, + ) + + dQ.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) + dK.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) + dV.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) + graph.check_support() - graph.build() - O_actual = torch.zeros(b * s_q * h * d, dtype=input_type, device="cuda") - Stats_actual = torch.zeros(b * h * s_q * 1, dtype=torch.float32, device="cuda") + variant_pack = { + q: q_gpu, + k: k_gpu, + v: v_gpu, + o: o_gpu, + dO: dO_gpu, + stats: stats_gpu, + dQ: dQ_gpu, + dK: dK_gpu, + dV: dV_gpu + } + + if attn_scale_val != 1.0: + variant_pack[attn_scale] = attn_scale_cpu + + if is_bias: + variant_pack[bias] = bias_gpu + + if is_dropout: + variant_pack[seed] = seed_gpu + variant_pack[offset] = offset_gpu workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + graph.execute(variant_pack, workspace) - cudnn_to_torch_tensor = {Q: Q_gpu, K: K_gpu, V: V_gpu - , Seed: Seed_gpu, Offset: Offset_gpu - , Attn_scale: Attn_scale_cpu - , O: O_actual, Stats: Stats_actual} - - if bias_enable: - cudnn_to_torch_tensor[Bias] = bias_gpu + # compare with torch autograd reference + nn_ref = ScaledDotProductAttentionPyT( + is_causal=is_causal, + is_bias=is_bias, + attn_scale=attn_scale_val + ).cuda().float() - if padding_mask: - cudnn_to_torch_tensor[seq_len_Q] = seq_len_Q_gpu - cudnn_to_torch_tensor[seq_len_KV] = seq_len_KV_gpu + q_ref = q_gpu.detach().float() + q_ref.requires_grad = True + k_ref = k_gpu.detach().float() + k_ref.requires_grad = True + v_ref = v_gpu.detach().float() + v_ref.requires_grad = True + dO_ref = dO_gpu.detach().float() - graph.execute(cudnn_to_torch_tensor, workspace) + if is_bias: + bias_ref = bias_gpu.detach().float() + bias_ref.requires_grad = True - torch.set_printoptions(precision = 2, linewidth = 2560, threshold = 1000000, sci_mode = False) + o_ref = nn_ref(q_ref, k_ref, v_ref, bias=bias_ref if is_bias else None) - Stats_reorg = Stats_actual.view(b, h, s_q, 1) + outputs_ref = [o_ref] + inputs_ref = [q_ref, k_ref, v_ref] - if layout == 'sbh3d': - O_reorg = O_actual.view([s_q, b, h, d]).permute(1, 2, 0, 3) - elif layout == 'bs3hd': - O_reorg = O_actual.view([b, s_q, h, d]).permute(0, 2, 1, 3) - elif layout == 'non_interleaved': - O_reorg = O_actual.view([b, h, s_q, d]) - else: - assert False, "Layout should be either sbh3d or bs3hd or non_interleaved" + if is_bias: + inputs_ref.append(bias_ref) - # Cpu reference - sdpa = scaled_dot_product_attention() - O_expected, Stats_expected = sdpa(Q_gpu, K_gpu.permute(0, 1, 3, 2), V_gpu, is_causal = causal_mask, is_infer = is_infer, bias = bias_gpu, is_alibi = alibi_mask, attn_scale = attn_scale) + [dq_ref, dk_ref, dv_ref, *opt_refs] = list(torch.autograd.grad( + outputs=outputs_ref, + inputs=inputs_ref, + grad_outputs=dO_ref + )) + + assert compare_tensors(dq_ref, dQ_gpu, "dQ", print_compare=print_compare) == 0 + assert compare_tensors(dk_ref, dK_gpu, "dK", print_compare=print_compare) == 0 + assert compare_tensors(dv_ref, dV_gpu, "dV", print_compare=print_compare) == 0 + + if is_bias: + db_ref = opt_refs.pop(0) - if is_infer == False: - assert compare_tensors(Stats_expected, Stats_reorg, "Stats") == 0 - - assert compare_tensors(O_expected, O_reorg, "O") == 0 - if __name__ == "__main__": - test_scale_dot_product_flash_attention((True, "bs3hd", False, False, True)) \ No newline at end of file + """ + option_forward = (alibi_mask, padding_mask, causal_mask, layout, dropout_enable, is_infer, bias_enable, input_type) + option_backward = (is_causal, is_dropout, is_bias, input_type) + test_scale_dot_product_flash_attention((False, False, False, "bs3hd", False, False, False, torch.float16), print_compare=True) + test_scale_dot_product_flash_attention_backward((False, False, False, torch.float16), print_compare=True) + """ + + print("==========running forward tests==========") + for option in all_options_forward: + test_scale_dot_product_flash_attention(option) + + print("==========running backward tests==========") + for option in all_options_backward: + test_scale_dot_product_flash_attention_backward(option)