Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Blocksparse] bug fixing half + sequence length #25

Merged
merged 3 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ A simple example is that of a causal attention: just compute the lower triangula

If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine,
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations.
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.

We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention. Please note that for now _blocksparse attention requires the sequence length to be a power of two_.

Let's look at an example:

```python
Expand Down
3 changes: 3 additions & 0 deletions docs/source/tutorials/blocksparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ A simple example is that of a causal attention: just compute the lower triangula
If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine,
BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations.
We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention.

*Please note that for now blocksparse attention requires the sequence length to be a power of two*.

Let's look at an example:

.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions xformers/benchmarks/benchmark_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def instantiate_xformer(
"layout": torch.eye(
sequence_length // block_size,
sequence_length // block_size,
dtype=torch.int,
dtype=torch.long,
)
.unsqueeze(0)
.expand(heads, -1, -1),
Expand Down Expand Up @@ -325,12 +325,12 @@ def plot(args, results: List[Dict[str, Any]]):
"-emb", "--embedding_dim", nargs="+", default=[64, 128, 256], type=int
)
parser.add_argument(
"-sl", "--sequence_length", nargs="+", default=[512, 768, 1024], type=int
"-sl", "--sequence_length", nargs="+", default=[512, 1024], type=int
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would have been nice to test for longer sequences, but 2048 OOMS with the vanilla attention on a V100..

)
parser.add_argument("-bs", "--batch_size", nargs="+", default=[8, 16, 32], type=int)
parser.add_argument("-heads", "--heads", nargs="+", default=[8, 16], type=int)

parser.add_argument("-fp16", "--pytorch_amp", nargs="+", default=[False], type=bool)
parser.add_argument("-fp16", "--pytorch_amp", nargs="+", default=[True], type=bool)
parser.add_argument("-causal", "--causal", nargs="+", default=[False], type=bool)
parser.add_argument("-plot", "--plot", action="store_true", default=False)
parser.add_argument(
Expand Down
39 changes: 34 additions & 5 deletions xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import logging
import math
from dataclasses import dataclass
from typing import Optional

Expand All @@ -22,6 +23,14 @@
from triton.ops.blocksparse import softmax as blocksparse_softmax

from xformers.triton.softmax import MaskType
from xformers.triton.utils import gpu_capabilities_older_than_70

# Blocksparse requires Tensor cores
if gpu_capabilities_older_than_70():
logging.warning(
"Blocksparse is not available: the current GPU does not expose Tensor cores"
)
_use_triton = False

except ImportError as e:
logging.warning(
Expand All @@ -47,6 +56,9 @@ class BlockSparseAttention(Attention):
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
If some dimensions are missing, we assume that the same layout is to be used across heads.

.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.

.. note: it is possible to pass a specific per batch mask in the forward call,
but this will not lead to any speed up.
Any constant sparsity pattern is better passed through the layout parameter.
Expand Down Expand Up @@ -141,12 +153,30 @@ def forward(
att_mask is None or att_mask.dim() == 2
), "The attention mask is constant across heads, expected dimensions are [seq x seq]"

# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
# When the computations are block sparse, the matrix types change along the way:
# - (sparse) attention matrix = (dense) Kt * (dense) Q
assert (
q.shape[-2] == k.shape[-2]
), "Blocksparse requires the same dimensions for K and Q for now"

assert (
q.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"

assert math.log(
q.shape[-2], 2
).is_integer(), (
"For now blocksparse only works on power-of-two sequence lengths"
)

# Blocksparse only works on fp16
q_dtype = q.dtype
q, k, v = q.half(), k.half(), v.half()

# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
# When the computations are block sparse, the matrix types change along the way:
# - (sparse) attention matrix = (dense) Kt * (dense) Q
sparse_att_mat = self.sparse_dot_sdd(q, k)

# - softmax on the sparse attention matrix
Expand All @@ -161,5 +191,4 @@ def forward(

# - then (dense) attention is (sparse) attention matrix * dense (value)
a = self.sparse_dot_dsd(sparse_att_mat, v)

return a
return a.to(q_dtype)