Skip to content

Commit

Permalink
cudnn prerelease_3:
Browse files Browse the repository at this point in the history
Improvements over prerelease 2:
[Feature] Added SDPA flash attention backwward node.
[Enhancement] Resolved an issue where the computed Alibi slopes were copied onto GPU memory on default stream instead of user specified stream in the handle.
[Bug fix] Fix  windows compilation error when pedantic warnings are treated as error.
[Bug fix] Fixed issue in causal padding where the masked values were `std::numeric_limits<float>::min()` instead of `std::numeric_limits<float>::lowest()`

Under investigation and development:
- We are still working on additional features for SDPA back prop.
- Better error messages and logging
  • Loading branch information
Anerudhan committed Sep 25, 2023
1 parent 6e59c45 commit ea7f8b9
Show file tree
Hide file tree
Showing 33 changed files with 2,109 additions and 778 deletions.
3 changes: 2 additions & 1 deletion README.FE.1.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ FE v1.0 API follows a functional style of building a graph. Operations take in i
| Generate stats of output| genstats <br>Genstats_attributes | genstats |
| BN Finalize of stats | bn_finalize <br>BN_finalize_attributes | bn_finalize |
| Dbn weight | dbn_weight <br>DBN_weight_attributes | dbn_weight |
| Scale dot product flash attention | scaled_dot_product_flash_attention<br> Scaled_dot_product_flash_attention_attributes | scaled_dot_product_flash_attention|
| Scale dot product flash attention | scaled_dot_product_flash_attention<br> Scaled_dot_product_flash_attention_attributes | scaled_dot_product_flash_attention |
| Scale dot product flash attention_backward | scaled_dot_product_flash_attention_backward<br> Scaled_dot_product_flash_attention_backward_attributes | scaled_dot_product_flash_attention_backward |

### Create Graph
Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations.
Expand Down
152 changes: 112 additions & 40 deletions docs/operations/Attention.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
## Table of Contents
1. [Scaled Dot Product Flash Attention](#Scaled Dot Product Flash Attention)

1. [Scaled Dot Product Flash Attention](#scaled-dot-product-flash-attention)
2. [Scaled Dot Product Flash Attention Backward](#scaled-dot-product-flash-attention-backward)

### Scaled Dot Product Flash Attention
Computes the scaled dot product attention for given Query, Key and Value tensors. Optionally, can set dropout probability, causal mask. Can optionally dump stats to be used for the bprop computation.

API:

```
std::array<std::shared_ptr<Tensor_attributes>, 2>
scaled_dot_product_flash_attention
Expand All @@ -15,50 +16,121 @@ scaled_dot_product_flash_attention
Scaled_dot_product_flash_attention_attributes options);
```

where the output array has tensors in order of: `[output, softmax_stats]`
where, `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation
where the output array has tensors in order of: `[output, softmax_stats]` and `Scaled_dot_product_flash_attention_attributes` controls the sub-graph in the operation

```
Scaled_dot_product_flash_attention_attributes &
set_is_inference(bool const value);
Scaled_dot_product_flash_attention_attributes &
set_causal_mask(bool const value);
Scaled_dot_product_flash_attention_attributes &
set_bias(std::shared_ptr<Tensor_attributes> value);
Scaled_dot_product_flash_attention_attributes &
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
Scaled_dot_product_flash_attention_attributes &
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);
Scaled_dot_product_flash_attention_attributes &
set_dropout(std::shared_ptr<Tensor_attributes> mask, std::shared_ptr<Tensor_attributes> scale);
Scaled_dot_product_flash_attention_attributes &
set_compute_data_type(DataType_t value)
```

Python API:

```
Scaled_dot_product_flash_attention_attributes &
set_is_inference(bool const value);
Scaled_dot_product_flash_attention_attributes &
set_causal_mask(bool const value);
Scaled_dot_product_flash_attention_attributes &
set_bias(std::shared_ptr<Tensor_attributes> value);
Scaled_dot_product_flash_attention_attributes &
set_attn_scale(std::shared_ptr<Tensor_attributes> value);
Scaled_dot_product_flash_attention_attributes &
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset);
Scaled_dot_product_flash_attention_attributes &
set_dropout(std::shared_ptr<Tensor_attributes> mask, std::shared_ptr<Tensor_attributes> scale);
Scaled_dot_product_flash_attention_attributes &
set_compute_data_type(DataType_t value)
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
seq_len_q (Optional[cudnn_tensor]): The sequence length of the query.
seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key.
is_inference (bool): Whether it is an inference step or training step.
attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False.
use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
o (cudnn_tensor): The result of scaled dot-product flash attention.
stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step.
```

### Scaled Dot Product Flash Attention Backward
Computes the query, key and value gradient tensors for scaled dot product flash attention. Optionally, can set dropout probability, causal mask.

API:
```
std::array<std::shared_ptr<Tensor_attributes>, 3>
scaled_dot_product_flash_attention_backward
(std::shared_ptr<Tensor_attributes> q,
std::shared_ptr<Tensor_attributes> k,
std::shared_ptr<Tensor_attributes> v,
std::shared_ptr<Tensor_attributes> o,
std::shared_ptr<Tensor_attributes> dO,
std::shared_ptr<Tensor_attributes> stats,
Scaled_dot_product_flash_attention_backward_attributes);
```

where the output array has tensors in order of: `[dQ, dK, dV]`
where, `Scaled_dot_product_flash_attention_backward_attributes` controls the sub-graph in the operation


```
Scaled_dot_product_flash_attention_backward_attributes&
set_attn_scale(std::shared_ptr<Tensor_attributes> value)
Scaled_dot_product_flash_attention_backward_attributes&
set_bias(std::shared_ptr<Tensor_attributes> value)
Scaled_dot_product_flash_attention_backward_attributes&
set_causal_mask(bool const value)
Scaled_dot_product_flash_attention_backward_attributes&
set_dropout(float const probability,
std::shared_ptr<Tensor_attributes> seed,
std::shared_ptr<Tensor_attributes> offset)
Scaled_dot_product_flash_attention_backward_attributes&
set_dropout(std::shared_ptr<Tensor_attributes> mask, std::shared_ptr<Tensor_attributes> scale, std::shared_ptr<Tensor_attributes> scale_inv)
Scaled_dot_product_flash_attention_backward_attributes&
set_compute_data_type(DataType_t const value)
```

Python API:
- q
- k
- v
- seq_q
- seq_k
- is_inference
- attn_scale
- bias
- use_padding_mask
- use_alibi_mask
- use_causal_mask
- dropout
- compute_data_type
- name

```
Args:
q (cudnn_tensor): The query data.
k (cudnn_tensor): The key data.
v (cudnn_tensor): The value data.
o (cudnn_tensor): The output data.
dO (cudnn_tensor): The output loss gradient.
stats (cudnn_tensor): The softmax statistics from the forward pass.
attn_scale (Optional[cudnn_tensor]): The scale factor for attention. Default is None.
bias (Optional[cudnn_tensor]): The bias data for attention. Default is None.
use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False.
dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None.
compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET.
name (Optional[str]): The name of the operation.
Returns:
dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention.
dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention.
dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention.
```

## Miscellaneous
- FE provides shadow enums which help avoid users to workaround having different enums for different cudnn versions.
Expand Down
8 changes: 5 additions & 3 deletions include/cudnn_frontend/cudnn_frontend_cudnn_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,24 @@ class ICudnn {
error_t
create_cudnn_tensor(std::shared_ptr<graph::Tensor_attributes> const& props) {
// Check whether tensor already created
if (tensors.find(props->get_uid()) != tensors.end()) {
auto const uid = props->get_uid();
if (tensors.find(uid) != tensors.end()) {
getLogger() << "[cudnn_frontend] INFO: Backend tensor already created for Id: " << uid << ".\n";
return {error_code_t::OK, ""};
}

// Create new backend tensor
auto tensor = cudnn_frontend::TensorBuilder()
.setDim(props->get_dim().size(), props->get_dim().data())
.setStrides(props->get_stride().size(), props->get_stride().data())
.setId(props->get_uid())
.setId(uid)
.setAlignment(16)
.setDataType(props->get_data_type())
.setVirtual(props->get_is_virtual())
.setByValue(props->get_is_pass_by_value())
.setReorderType(props->get_reordering_type())
.build();
tensors.emplace(props->get_uid(), std::make_shared<Tensor>(std::move(tensor)));
tensors.emplace(uid, std::make_shared<Tensor>(std::move(tensor)));

return {error_code_t::OK, ""};
}
Expand Down
Loading

0 comments on commit ea7f8b9

Please sign in to comment.