diff --git a/CHANGELOG.md b/CHANGELOG.md index b6002c2607..679ab09de0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - bugfix Favor, single feature map [#183] +- sanity check blocksparse settings [#207] +- fixed some pickability [#204] ### Added - Mixture of Experts [#181] diff --git a/tests/test_triton_blocksparse.py b/tests/test_triton_blocksparse.py index 459a2e7ceb..773909f962 100644 --- a/tests/test_triton_blocksparse.py +++ b/tests/test_triton_blocksparse.py @@ -139,7 +139,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE): @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") -@pytest.mark.parametrize("block", [32]) # 16, 32, +@pytest.mark.parametrize("block", [32, 43]) # 16, 32, def test_attention_fwd_bwd( block, input_scale=1.0, @@ -177,45 +177,50 @@ def loss_fn(x): query.retain_grad() key.retain_grad() value.retain_grad() - block_sparse_attention = BlockSparseAttention(layout, block) - attn_out = block_sparse_attention( - att_mask=attn_mask, q=query, k=key, v=value, scale=scale - ) - - # ad hoc loss - loss = loss_fn(attn_out) - loss.backward() - grads = [query.grad, key.grad, value.grad] - - # Torch version: - torch_q, torch_k, torch_v = [x.clone() for x in qkvs] - torch_q = torch_q / math.sqrt(head_dim) - attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) - torch_q.retain_grad() - torch_k.retain_grad() - torch_v.retain_grad() - scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) - scores = scores + attn_mask - probs = torch.softmax(scores, dim=-1) - torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) - - # ad hoc loss - torch_loss = loss_fn(torch_attn_out) - torch_loss.backward() - torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] - - # comparison - assert_almost_equal( - loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}" - ) + if block not in [16, 32, 64]: + # Check that unsupported dimensions are caught + with pytest.raises(AssertionError): + _ = BlockSparseAttention(layout, block) + else: + block_sparse_attention = BlockSparseAttention(layout, block) + attn_out = block_sparse_attention( + att_mask=attn_mask, q=query, k=key, v=value, scale=scale + ) - for g1, g2 in zip(grads, torch_grads): + # ad hoc loss + loss = loss_fn(attn_out) + loss.backward() + grads = [query.grad, key.grad, value.grad] + + # Torch version: + torch_q, torch_k, torch_v = [x.clone() for x in qkvs] + torch_q = torch_q / math.sqrt(head_dim) + attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda())) + torch_q.retain_grad() + torch_k.retain_grad() + torch_v.retain_grad() + scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k) + scores = scores + attn_mask + probs = torch.softmax(scores, dim=-1) + torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) + + # ad hoc loss + torch_loss = loss_fn(torch_attn_out) + torch_loss.backward() + torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] + + # comparison assert_almost_equal( - torch.norm(g1), - torch.norm(g2), - err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}", + loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}" ) + for g1, g2 in zip(grads, torch_grads): + assert_almost_equal( + torch.norm(g1), + torch.norm(g2), + err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}", + ) + @pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu") def test_blocksparse_attention_parity(): diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index afb5da7c15..e3b691784c 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -52,6 +52,9 @@ class BlockSparseAttention(Attention): .. warning: for now, the sequence (context) length has to be a power of two. This constraint could be relaxed in the future. + .. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks. + It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit. + .. note: it is possible to pass a specific per batch mask in the forward call, but this will not lead to any speed up. Any constant sparsity pattern is better passed through the layout parameter. @@ -76,7 +79,11 @@ def __init__( layout = layout.unsqueeze(0).expand(num_heads, -1, -1) logging.warning(f"New layout dimensions: {layout.shape}") - assert block_size >= 16, "Minimum block size is 16, for now at least" + assert block_size in ( + 16, + 32, + 64, + ), "Only block sizes in [16, 32, 64] are supported" super().__init__() self.attn_drop = torch.nn.Dropout(dropout, inplace=False)