From 33e3889f0f29781385295fe6fea21485badbc510 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 22 Oct 2021 10:09:18 -0700 Subject: [PATCH 1/3] bug fixing for now, would be nice to PR Triton and support non-power-of-two sequence lengths --- xformers/benchmarks/benchmark_encoder.py | 6 ++--- xformers/components/attention/blocksparse.py | 28 ++++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/xformers/benchmarks/benchmark_encoder.py b/xformers/benchmarks/benchmark_encoder.py index bd1849ab4c..c6f1c7ca82 100644 --- a/xformers/benchmarks/benchmark_encoder.py +++ b/xformers/benchmarks/benchmark_encoder.py @@ -220,7 +220,7 @@ def instantiate_xformer( "layout": torch.eye( sequence_length // block_size, sequence_length // block_size, - dtype=torch.int, + dtype=torch.long, ) .unsqueeze(0) .expand(heads, -1, -1), @@ -325,12 +325,12 @@ def plot(args, results: List[Dict[str, Any]]): "-emb", "--embedding_dim", nargs="+", default=[64, 128, 256], type=int ) parser.add_argument( - "-sl", "--sequence_length", nargs="+", default=[512, 768, 1024], type=int + "-sl", "--sequence_length", nargs="+", default=[512, 1024], type=int ) parser.add_argument("-bs", "--batch_size", nargs="+", default=[8, 16, 32], type=int) parser.add_argument("-heads", "--heads", nargs="+", default=[8, 16], type=int) - parser.add_argument("-fp16", "--pytorch_amp", nargs="+", default=[False], type=bool) + parser.add_argument("-fp16", "--pytorch_amp", nargs="+", default=[True], type=bool) parser.add_argument("-causal", "--causal", nargs="+", default=[False], type=bool) parser.add_argument("-plot", "--plot", action="store_true", default=False) parser.add_argument( diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index ccd2856c5c..14d5e65624 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -5,6 +5,7 @@ import logging +import math from dataclasses import dataclass from typing import Optional @@ -141,12 +142,30 @@ def forward( att_mask is None or att_mask.dim() == 2 ), "The attention mask is constant across heads, expected dimensions are [seq x seq]" - # Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S) - # When the computations are block sparse, the matrix types change along the way: - # - (sparse) attention matrix = (dense) Kt * (dense) Q assert ( q.shape[-2] == k.shape[-2] ), "Blocksparse requires the same dimensions for K and Q for now" + + assert ( + q.shape[-2] == self.layout.shape[-2] * self.block_size + ), "Actual sequence size and layout are inconsistent" + assert ( + k.shape[-2] == self.layout.shape[-2] * self.block_size + ), "Actual sequence size and layout are inconsistent" + + assert math.log( + q.shape[-2], 2 + ).is_integer(), ( + "For now blocksparse only works on power-of-two sequence lengths" + ) + + # Blocksparse only works on fp16 + q_dtype = q.dtype + q, k, v = q.half(), k.half(), v.half() + + # Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S) + # When the computations are block sparse, the matrix types change along the way: + # - (sparse) attention matrix = (dense) Kt * (dense) Q sparse_att_mat = self.sparse_dot_sdd(q, k) # - softmax on the sparse attention matrix @@ -161,5 +180,4 @@ def forward( # - then (dense) attention is (sparse) attention matrix * dense (value) a = self.sparse_dot_dsd(sparse_att_mat, v) - - return a + return a.to(q_dtype) From 870afc9fc2ef8c9948fef5a822c8df3085798c22 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 22 Oct 2021 10:19:27 -0700 Subject: [PATCH 2/3] Add some warnings in the doc about the sequence length --- HOWTO.md | 4 +++- docs/source/tutorials/blocksparse.rst | 3 +++ xformers/components/attention/blocksparse.py | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/HOWTO.md b/HOWTO.md index 52466736ea..473d7dc40e 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -195,7 +195,9 @@ A simple example is that of a causal attention: just compute the lower triangula If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine, BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations. -We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention. + +We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention. Please note that for now _blocksparse attention requires the sequence length to be a power of two_. + Let's look at an example: ```python diff --git a/docs/source/tutorials/blocksparse.rst b/docs/source/tutorials/blocksparse.rst index 65eece79be..bf2945dc65 100644 --- a/docs/source/tutorials/blocksparse.rst +++ b/docs/source/tutorials/blocksparse.rst @@ -10,6 +10,9 @@ A simple example is that of a causal attention: just compute the lower triangula If you already have a per-coefficient pattern in mind and this is not a perfect match with a block pattern, this is probably fine, BlockSparse is fast enough so that dropping some of the computations after the fact with a fine-grained mask is still probably better than dense computations. We provide a small helper (this is just maxpooling) to convert in between a per coefficient binary mask and the layout that you will need to build a block sparse attention. + +*Please note that for now blocksparse attention requires the sequence length to be a power of two*. + Let's look at an example: .. code-block:: python diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index 14d5e65624..7905b3240c 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -48,6 +48,9 @@ class BlockSparseAttention(Attention): .. warning: the layout is assumed to have the dimensions [heads, seq, seq]. If some dimensions are missing, we assume that the same layout is to be used across heads. + .. warning: for now, the sequence (context) length has to be a power of two. This constraint could + be relaxed in the future. + .. note: it is possible to pass a specific per batch mask in the forward call, but this will not lead to any speed up. Any constant sparsity pattern is better passed through the layout parameter. From 5a9edb6687a40c5ecd3ae5dc55b3cdbdcf8bf35f Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 22 Oct 2021 10:40:29 -0700 Subject: [PATCH 3/3] do not expose blocksparse if not tensor core enabled GPU --- xformers/components/attention/blocksparse.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index 7905b3240c..94fa6c9dfc 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -23,6 +23,14 @@ from triton.ops.blocksparse import softmax as blocksparse_softmax from xformers.triton.softmax import MaskType + from xformers.triton.utils import gpu_capabilities_older_than_70 + + # Blocksparse requires Tensor cores + if gpu_capabilities_older_than_70(): + logging.warning( + "Blocksparse is not available: the current GPU does not expose Tensor cores" + ) + _use_triton = False except ImportError as e: logging.warning(