From 42e5c271a504db246c0ac4646d15fc5a3b5173aa Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 7 Nov 2022 09:38:35 -0500 Subject: [PATCH 1/2] Repeat mask when batch dims don't match (#506) * repeat mask when batch dims don't match * do not do for sparseCS --- tests/test_core_attention.py | 9 +++++++++ xformers/components/attention/core.py | 10 ++++++++++ 2 files changed, 19 insertions(+) diff --git a/tests/test_core_attention.py b/tests/test_core_attention.py index 3a9b01fba5..5e022b96cd 100644 --- a/tests/test_core_attention.py +++ b/tests/test_core_attention.py @@ -59,6 +59,15 @@ def test_core_attention_mask_types(): # Now properly handled assert torch.allclose(r_dense_add, r_sparse_add) + # Test additive mask with mismatched batch dim + d = b // 2 + mask = torch.rand(d, s, s) > prob + float_mask_add = torch.zeros_like(mask, dtype=torch.float) + float_mask_add = float_mask_add.masked_fill(mask, float("-inf")) + + # Make sure masking doesn't return errors + r_dense_add = scaled_dot_product_attention(a, a, a, float_mask_add) + @pytest.mark.parametrize("device", _devices) def test_amp_attention_dense_no_mask(device): diff --git a/xformers/components/attention/core.py b/xformers/components/attention/core.py index 443bbdacb0..6709eaa84a 100644 --- a/xformers/components/attention/core.py +++ b/xformers/components/attention/core.py @@ -106,6 +106,16 @@ def _matmul_with_mask( att[~mask] = float("-inf") else: # mask is presumed additive + # repeat if batch sizes don't match + if ( + not isinstance(mask, SparseCS) + and mask.ndim == 3 + and mask.shape[0] != att.shape[0] + and (att.shape[0] % mask.shape[0]) == 0 + ): + repeat_factor = att.shape[0] // mask.shape[0] + mask = mask.repeat([repeat_factor, 1, 1]) + logger.info("Mismatched batch dimensions for mask, repeating mask.") att += mask return att From 44c560d374d4a2dfc4fc5e75a9b859d440262a1a Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 7 Nov 2022 09:38:51 -0500 Subject: [PATCH 2/2] enable AttentionMask in Block Factory (#503) --- tests/test_block_factory.py | 12 +++++++++--- xformers/components/attention/global_tokens.py | 8 +++++--- xformers/components/attention/local.py | 15 ++++++++++----- xformers/components/attention/ortho.py | 11 ++++++++--- xformers/components/attention/random.py | 7 +++++-- xformers/factory/block_factory.py | 7 ++++--- 6 files changed, 41 insertions(+), 19 deletions(-) diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 275fcd5f56..79e85f7a1d 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -8,7 +8,7 @@ # Automatically fetch all registered attentions and Feedforwards from xformers.components import Activation -from xformers.components.attention import ATTENTION_REGISTRY +from xformers.components.attention import ATTENTION_REGISTRY, AttentionMask from xformers.components.feedforward import FEEDFORWARD_REGISTRY from xformers.factory import ( xFormerDecoderBlock, @@ -112,10 +112,12 @@ def test_xformer_encoder_block( _ = block(inputs) # Check that we support attention masking, at least interface wise (do not check correctness yet) - att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask = AttentionMask.from_bool(att_mask_tensor) if block.supports_attention_mask: _ = block(inputs, att_mask=att_mask) + _ = block(inputs, att_mask=att_mask_tensor) else: with pytest.raises(AssertionError): # Check that passing an attention mask to a mechanism which does not support it raises @@ -226,7 +228,8 @@ def test_xformer_decoder_block( ) # NOTE: does not make a lot of sense, just checking dimensions # Check that we support masking, at least interface wise (do not check correctness yet) - att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask_tensor = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + att_mask = AttentionMask.from_bool(att_mask_tensor) input_mask = torch.randn(SEQ, dtype=torch.float, device=device) input_mask[input_mask < 0.0] = -float("inf") @@ -235,6 +238,9 @@ def test_xformer_decoder_block( _ = decoder_block( inputs, encoded, encoder_att_mask=att_mask, input_mask=input_mask ) + _ = decoder_block( + inputs, encoded, encoder_att_mask=att_mask_tensor, input_mask=input_mask + ) # Test different sequence lengths when encoding and decoding if ( diff --git a/xformers/components/attention/global_tokens.py b/xformers/components/attention/global_tokens.py index d0a6f0166e..653ed619c8 100644 --- a/xformers/components/attention/global_tokens.py +++ b/xformers/components/attention/global_tokens.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -88,7 +88,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *_, **__, ): @@ -101,7 +101,9 @@ def forward( if att_mask.dtype == torch.bool and isinstance( self.attention_mask, AttentionMask ): - mask = self.attention_mask + AttentionMask.from_bool(att_mask) + if not isinstance(att_mask, AttentionMask): + att_mask = AttentionMask.from_bool(att_mask) + mask = self.attention_mask + att_mask else: mask = self.attention_mask & att_mask else: diff --git a/xformers/components/attention/local.py b/xformers/components/attention/local.py index 68df4bca3d..3220a8d401 100644 --- a/xformers/components/attention/local.py +++ b/xformers/components/attention/local.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -13,6 +13,7 @@ from xformers.components.attention import ( Attention, AttentionConfig, + AttentionMask, maybe_sparsify, register_attention, sparsify, @@ -97,7 +98,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *args, **kwargs, ): @@ -106,9 +107,13 @@ def forward( self.attention_mask = self._get_local_mask(q.shape).to(q.device) # Take into account the optional user mask - mask = ( - self.attention_mask if att_mask is None else self.attention_mask & att_mask - ) + if att_mask is None: + mask = self.attention_mask + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.attention_mask & att_mask return scaled_dot_product_attention( q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop diff --git a/xformers/components/attention/ortho.py b/xformers/components/attention/ortho.py index 392ed96f36..3737f6cdd0 100644 --- a/xformers/components/attention/ortho.py +++ b/xformers/components/attention/ortho.py @@ -7,14 +7,19 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Union import torch import torch.autograd.profiler as profiler import torch.nn as nn import torch.nn.functional as Fn -from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) from xformers.components.attention.core import ( scaled_dot_product_attention, scaled_query_key_softmax, @@ -83,7 +88,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, *args, **kwargs, ): diff --git a/xformers/components/attention/random.py b/xformers/components/attention/random.py index 5e3ee08e69..e07e6c8679 100644 --- a/xformers/components/attention/random.py +++ b/xformers/components/attention/random.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -91,7 +91,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, *args, **kwargs, ): @@ -106,6 +106,9 @@ def forward( ): mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask) else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() mask = self.rand_attention_mask & att_mask else: mask = self.rand_attention_mask diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index e139a2a6fc..113f440fe0 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -20,6 +20,7 @@ build_multi_head_attention, build_patch_embedding, ) +from xformers.components.attention import AttentionMask from xformers.components.feedforward import build_feedforward from xformers.components.positional_embedding import build_positional_embedding from xformers.components.residual import get_deepnorm_coefficients @@ -206,7 +207,7 @@ def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]: def forward( self, x: torch.Tensor, - att_mask: Optional[torch.Tensor] = None, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, input_mask: Optional[torch.Tensor] = None, ): if self.patch_emb is not None: @@ -327,8 +328,8 @@ def forward( self, target: torch.Tensor, memory: torch.Tensor, - encoder_att_mask: Optional[torch.Tensor] = None, - decoder_att_mask: Optional[torch.Tensor] = None, + encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, input_mask: Optional[torch.Tensor] = None, ): if self.pose_encoding is not None: