Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] blocksparse sanity checks #207

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
- bugfix Favor, single feature map [#183]
- sanity check blocksparse settings [#207]
- fixed some pickability [#204]

### Added
- Mixture of Experts [#181]
Expand Down
77 changes: 41 additions & 36 deletions tests/test_triton_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_softmax(BLOCK, WIDTH, DTYPE):


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
@pytest.mark.parametrize("block", [32]) # 16, 32,
@pytest.mark.parametrize("block", [32, 43]) # 16, 32,
def test_attention_fwd_bwd(
block,
input_scale=1.0,
Expand Down Expand Up @@ -177,45 +177,50 @@ def loss_fn(x):
query.retain_grad()
key.retain_grad()
value.retain_grad()
block_sparse_attention = BlockSparseAttention(layout, block)
attn_out = block_sparse_attention(
att_mask=attn_mask, q=query, k=key, v=value, scale=scale
)

# ad hoc loss
loss = loss_fn(attn_out)
loss.backward()
grads = [query.grad, key.grad, value.grad]

# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
torch_q = torch_q / math.sqrt(head_dim)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)

# ad hoc loss
torch_loss = loss_fn(torch_attn_out)
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]

# comparison
assert_almost_equal(
loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}"
)
if block not in [16, 32, 64]:
# Check that unsupported dimensions are caught
with pytest.raises(AssertionError):
_ = BlockSparseAttention(layout, block)
else:
block_sparse_attention = BlockSparseAttention(layout, block)
attn_out = block_sparse_attention(
att_mask=attn_mask, q=query, k=key, v=value, scale=scale
)

for g1, g2 in zip(grads, torch_grads):
# ad hoc loss
loss = loss_fn(attn_out)
loss.backward()
grads = [query.grad, key.grad, value.grad]

# Torch version:
torch_q, torch_k, torch_v = [x.clone() for x in qkvs]
torch_q = torch_q / math.sqrt(head_dim)
attn_mask = 1e6 * (-1 + (attn_mask.reshape((1, 1, n_ctx, n_ctx)).cuda()))
torch_q.retain_grad()
torch_k.retain_grad()
torch_v.retain_grad()
scores = scale * torch.einsum("bhsd,bhtd->bhst", torch_q, torch_k)
scores = scores + attn_mask
probs = torch.softmax(scores, dim=-1)
torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v)

# ad hoc loss
torch_loss = loss_fn(torch_attn_out)
torch_loss.backward()
torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad]

# comparison
assert_almost_equal(
torch.norm(g1),
torch.norm(g2),
err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}",
loss, torch_loss, err_msg=f"Triton loss {loss} and torch loss {torch_loss}"
)

for g1, g2 in zip(grads, torch_grads):
assert_almost_equal(
torch.norm(g1),
torch.norm(g2),
err_msg=f"Triton grad {torch.norm(g1).item()} and torch grad {torch.norm(g2).item()}",
)


@pytest.mark.skipif(not _triton_available, reason="Triton requires a recent CUDA gpu")
def test_blocksparse_attention_parity():
Expand Down
9 changes: 8 additions & 1 deletion xformers/components/attention/blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class BlockSparseAttention(Attention):
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.

.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.

.. note: it is possible to pass a specific per batch mask in the forward call,
but this will not lead to any speed up.
Any constant sparsity pattern is better passed through the layout parameter.
Expand All @@ -76,7 +79,11 @@ def __init__(
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
logging.warning(f"New layout dimensions: {layout.shape}")

assert block_size >= 16, "Minimum block size is 16, for now at least"
assert block_size in (
16,
32,
64,
), "Only block sizes in [16, 32, 64] are supported"

super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
Expand Down