Skip to content

Commit

Permalink
Update on "SwiGLU optimized fw/bw"
Browse files Browse the repository at this point in the history
**NOTE**
We can improve a bit more once this is fixed - NVIDIA/cutlass#674

**USAGE**

```python
import xformers.ops as xops

# NOTE: Important to use `unbind` from xformers for the bw pass!
w1, w2 = xops.unbind(
    w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
    dim=0,
)
b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
y = xops.functional_swiglu(x,
    w1, b1, w2, b2, w3, b3)
```

**PERFORMANCE (A100 only)**

*FW*
```
[-------------------------------------------------------- swiglu_fw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               1377.7               |  1581.4  |         1339.1
      f16.ac B=9456, I=1536, H=4096  |               1449.3               |  1735.3  |         1462.9
      f16    B=4440, I=1536, H=4096  |                600.4               |   735.6  |          593.9
      f16.ac B=4440, I=1536, H=4096  |                709.0               |   843.7  |          717.6
      f16    B=4728, I=1536, H=4096  |                638.9               |   776.2  |          635.3
      f16.ac B=4728, I=1536, H=4096  |                748.9               |   892.2  |          756.7
      f16    B=4728, I=1536, H=1024  |                162.3               |   201.5  |          163.1
      f16.ac B=4728, I=1536, H=1024  |                235.2               |   277.4  |          245.5

Times are in microseconds (us).
```

*BW*
```
[-------------------------------------------------------- swiglu_bw ---------------------------------------------------------]
                                     |  SwiGLUPackedFusedOp[fused.p.cpp]  |  eager   |  SwiGLUFusedOp[fused]
1 threads: -------------------------------------------------------------------------------------------------------------------
      f16    B=9456, I=1536, H=4096  |               2333.1               |  2696.7  |         2336.1
      f16.ac B=9456, I=1536, H=4096  |               2620.8               |  2990.9  |         2840.0
      f16    B=4440, I=1536, H=4096  |               1243.2               |  1413.8  |         1240.3
      f16.ac B=4440, I=1536, H=4096  |               1448.6               |  1629.0  |         1637.3
      f16    B=4728, I=1536, H=4096  |               1298.4               |  1481.5  |         1301.1
      f16.ac B=4728, I=1536, H=4096  |               1511.8               |  1705.3  |         1705.4
      f16    B=4728, I=1536, H=1024  |                463.3               |   493.9  |          463.0
      f16.ac B=4728, I=1536, H=1024  |                582.4               |   614.9  |          672.7

Times are in microseconds (us).
```

[ghstack-poisoned]
  • Loading branch information
danthe3rd committed Nov 10, 2022
2 parents 3490242 + 12733f0 commit a90fe49
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 20 deletions.
12 changes: 9 additions & 3 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Automatically fetch all registered attentions and Feedforwards
from xformers.components import Activation
from xformers.components.attention import ATTENTION_REGISTRY
from xformers.components.attention import ATTENTION_REGISTRY, AttentionMask
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
from xformers.factory import (
xFormerDecoderBlock,
Expand Down Expand Up @@ -112,10 +112,12 @@ def test_xformer_encoder_block(
_ = block(inputs)

# Check that we support attention masking, at least interface wise (do not check correctness yet)
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
att_mask = AttentionMask.from_bool(att_mask_tensor)

if block.supports_attention_mask:
_ = block(inputs, att_mask=att_mask)
_ = block(inputs, att_mask=att_mask_tensor)
else:
with pytest.raises(AssertionError):
# Check that passing an attention mask to a mechanism which does not support it raises
Expand Down Expand Up @@ -226,7 +228,8 @@ def test_xformer_decoder_block(
) # NOTE: does not make a lot of sense, just checking dimensions

# Check that we support masking, at least interface wise (do not check correctness yet)
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
att_mask = AttentionMask.from_bool(att_mask_tensor)
input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")

Expand All @@ -235,6 +238,9 @@ def test_xformer_decoder_block(
_ = decoder_block(
inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask
)
_ = decoder_block(
inputs, encoded, encoder_att_mask=att_mask_tensor, input_mask=input_mask
)

# Test different sequence lengths when encoding and decoding
if (
Expand Down
9 changes: 9 additions & 0 deletions tests/test_core_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test_core_attention_mask_types():
# Now properly handled
assert torch.allclose(r_dense_add, r_sparse_add)

# Test additive mask with mismatched batch dim
d = b // 2
mask = torch.rand(d, s, s) > prob
float_mask_add = torch.zeros_like(mask, dtype=torch.float)
float_mask_add = float_mask_add.masked_fill(mask, float("-inf"))

# Make sure masking doesn't return errors
r_dense_add = scaled_dot_product_attention(a, a, a, float_mask_add)


@pytest.mark.parametrize("device", _devices)
def test_amp_attention_dense_no_mask(device):
Expand Down
10 changes: 10 additions & 0 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ def _matmul_with_mask(
att[~mask] = float("-inf")
else:
# mask is presumed additive
# repeat if batch sizes don't match
if (
not isinstance(mask, SparseCS)
and mask.ndim == 3
and mask.shape[0] != att.shape[0]
and (att.shape[0] % mask.shape[0]) == 0
):
repeat_factor = att.shape[0] // mask.shape[0]
mask = mask.repeat([repeat_factor, 1, 1])
logger.info("Mismatched batch dimensions for mask, repeating mask.")
att += mask
return att

Expand Down
8 changes: 5 additions & 3 deletions xformers/components/attention/global_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -88,7 +88,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*_,
**__,
):
Expand All @@ -101,7 +101,9 @@ def forward(
if att_mask.dtype == torch.bool and isinstance(
self.attention_mask, AttentionMask
):
mask = self.attention_mask + AttentionMask.from_bool(att_mask)
if not isinstance(att_mask, AttentionMask):
att_mask = AttentionMask.from_bool(att_mask)
mask = self.attention_mask + att_mask
else:
mask = self.attention_mask & att_mask
else:
Expand Down
15 changes: 10 additions & 5 deletions xformers/components/attention/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@


from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn

from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
Expand Down Expand Up @@ -97,7 +98,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
Expand All @@ -106,9 +107,13 @@ def forward(
self.attention_mask = self._get_local_mask(q.shape).to(q.device)

# Take into account the optional user mask
mask = (
self.attention_mask if att_mask is None else self.attention_mask & att_mask
)
if att_mask is None:
mask = self.attention_mask
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.attention_mask & att_mask

return scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
Expand Down
11 changes: 8 additions & 3 deletions xformers/components/attention/ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import Optional, Union

import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
import torch.nn.functional as Fn

from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
Expand Down Expand Up @@ -83,7 +88,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
):
Expand Down
7 changes: 5 additions & 2 deletions xformers/components/attention/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -91,7 +91,7 @@ def forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
Expand All @@ -106,6 +106,9 @@ def forward(
):
mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.rand_attention_mask & att_mask
else:
mask = self.rand_attention_mask
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
EpilogueOutputOp01,
EpilogueOutputOp01,
EpilogueOutputOp2,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>,
kStages,
kStoreD0,
kStoreD1,
Expand Down
7 changes: 4 additions & 3 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
build_multi_head_attention,
build_patch_embedding,
)
from xformers.components.attention import AttentionMask
from xformers.components.feedforward import build_feedforward
from xformers.components.positional_embedding import build_positional_embedding
from xformers.components.residual import get_deepnorm_coefficients
Expand Down Expand Up @@ -206,7 +207,7 @@ def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
def forward(
self,
x: torch.Tensor,
att_mask: Optional[torch.Tensor] = None,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.patch_emb is not None:
Expand Down Expand Up @@ -327,8 +328,8 @@ def forward(
self,
target: torch.Tensor,
memory: torch.Tensor,
encoder_att_mask: Optional[torch.Tensor] = None,
decoder_att_mask: Optional[torch.Tensor] = None,
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.pose_encoding is not None:
Expand Down

0 comments on commit a90fe49

Please sign in to comment.