Skip to content

Commit

Permalink
cudnn pre-release-4
Browse files Browse the repository at this point in the history
[API change] `Scaled_dot_product_flash_attention_attributes`,
`Scaled_dot_product_flash_attention_backward_attributes` now accepts K,
V tensors instead of K-transpose and V-transpose. This is a deviation
from the backend API. This change is made based on multiple customer
feedback.

[New API] Add `tensor_like` python API which accepts a DLPack-compstible
tensor. This simplifies the cudnn tensor creation.

[New Feature] Setting `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT`
environment variable allows to choose between different optimized cudnn
backend kernels. See docs/operations/mha for more details.
[New Feature] Add RMSNorm and InstanceNorm forward and backward
implementations.
[New Feature] Add alibi, padding, layout support for attention bprop
node.
[New Feature] Introduce python bindings for plans. Allows  validate
graph, filter plans.

[Bug Fix] Fix relative includes of filenames in cudnn_frontend headers.
This resolves compilation issues in certain toolchains
[Bug Fix] Fix Segfault when dropout was set for some scaled dot product
flash attention nodes.

[New samples] Add new samples for `apply_rope`, `layernorm forward and
backward`, `rmsnorm forward and backward`
  • Loading branch information
Anerudhan committed Oct 19, 2023
1 parent ea7f8b9 commit d337a3c
Show file tree
Hide file tree
Showing 66 changed files with 6,641 additions and 3,372 deletions.
90 changes: 75 additions & 15 deletions docs/operations/Attention.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@
### Scaled Dot Product Flash Attention
Computes the scaled dot product attention for given Query, Key and Value tensors. Optionally, can set dropout probability, causal mask. Can optionally dump stats to be used for the bprop computation.

API:
The dimensions for

```
- Query tensor should be $(B, H, S_{q}, D)$
- Key tensor should be $(B, H, S_{kv}, D)$
- Value tensor should be $(B, H, S_{kv}, D)$
- Output tensor should be $(B, H, S_{q}, D)$
- Stats tensor should be $(B, H, S_{q}, 1)$

Where $B$ is the batch size, $H$ is the number of heads, $S_{q}$ is the sequence length of the query, $S_{kv}$ is the sequence length
of the key and value, and $D$ is the embedding dimension per head.

Additionally, the stride for the last dimension corresponding to the embedding dim per head for each of these tensors
must be 1.

**API:**

```cpp
std::array<std::shared_ptr<Tensor_attributes>, 2>
scaled_dot_product_flash_attention
(std::shared_ptr<Tensor_attributes> q,
Expand All @@ -18,44 +32,57 @@ scaled_dot_product_flash_attention

where the output array has tensors in order of: `[output, softmax_stats]` and `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation

```
```cpp
Scaled_dot_product_flash_attention_attributes &
set_is_inference(bool const value);

Scaled_dot_product_flash_attention_attributes &
set_causal_mask(bool const value);
set_attn_scale(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_attributes &
set_bias(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_attributes&
set_alibi_mask(bool const value)

Scaled_dot_product_flash_attention_attributes&
set_padding_mask(bool const value);

Scaled_dot_product_flash_attention_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_attributes&
set_seq_len_kv(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_attributes &
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
set_causal_mask(bool const value);

Scaled_dot_product_flash_attention_attributes &
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);

Scaled_dot_product_flash_attention_attributes &
set_dropout(std::shared_ptr<Tensor_attributes> mask, std::shared_ptr<Tensor_attributes> scale);
set_dropout(std::shared_ptr<Tensor_attributes> mask,
std::shared_ptr<Tensor_attributes> scale);

Scaled_dot_product_flash_attention_attributes &
set_compute_data_type(DataType_t value)
```
Python API:
**Python API:**
```
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
seq_len_q (Optional[cudnn_tensor]): The sequence length of the query.
seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key.
is_inference (bool): Whether it is an inference step or training step.
attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False.
seq_len_q (Optional[cudnn_tensor]): The sequence length of the query.
seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key.
use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
Expand All @@ -70,8 +97,23 @@ Returns:
### Scaled Dot Product Flash Attention Backward
Computes the query, key and value gradient tensors for scaled dot product flash attention. Optionally, can set dropout probability, causal mask.
The dimensions for
- Query tensor should be $(B, H, S_{q}, D)$
- Key tensor should be $(B, H, S_{kv}, D)$
- Value tensor should be $(B, H, S_{kv}, D)$
- Output tensor should be $(B, H, S_{q}, D)$
- Stats tensor should be $(B, H, S_{q}, 1)$
- Gradient tensors for query, key, value, and output should follow the same convention
Where $B$ is the batch size, $H$ is the number of heads, $S_{q}$ is the sequence length of the query, $S_{kv}$ is the sequence length
of the key and value, and $D$ is the embedding size per head.
Additionally, the stride for the last dimension corresponding to the embedding size per head for each of these tensors
must be 1.
API:
```
```cpp
std::array<std::shared_ptr<Tensor_attributes>, 3>
scaled_dot_product_flash_attention_backward
(std::shared_ptr<Tensor_attributes> q,
Expand All @@ -87,13 +129,25 @@ where the output array has tensors in order of: `[dQ, dK, dV]`
where, `Scaled_dot_product_flash_attention_backward_attributes` controls the sub-graph in the operation


```
```cpp
Scaled_dot_product_flash_attention_backward_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value)

Scaled_dot_product_flash_attention_backward_attributes&
set_bias(std::shared_ptr<Tensor_attributes> value)

Scaled_dot_product_flash_attention_backward_attributes&
set_alibi_mask(bool const value)

Scaled_dot_product_flash_attention_backward_attributes&
set_padding_mask(bool const value);

Scaled_dot_product_flash_attention_backward_attributes&
set_seq_len_q(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_backward_attributes&
set_seq_len_kv(std::shared_ptr<Tensor_attributes> value);

Scaled_dot_product_flash_attention_backward_attributes&
set_causal_mask(bool const value)

Expand All @@ -103,7 +157,9 @@ set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> offset)

Scaled_dot_product_flash_attention_backward_attributes&
set_dropout(std::shared_ptr<Tensor_attributes> mask, std::shared_ptr<Tensor_attributes> scale, std::shared_ptr<Tensor_attributes> scale_inv)
set_dropout(std::shared_ptr<Tensor_attributes> mask,
std::shared_ptr<Tensor_attributes> scale,
std::shared_ptr<Tensor_attributes> scale_inv)

Scaled_dot_product_flash_attention_backward_attributes&
set_compute_data_type(DataType_t const value)
Expand All @@ -119,10 +175,13 @@ Args:
o (cudnn_tensor): The output data.
dO (cudnn_tensor): The output loss gradient.
stats (cudnn_tensor): The softmax statistics from the forward pass.
attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None.
attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)],
Tuple[mask: cudnn_tensor, scale: cudnn_tensor, scale_inv: cudnn_tensor]]]):
Whether to do dropout. Default is None.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.

Expand All @@ -137,3 +196,4 @@ Returns:
- The cudnn backend enums are changed as follows:
- `cudnnBackend<enum_name>` -> `cudnn_frontend::<enum_name>`
- `cudnn<enum_name>` -> `cudnn_frontend::<enum_name>`
- Scaled Dot Product Flash Attention Backward improves computation speed by employing an optional workspace tensor, which consumes quadratically increasing memory usage relative to sequence length. The default GPU memory limit for the workspace tensor is 256MB, but users with enough available GPU memory budget can increase this limit by configuring the CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT environment variable to the desired new limit in bytes.
71 changes: 70 additions & 1 deletion docs/operations/Normalizations.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Python API:

### Layernorm Backward

DLN operation computes data graident, scale gradient, bias gradient during backpropagation of batchnorm forward operation.
DLN operation computes data graident, scale gradient, bias gradient during backpropagation of layernorm forward operation.

The API to achieve above is:
```
Expand All @@ -184,6 +184,75 @@ Layernorm_attributes&
set_compute_data_type(DataType_t value)
```

Python API:
- layernorm
- input
- scale
- loss
- compute_data_type
- name


### Instancenorm Forward

Instance norm computes

$$ output = scale*{input - mean \over \sqrt{variance + epsilon}} + bias $$

where normalization happens across each sample.

The API to achieve above equations is:
```
std::array<std::shared_ptr<Tensor_attributes>, 3> instancenorm(std::shared_ptr<Tensor_attributes>& input,
std::shared_ptr<Tensor_attributes>& scale,
std::shared_ptr<Tensor_attributes>& bias,
Instancenorm_attributes attribues);
```
where the output array has tensors in order of: `[output, mean, variance]`

Instancenorm_attributes is a lighweight structure with setters for providing optional input tensors and other operation attributes:
```
Instancenorm_attributes&
set_name(std::string const&)
Instancenorm_attributes&
set_compute_data_type(DataType_t value)
```

Python API:
- instancenorm
- norm_forward_phase
- input
- scale
- bias
- epsilon
- compute_data_type
- name


### Instancenorm Backward

DIN operation computes data graident, scale gradient, bias gradient during backpropagation of instancenorm forward operation.

The API to achieve above is:
```
std::array<std::shared_ptr<Tensor_attributes>, 3>
instancenorm_backward(std::shared_ptr<Tensor_attributes> dy,
std::shared_ptr<Tensor_attributes> x,
std::shared_ptr<Tensor_attributes> scale,
Instancenorm_backward_attributes options);
```
where the output array has tensors in order of: `[input gradient, scale gradient, bias gradient]`.

Instancenorm_attributes is a lighweight structure with setters for providing optoinal input tensors and other operation attributes:
```
Instancenorm_attributes&
set_name(std::string const&)
Instancenorm_attributes&
set_compute_data_type(DataType_t value)
```

Python API:
- layernorm
- input
Expand Down
2 changes: 1 addition & 1 deletion include/cudnn_frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@

#include "cudnn_frontend_Resample.h"

#include "cudnn_frontend/cudnn_frontend_graph_interface.h"
#include "cudnn_frontend/graph_interface.h"

#define CUDNN_FRONTEND_MAJOR_VERSION 1
#define CUDNN_FRONTEND_MINOR_VERSION 0
Expand Down
Loading

0 comments on commit d337a3c

Please sign in to comment.