diff --git a/CHANGELOG.md b/CHANGELOG.md index c551226f2c..df328734bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Four blocksparsity layouts from DeepSpeed [#320] - Support several initialization options [#312] - Conv2DFeedforward feedforward part [#321] +- VisualAttention [#329] ## [0.0.11] - 2022-05-30 diff --git a/README.md b/README.md index 096fc5d9ce..5180425d19 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,9 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* - [2D Pooling](xformers/components/attention/pooling.py) - *[Metaformer is actually what you need for vision, Yu et al.](https://arxiv.org/pdf/2111.11418v1.pdf)* +- [Visual Attention](xformers/components/attention/visual.py) + - *[`Visual Attention Network`_, Guo et al](https://arxiv.org/pdf/2202.09741.pdf)* + - ... add a new one [see Contribution.md](CONTRIBUTING.md)

@@ -199,7 +202,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
Initializations

This is completely optional, and will only occur when generating full models through xFormers, not when picking parts individually. - + There are basically two initialization mechanisms exposed, but the user is free to initialize weights as he/she sees fit after the fact. - Parts can expose a `init_weights()` method, which define sane defaults - xFormers supports [specific init schemes](xformers/factory/weight_init.py) which *can take precedence* over the init_weights() diff --git a/examples/cifarMetaformer.py b/examples/cifarMetaformer.py index cdc6e96eb1..0c52f61e3e 100644 --- a/examples/cifarMetaformer.py +++ b/examples/cifarMetaformer.py @@ -31,6 +31,7 @@ def __init__( num_classes=10, dim=384, attention="scaled_dot_product", + feedforward="MLP", layer_norm_style="pre", use_rotary_embeddings=True, linear_warmup_ratio=0.1, @@ -45,8 +46,7 @@ def __init__( # Generate the skeleton of our hierarchical Transformer # This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) - # Any other related config would work, - # and the attention mechanisms don't have to be the same across layers + # Any other related config would work, and the attention mechanisms don't have to be the same across layers base_hierarchical_configs = [ BasicLayerConfig( embedding=64, @@ -121,8 +121,8 @@ def forward(self, x): # Adjust batch depending on the available memory on your machine. # You can also use reversible layers to save memory - REF_BATCH = 512 - BATCH = 512 # lower if not enough GPU memory + REF_BATCH = 768 + BATCH = 256 # lower if not enough GPU memory MAX_EPOCHS = 50 NUM_WORKERS = 4 @@ -172,6 +172,7 @@ def forward(self, x): num_classes=num_classes, attention="scaled_dot_product", layer_norm_style="pre", + feedforward="MLP", use_rotary_embeddings=True, ) trainer = pl.Trainer( diff --git a/tests/test_attentions.py b/tests/test_attentions.py index 3ad29da3ed..825d36e477 100644 --- a/tests/test_attentions.py +++ b/tests/test_attentions.py @@ -27,7 +27,7 @@ ) BATCH = 2 -SEQ = 128 if torch.cuda.is_available() else 32 +SEQ = 128 if torch.cuda.is_available() else 36 MODEL = 128 if torch.cuda.is_available() else 16 GLOBAL_ATTENTION_RATIO = ( _DENSITY_THRESHOLD * 0.9 @@ -35,6 +35,8 @@ assert ATTENTION_REGISTRY.keys(), "Attention layers should have been registered" +_non_order_invariant_attentions = ["visual", "pooling"] + def _get_multihead( attention_name, @@ -93,7 +95,9 @@ def noop(x): @pytest.mark.parametrize("residual_dropout", [0.0, 0.1]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("heads", [1, 4]) -@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys()) +@pytest.mark.parametrize( + "attention_name", ATTENTION_REGISTRY.keys() - _non_order_invariant_attentions +) @pytest.mark.parametrize("device", DEVICES) def test_order_invariance( attention_name: str, @@ -104,9 +108,6 @@ def test_order_invariance( device: torch.device, ): - if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling": - pytest.skip(f"{attention_name} requires squared sequence lengths") - torch.manual_seed(42) torch.cuda.manual_seed_all(42) @@ -120,6 +121,12 @@ def test_order_invariance( use_seperate_proj_weights=False, ) + if ( + int(math.sqrt(SEQ)) ** 2 != SEQ + and multi_head.attention.requires_squared_context + ): + pytest.skip(f"{attention_name} requires squared sequence lengths") + # Check that a shuffled input produces the same results seqs = [SEQ, SEQ // 2] if (attention_name != "blocksparse") else [SEQ] @@ -304,12 +311,15 @@ def test_broadcast_batch_dimension( device: torch.device, batch_sizes: Tuple[int, int, int], ): - if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "pooling": - pytest.skip(f"{attention_name} requires squared sequence lengths") - Q_BATCH, K_BATCH, V_BATCH = batch_sizes multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device) + if ( + int(math.sqrt(SEQ)) ** 2 != SEQ + and multi_head.attention.requires_squared_context + ): + pytest.skip(f"{attention_name} requires squared sequence lengths") + if multi_head.attention.requires_same_k_q_dimensions: # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. pytest.skip(f"{attention_name} does not support different k, q dimensions yet.") @@ -388,7 +398,7 @@ def test_torch_script_ability( heads: int, attn_dropout: float, ): - if attention_name in {"favor", "global", "local", "random", "pooling"}: + if attention_name in {"favor", "global", "local", "random"}: # pyre-fixme[29]: The library function `pytest.skip` is not supported by Pyre. pytest.skip(f"{attention_name} does not support scripting yet.") @@ -396,6 +406,12 @@ def test_torch_script_ability( multi_head = _get_multihead(attention_name, attn_dropout, 0.0, False, heads, device) + if ( + int(math.sqrt(SEQ)) ** 2 != SEQ + and multi_head.attention.requires_squared_context + ): + pytest.skip(f"{attention_name} requires squared sequence lengths") + # input for tracing the function q = torch.rand((BATCH, SEQ, MODEL), device=device) k = torch.rand((BATCH, SEQ, MODEL), device=device) diff --git a/xformers/components/attention/base.py b/xformers/components/attention/base.py index 52d72f7e4b..376cd54bb1 100644 --- a/xformers/components/attention/base.py +++ b/xformers/components/attention/base.py @@ -53,6 +53,9 @@ def __init__(self, dropout: Optional[float] = None, *args, **kwargs): # so that the MHA wrapper should skip it self.requires_skip_multi_head = False + # This attention requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + # Whether this attention mechanism supports attention masks self.supports_attention_mask = True self.supports_key_padding_mask = False diff --git a/xformers/components/attention/pooling.py b/xformers/components/attention/pooling.py index eef85eab3b..d323e51deb 100644 --- a/xformers/components/attention/pooling.py +++ b/xformers/components/attention/pooling.py @@ -62,6 +62,10 @@ def __init__( # This operator does not really handle q,k,v self.requires_same_k_q_dimensions = True + # This attention requires the 2d structure out of the context, + # implictly assumed to be a squared length + self.requires_squared_context = True + def forward(self, q: torch.Tensor, *_, **__): # Expose the 2D token structure B, HW, C = q.shape diff --git a/xformers/components/attention/visual.py b/xformers/components/attention/visual.py new file mode 100644 index 0000000000..6ea81f41c2 --- /dev/null +++ b/xformers/components/attention/visual.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@dataclass +class VisualAttentionConfig(AttentionConfig): + dim_model: int # dimension of the input sequence + + +class LKA(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2d( + dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 + ) + self.conv1 = nn.Conv2d(dim, dim, 1) + + def forward(self, x: torch.Tensor): + u = x.clone() + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + + return u * attn + + +@register_attention("visual", VisualAttentionConfig) +class Visual(Attention): + def __init__( + self, + dim_model: int, + *_, + **__, + ): + """ + Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). + The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network + for the reference implementation + + .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) + and the prior and posterior transformations (Conv2d and activation) + + .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf + """ + super().__init__() + + self.block = nn.Sequential( + nn.Conv2d(dim_model, dim_model, 1), + nn.GELU(), + LKA(dim_model), + nn.Conv2d(dim_model, dim_model, 1), + ) + + # MHA related flags: + self.requires_same_k_q_dimensions = ( + True # This mechanism only really supports self attention + ) + self.supports_attention_mask = False + self.requires_skip_multi_head = ( + True # This mechanism skips the multihead attention altogether + ) + self.requires_squared_context = ( + True # Recovering the 2D structure from context assumes squared content + ) + + self.requires_input_projection = ( + False # This mechanism does not require that the MHA projects inputs + ) + + def forward(self, q: torch.Tensor, *_, **__): + # Expose the 2D token structure + B, HW, C = q.shape + H = int(math.sqrt(HW)) + assert H * H == HW + + x = q.transpose(-2, -1).reshape(B, C, H, H) + + # Large kernel attention + residual = x.clone() + x = self.block(x) + x = x + residual + + # Get back to B HW C + return x.flatten(2, 3).transpose(-2, -1) diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index 8e10401a63..99fbf5bd6f 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -278,7 +278,10 @@ def __init__(self, config: xFormerDecoderConfig, **kwargs): # Expose attention or feedforward specific capabilities self.supports_attention_mask = mha.attention.supports_attention_mask self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions - self.requires_squared_context_length = feedforward.requires_squared_context + self.requires_squared_context_length = ( + feedforward.requires_squared_context + or mha.attention.requires_squared_context + ) self.causal_attention = ( mha.attention.causal if hasattr(mha.attention, "causal") else False