diff --git a/xformers/components/attention/blocksparse.py b/xformers/components/attention/blocksparse.py index a9b701ce4e..d6302e8e4a 100644 --- a/xformers/components/attention/blocksparse.py +++ b/xformers/components/attention/blocksparse.py @@ -13,6 +13,7 @@ from torch import nn from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.utils import bool_mask_to_additive _mask_type_warning = True @@ -119,10 +120,9 @@ def update_mask_type(self, mask: torch.Tensor, to_dtype: torch.dtype): global _mask_type_warning if _mask_type_warning: logging.warning( - "Mask has to be multiplicative. Fixing that but this slows things down" + "Mask has to be additive. Fixing that but this slows things down" ) - _mask_type_warning = False # Only warn once - mask = mask.to(to_dtype) + mask = bool_mask_to_additive(mask) def forward( self, @@ -136,11 +136,10 @@ def forward( **kwargs, ) -> torch.Tensor: r""" - att_mask A 2D attention mask. The dtype must be the same as q. Multiplicative mask where a value - of 1 will keep the value, while a value of 0 will mask the value. + att_mask A 2D attention mask. The dtype must be the same as q. An additive mask is expected, + meaning float values using "-inf" to mask values. key_padding_mask A mask with size (batch size x sequence length). The dtype must be the same as q. - Multiplicative mask where a value of 1 will keep the value, while a value of 0 will - mask the value. + An additive mask is expected, meaning float values using "-inf" to mask values """ # NOTE: @@ -149,9 +148,9 @@ def forward( # If blocks are to be constantly masked, better perf would thus be reached by signalling them out in the # initial attention setup - if att_mask is not None and att_mask.dtype != q.dtype: + if att_mask is not None and att_mask.dtype == torch.bool: self.update_mask_type(att_mask, q.dtype) - if key_padding_mask is not None and key_padding_mask.dtype != q.dtype: + if key_padding_mask is not None and key_padding_mask.dtype == torch.bool: self.update_mask_type(key_padding_mask, q.dtype) assert ( @@ -197,8 +196,8 @@ def forward( scale=scale, key_padding_mask=key_padding_mask, attn_mask=att_mask, - key_padding_mask_mode=MaskType.MUL, - attn_mask_mode=MaskType.MUL, + key_padding_mask_mode=MaskType.ADD, + attn_mask_mode=MaskType.ADD, ) # - then (dense) attention is (sparse) attention matrix * dense (value)