Skip to content

Commit

Permalink
# cudnn frontend v1.8 release notes (#118)
Browse files Browse the repository at this point in the history
## New API

### Paged Attention API
SDPA forward operation now supports paged attention on cudnn 9.5.0 and
later by setting the appropriate page-table descriptors.
`SDPA_attributes` now accept `set_paged_attention_k_table` and
`set_paged_attention_v_table` to input this descriptor. Please refer to
samples for usage : [cpp
samples](samples/cpp/sdpa/fp16_fwd_with_paged_caches.cpp), [python
samples](samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb).
See [docs](docs/operations/Attention.md) for more API details.

### cuda Graph API
cudnn graph now allows user to directly build native cuda_graph for
given sub_graph (requires cudnn 9.5.0). There are two APIs:
 - `populate_cuda_graph` : add the cudnn nodes to the empty cuda_graph
 provided as input.
  - `update_cuda_graph` : update the populated cuda graph with necessary
  data pointers.
  See [docs](docs/cuda_graphs.md) and [backend
  documentation](https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-graph-library.html#cudnnbackendpopulatecudagraph)
  for more details.

### Enhancements

- Kernel cache for dynamic shapes are now supported in python. Added a
[sample](test/python/test_kernel_cache.py) to showcase usage.

- `graph.deselect_engines(str: )` has now a python equivalent through
pybind11.

- `graph.tensor(...)` can now accept `int64_t` scalars directly.
(Previously limited to int32_t, float and fp16 data types).

- fp8 sdpa attention now allows dropout and padding mask. Requires cudnn
9.5.0 and above.

- More enhancements to pointwise output stride inferencing (for
broadcast operation). For non-unary operands, the broadcasted tensor can
now be either at IN_0 or IN_1.

- SDPA backward operation now allows d upto 256 for Hopper. Requires
cudnn 9.5.0 and above.

### Bug fixes

- Fixed an issue while querying `cudnnGetLastErrorString()` from the
backend. The error_t object will now have more meaningful message.

- Fixed build issues seen with clang-19 compiler.

- Fixed an issue where it was assumed a graph with bias in sdpa_bprop
will always have a dbias.
  • Loading branch information
Anerudhan authored Oct 23, 2024
1 parent de355c7 commit 936021b
Show file tree
Hide file tree
Showing 125 changed files with 3,655 additions and 1,184 deletions.
14 changes: 6 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
cmake_minimum_required(VERSION 3.17)

project(cudnn_frontend VERSION 1.7.0)
project(cudnn_frontend VERSION 1.8.0)

option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)
option(CUDNN_FRONTEND_BUILD_UNIT_TESTS "Defines if unittests are built or not." ON)
option(CUDNN_FRONTEND_BUILD_TESTS "Defines if unittests are built or not." ON)

if(MSVC OR MSYS OR MINGW)
option(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS "Defines if python bindings are built or not." OFF)
Expand All @@ -28,13 +28,11 @@ target_include_directories(
)

# Find the cuda compiler
find_package(CUDAToolkit)
find_package(CUDAToolkit REQUIRED)

target_link_libraries(
target_include_directories(
cudnn_frontend INTERFACE

CUDA::cudart
CUDA::nvrtc
${CUDAToolkit_INCLUDE_DIRS}
)

target_compile_features(cudnn_frontend INTERFACE cxx_std_17)
Expand All @@ -47,7 +45,7 @@ if (CUDNN_FRONTEND_BUILD_SAMPLES)
add_subdirectory(samples)
endif()

if (CUDNN_FRONTEND_BUILD_UNIT_TESTS)
if (CUDNN_FRONTEND_BUILD_TESTS)
add_subdirectory(test)
endif()

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ To provide a custom CUDNN installation path, use environment variable: `CUDNN_PA
#### Checking the installation
To test whether installation is successful, run:
```
pytest test/python_fe
pytest test/python
```

NOTE: Only v1.0 API is exposed via python bindings.
Expand Down Expand Up @@ -95,6 +95,8 @@ To skip building samples, use `-DCUDNN_FRONTEND_BUILD_SAMPLES=OFF`.

To skip building python bindings, use `-DCUDNN_FRONTEND_BUILD_PYTHON_BINDINGS=OFF`.

To add debug symbols, use `-DCMAKE_BUILD_TYPE=Debug`.

In case, you have a stale cmake cache and want to update the cudnn/cuda paths, please delete the cmake cache (or build directory and redo the above steps).

## Debugging
Expand Down
4 changes: 2 additions & 2 deletions cmake/cuDNN.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_library(CUDNN::cudnn_all INTERFACE IMPORTED)

find_path(
CUDNN_INCLUDE_DIR cudnn.h
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS}
HINTS $ENV{CUDNN_INCLUDE_PATH} ${CUDNN_INCLUDE_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_INCLUDE_DIRS}
PATH_SUFFIXES include
REQUIRED
)
Expand All @@ -14,7 +14,7 @@ string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}")
function(find_cudnn_library NAME)
find_library(
${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}"
HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
HINTS $ENV{CUDNN_LIBRARY_PATH} ${CUDNN_LIBRARY_PATH} $ENV{CUDNN_PATH} ${CUDNN_PATH} ${Python_SITEARCH}/nvidia/cudnn ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib/x64 lib
REQUIRED
)
Expand Down
31 changes: 31 additions & 0 deletions docs/cuda_graphs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@


### `populate_cuda_graph`

The `populate_cuda_graph` function is a member function of the `Graph` class. It is used to populate a CUDA graph with the necessary data and operations.

#### Parameters

- `handle`: A cuDNN handle.
- `uid_to_device_ptrs`: A map of tensor UIDs to device pointers.
- `workspace`: A pointer to the workspace memory.
- `cudnn_cuda_graph`: A pointer to the CUDA graph.

#### Return Value

- An `error_t` object indicating the success or failure of the function.

### `update_cuda_graph`

The `update_cuda_graph` function is a member function of the `Graph` class. It is used to update a CUDA graph with the necessary data and operations.

#### Parameters

- `handle`: A cuDNN handle.
- `uid_to_device_ptrs`: A map of tensor UIDs to device pointers.
- `workspace`: A pointer to the workspace memory.
- `cudnn_cuda_graph`: A pointer to the CUDA graph.

#### Return Value

- An `error_t` object indicating the success or failure of the function.
66 changes: 49 additions & 17 deletions docs/operations/Attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2

- Python sample: [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb)

- Python sample with paged caches: [samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb)

- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa)

- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)
- Python tests: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

#### Configurable Options:

Expand All @@ -38,22 +40,35 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2
- `dropout mask` that matches the attention weights' dimensions, indicating which weights to drop. The dimensions that are passed as 1 will apply a broadcasted dropout mask.
- `dropout scale` used to adjust the scale of the remaining weights accordingly, such as $1 / (1 - \text{dropout probability})$.
- Packed layout: With packed layout, the query, key, value, and output tensor should be [ragged tensors](https://www.tensorflow.org/guide/ragged_tensor), which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the `Tensor_attributes.set_ragged_offset()` method. the ragged offset tensor must be a tensor of size $(B + 1, 1, 1, 1)$ that contains the nested tensor's offset in terms of number of elements (not bytes). The last value of the offset tensor specifies the offset of the past-the-end element of the ragged tensor. See Appendix A for more information on the supported layouts.
- Paged attention: with paged K and/or V caches, the K/V blocks no longer need to be contiguous, allowing users to better utilize memory by avoiding fragmentation.
- Users must therefore:
- Pass a `page table k` tensor containing offsets to the container with K blocks. This is optional, and only needed if the K cache is paged.
- Pass a `page table v` tensor containing offsets to the container with V blocks. This is optional, and only needed if the V cache is paged.
- Pass anything required for `Padding mask` above (i.e., per-batch sequence lengths for both K and V caches). This is needed if at least one of the K/V caches are paged.
- Optionally, but recommended, pass the maximum sequence length for the K/V caches. When omitted, it will be (over)estimated, which could result in a corrupted graph in some corner cases.
- Offsets to the K/V containers will be calculcated as
- $Kcache[b,h,s,d] = K[page\ table\ k[b,1,s / bs_k, 1],h,s\ mod\ bs_{k},d]$
- $Vcache[b,h,s,d] = V[page\ table\ v[b,1,s / bs_v, 1],h,s\ mod\ bs_{v},d]$
- See also the [PagedAttention paper](https://arxiv.org/abs/2309.06180).

##### Input Tensors:

| Tensor Name | Device | Data Type | Dimensions |
|-------------------------------------|------------|----------------|----------------------------------------------------------------------------------------------------------------|
| Q | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ |
| K | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$ |
| V | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$ |
| (Bias mask) Bias Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ |
| (Padding mask) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ |
| (Padding mask) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ |
| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ |
| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ |
| (Custom Dropout Mask) Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ |
| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ |
| (Packed Layout) Ragged Offset | GPU | INT32 | $(B + 1, 1, 1, 1)$ |
| Tensor Name | Device | Data Type | Dimensions |
|------------------------------------------------|------------|----------------|----------------------------------------------------------------------------------------------------------------|
| Q | GPU | FP16 or BF16 | $(B, H_{q}, S_{q}, D_{qk})$ |
| K | GPU | FP16 or BF16 | $(B, H_{k}, S_{kv}, D_{qk})$, or $(num\_blocks_{k}, H_{k}, bs_{k}, D_{qk})$ in case of paged K cache |
| V | GPU | FP16 or BF16 | $(B, H_{v}, S_{kv}, D_{v})$, or $(num\_blocks_{v}, H_{v}, bs_{v}, D_{v})$ in case of paged V cache |
| (Bias mask) Bias Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ |
| (Padding mask/Paged Caches) Sequence Length Q | GPU | INT32 | $(B, 1, 1, 1)$ |
| (Padding mask/Paged Caches) Sequence Length KV | GPU | INT32 | $(B, 1, 1, 1)$ |
| (Philoc RNG Dropout) Seed | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ |
| (Philoc RNG Dropout) Offset | CPU or GPU | INT32 or INT64 | $(1, 1, 1, 1)$ |
| (Custom Dropout Mask) Mask | GPU | FP16 or BF16 | $(1, 1, S_{q}, S_{kv})$, $(1, H_{q}, S_{q}, S_{kv})$, $(B, 1, S_{q}, S_{kv})$, or $(B, H_{q}, S_{q}, S_{kv})$ |
| (Custom Dropout Mask) Scale | GPU | FP32 | $(1, 1, 1, 1)$ |
| (Packed Layout) Ragged Offset | GPU | INT32 | $(B + 1, 1, 1, 1)$ |
| (Paged Attention) Page Table K | GPU | INT32 | $(B, 1, ceil(S_{kv}/bs_{k}), 1)$ |
| (Paged Attention) Page Table V | GPU | INT32 | $(B, 1, ceil(S_{kv}/bs_{v}), 1)$ |
| (Paged Attention) Max Sequence Length KV | CPU | INT32 or INT64 | $(1, 1, 1, 1)$ |

##### Output Tensors

Expand All @@ -73,6 +88,10 @@ Where,
- $S_{kv}$ is the sequence length of the key and value
- $D_{qk}$ is the embedding dimension per head of query and key
- $D_{v}$ is the embedding dimension per head of value
- $bs_{k}$ is the (power of 2) block size of the K container
- $bs_{v}$ is the (power of 2) block size of the V container
- $num\_blocks_{k}$ is the number of blocks in the K container
- $num\_blocks_{v}$ is the number of blocks in the V container

#### Group-query attention (GQA) and Multi-query attention (MQA)

Expand Down Expand Up @@ -146,15 +165,25 @@ set_dropout(std::shared_ptr<Tensor_attributes> mask,
SDPA_attributes&
set_compute_data_type(DataType_t value);
SDPA_attributes&
set_paged_attention_k_table(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_paged_attention_v_table(std::shared_ptr<Tensor_attributes> value);
SDPA_attributes&
set_paged_attention_max_seq_len_kv(int const value);
```

#### Python API:

```
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
k (cudnn_tensor): The key data. When page_table_k is provided, 'k' is a container of non-contiguous key data.
v (cudnn_tensor): The value data. When page_table_v is provided, 'v' is a container of non-contiguous value data.
is_inference (bool): Whether it is an inference step or training step.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
Expand All @@ -166,6 +195,9 @@ Args:
use_causal_mask_bottom_right (Optional[bool]): Whether to use bottom right aligned causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
rng_dump (Optional[cudnn_tensor]): Debug tensor used to output the Philox RNG dropout mask
paged_attention_k_table (Optional[cudnn_tensor]): The page table to look up offsets into 'k'
paged_attention_v_table (Optional[cudnn_tensor]): The page table to look up offsets into 'v'
paged_attention_max_seq_len_kv (Optional[integer]): The maximum sequence length for k/v caches when paged attention is active.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Expand All @@ -182,7 +214,7 @@ This operation computes gradient tensors for scaled dot product attention (SDPA)

- C++ sample: [samples/cpp/sdpa](https://github.com/NVIDIA/cudnn-frontend/tree/main/samples/cpp/sdpa)

- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py)
- Python tests: [test/python/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python/test_mhas.py)

#### Configurable Options:

Expand Down
Loading

0 comments on commit 936021b

Please sign in to comment.