diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index b888ea3942..daa5abdd56 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -17,7 +17,7 @@ import torch.nn as nn from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -44,6 +44,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -55,13 +56,16 @@ def __init__( dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. - For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. - input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. + causal (bool, optional): whether to use causal attention. + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional + parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -81,6 +85,20 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False" + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -94,6 +112,7 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.head_dim**-0.5 self.save_attn = save_attn @@ -101,6 +120,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -142,26 +162,39 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.use_flash_attention: + x = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose( + 1, 2 + ) # Back to (b, nh, t, hs) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + # apply relative positional embedding if defined + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..124c00acc6 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -15,9 +15,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -42,6 +43,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -59,6 +61,9 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -82,6 +87,20 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False." + ) + + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) @@ -91,12 +110,14 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -130,23 +151,34 @@ def forward(self, x): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + if self.use_flash_attention: + x = F.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose(1, 2) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply relative positional embedding if defined + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 75319853d9..1cfafb1585 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module): num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -44,6 +45,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -54,7 +56,11 @@ def __init__( raise ValueError("num_channels must be divisible by num_head_channels") num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.attn = SABlock( - hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + hidden_size=num_channels, + num_heads=num_heads, + qkv_bias=True, + attention_dtype=attention_dtype, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor): diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 0aa1697479..28d9c563ac 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -36,6 +36,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -43,8 +44,10 @@ def __init__( mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. - qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -66,13 +69,19 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention self.norm_cross_attn = nn.LayerNorm(hidden_size) self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 8a9ac859a3..a885339d0d 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -66,6 +66,8 @@ class DiffusionUNetTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -77,6 +79,7 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -86,6 +89,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttentionBlock( @@ -96,6 +100,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 4ab0ab1823..44458147d6 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -31,25 +32,29 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_CABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): + # Without flash attention net = CrossAttentionBlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) @@ -62,6 +67,25 @@ def test_ill_arg(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True + ) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): @@ -75,6 +99,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = CrossAttentionBlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = CrossAttentionBlock( @@ -119,7 +159,7 @@ def test_access_attn_matrix(self): # no of elements is zero assert no_matrix_acess_blk.att_mat.nelement() == 0 - # be able to acess the attention matrix + # be able to acess the attention matrix. matrix_acess_blk = CrossAttentionBlock( hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True ) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d069d6aa30..3e98f4c5c4 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -31,24 +32,27 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_SABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): net = SABlock(**input_param) with eval_mode(net): @@ -62,6 +66,23 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + + @SkipIfBeforePyTorchVersion((1, 13)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) + def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) @@ -74,6 +95,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = SABlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)