From 744e575f6227ce74b44e08a571dda1a39c1aef16 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Wed, 31 Jul 2024 23:09:46 +0100 Subject: [PATCH 1/9] Addition of norm_eps parameter to spade_autoencoderkl.py in response to issue https://github.com/Project-MONAI/MONAI/issues/7946. Signed-off-by: Virginia Fernandez --- monai/networks/blocks/selfattention.py | 46 +++++++++++++++------ monai/networks/blocks/spatialattention.py | 8 +++- monai/networks/blocks/transformerblock.py | 3 ++ monai/networks/nets/diffusion_model_unet.py | 3 ++ tests/test_crossattention.py | 2 +- tests/test_selfattention.py | 12 ++++++ 6 files changed, 59 insertions(+), 15 deletions(-) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ab1e1fd10..d5b9b44a64 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -15,9 +15,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -42,6 +43,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -59,6 +61,7 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -82,6 +85,17 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=0, major=2, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions > 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False" + ) + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) @@ -97,6 +111,7 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -130,23 +145,28 @@ def forward(self, x): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + if self.use_flash_attention: + x = F.scaled_dot_product_attention(q, k, v) + else: + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + # apply relative positional embedding if defined + att_mat = ( + self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + ) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 75319853d9..1cfafb1585 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module): num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -44,6 +45,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -54,7 +56,11 @@ def __init__( raise ValueError("num_channels must be divisible by num_head_channels") num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 self.attn = SABlock( - hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype + hidden_size=num_channels, + num_heads=num_heads, + qkv_bias=True, + attention_dtype=attention_dtype, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor): diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 0aa1697479..dbba930d66 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -36,6 +36,7 @@ def __init__( causal: bool = False, sequence_length: int | None = None, with_cross_attention: bool = False, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -45,6 +46,7 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -66,6 +68,7 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) self.with_cross_attention = with_cross_attention diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 8a9ac859a3..38235a113f 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -66,6 +66,7 @@ class DiffusionUNetTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ @@ -77,6 +78,7 @@ def __init__( dropout: float = 0.0, cross_attention_dim: int | None = None, upcast_attention: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -86,6 +88,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) self.attn2 = CrossAttentionBlock( diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 4ab0ab1823..ef29c1990f 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -119,7 +119,7 @@ def test_access_attn_matrix(self): # no of elements is zero assert no_matrix_acess_blk.att_mat.nelement() == 0 - # be able to acess the attention matrix + # be able to acess the attention matrix. matrix_acess_blk = CrossAttentionBlock( hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True ) diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index d069d6aa30..504f4c4f4e 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -49,11 +50,17 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_SABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((0, 2)) def test_shape(self, input_param, input_shape, expected_shape): net = SABlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + # With flash attention + net_fa = SABlock(**input_param, use_flash_attention=True) + with eval_mode(net): + result_fa = net_fa(torch.randn(input_shape)) + self.assertEqual(result_fa.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): @@ -62,6 +69,11 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((0, 2)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) + def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1) From bcb5907ccf07c2026c415ada544e234d254f622d Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 1 Aug 2024 13:22:13 +0100 Subject: [PATCH 2/9] Addition of flash_attention (using new PyTorch functionality for torch >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. Signed-off-by: Virginia Fernandez --- monai/networks/blocks/crossattention.py | 46 +++++++++++++++------ monai/networks/blocks/selfattention.py | 4 +- monai/networks/blocks/transformerblock.py | 7 +++- monai/networks/nets/diffusion_model_unet.py | 1 + tests/test_crossattention.py | 14 +++++++ tests/test_selfattention.py | 2 +- 6 files changed, 57 insertions(+), 17 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index b888ea3942..79a5959025 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -17,7 +17,7 @@ import torch.nn as nn from monai.networks.layers.utils import get_rel_pos_embedding_layer -from monai.utils import optional_import +from monai.utils import optional_import, pytorch_after Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") @@ -44,6 +44,7 @@ def __init__( rel_pos_embedding: Optional[str] = None, input_size: Optional[Tuple] = None, attention_dtype: Optional[torch.dtype] = None, + use_flash_attention: bool = False, ) -> None: """ Args: @@ -62,6 +63,7 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. """ super().__init__() @@ -81,6 +83,17 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): + raise ValueError( + "use_flash_attention is only supported for PyTorch versions >= 2.0." + "Upgrade your PyTorch or set the flag to False." + ) + if use_flash_attention and save_attn: + raise ValueError( + "save_attn has been set to True, but use_flash_attention is also set" + "to True. save_attn can only be used if use_flash_attention is False" + ) + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -101,6 +114,7 @@ def __init__( self.causal = causal self.sequence_length = sequence_length + self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: # causal mask to ensure that attention is only applied to the left in the input sequence @@ -145,23 +159,29 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale - # apply relative positional embedding if defined - att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + if self.use_flash_attention: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v).contiguous() + else: + + att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale + # apply relative positional embedding if defined + att_mat = ( + self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat + ) - if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) + if self.causal: + att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + att_mat = att_mat.softmax(dim=-1) - if self.save_attn: - # no gradients and new tensor; - # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - self.att_mat = att_mat.detach() + if self.save_attn: + # no gradients and new tensor; + # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html + self.att_mat = att_mat.detach() - att_mat = self.drop_weights(att_mat) - x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + att_mat = self.drop_weights(att_mat) + x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index d5b9b44a64..48073aba9b 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -85,9 +85,9 @@ def __init__( if causal and sequence_length is None: raise ValueError("sequence_length is necessary for causal attention.") - if use_flash_attention and not pytorch_after(minor=0, major=2, patch=0): + if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0): raise ValueError( - "use_flash_attention is only supported for PyTorch versions > 2.0." + "use_flash_attention is only supported for PyTorch versions >= 2.0." "Upgrade your PyTorch or set the flag to False." ) if use_flash_attention and save_attn: diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index dbba930d66..cb30d14480 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -75,7 +75,12 @@ def __init__( self.norm_cross_attn = nn.LayerNorm(hidden_size) self.cross_attn = CrossAttentionBlock( - hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False + hidden_size=hidden_size, + num_heads=num_heads, + dropout_rate=dropout_rate, + qkv_bias=qkv_bias, + causal=False, + use_flash_attention=use_flash_attention, ) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 38235a113f..69503e3c0a 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -99,6 +99,7 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + use_flash_attention=use_flash_attention, ) self.norm1 = nn.LayerNorm(num_channels) self.norm2 = nn.LayerNorm(num_channels) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index ef29c1990f..6773ab8d28 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,6 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion einops, has_einops = optional_import("einops") @@ -50,10 +51,16 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): + # Without flash attention net = CrossAttentionBlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) self.assertEqual(result.shape, expected_shape) + # With flash attention + net = CrossAttentionBlock(**input_param, use_flash_attention=True) + with eval_mode(net): + result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) + self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): @@ -62,6 +69,13 @@ def test_ill_arg(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((1, 13)) + def test_save_attn_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True + ) + @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 504f4c4f4e..c06fba0cc5 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -69,7 +69,7 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - @SkipIfBeforePyTorchVersion((0, 2)) + @SkipIfBeforePyTorchVersion((1, 13)) def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True) From 17c914ac7d3f9b4215712b126c9b0da2988dc68d Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 1 Aug 2024 15:15:23 +0100 Subject: [PATCH 3/9] Addition of flash_attention (using new PyTorch functionality for torch >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>> Implementation of proposed corrections Signed-off-by: Virginia Fernandez --- monai/networks/blocks/crossattention.py | 36 +++++++++------------ monai/networks/blocks/selfattention.py | 9 +++--- monai/networks/blocks/transformerblock.py | 14 ++++---- monai/networks/nets/diffusion_model_unet.py | 3 +- tests/test_crossattention.py | 33 +++++++++---------- tests/test_selfattention.py | 33 +++++++++---------- 6 files changed, 59 insertions(+), 69 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 79a5959025..93ac361ec1 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -47,23 +47,20 @@ def __init__( use_flash_attention: bool = False, ) -> None: """ - Args: - hidden_size (int): dimension of hidden layer. - num_heads (int): number of attention heads. - dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. - hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. - context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. - dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. - qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - causal: whether to use causal attention. - sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. - For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. - input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative - positional parameter size. - attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate + (float, optional): fraction of the input units to drop. Defaults to 0.0. hidden_input_size (int, optional): + dimension of the input tensor. Defaults to hidden_size. context_input_size (int, optional): dimension of the + context tensor. Defaults to hidden_size. dim_head (int, optional): dimension of each head. Defaults to + hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. causal: whether to + use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple( + spatial_dim), optional): Input resolution for calculating the relative positional parameter size. + attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -166,9 +163,8 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined - att_mat = ( - self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - ) + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf")) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 48073aba9b..487e07b5d1 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -61,7 +61,9 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -151,9 +153,8 @@ def forward(self, x): att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined - att_mat = ( - self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat - ) + if self.rel_positional_embedding is not None: + att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index cb30d14480..d05755e5e7 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -39,14 +39,12 @@ def __init__( use_flash_attention: bool = False, ) -> None: """ - Args: - hidden_size (int): dimension of hidden layer. - mlp_dim (int): dimension of feedforward layer. - num_heads (int): number of attention heads. - dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. - qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + Args: hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads + (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. + Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 69503e3c0a..a885339d0d 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -66,7 +66,8 @@ class DiffusionUNetTransformerBlock(nn.Module): dropout: dropout probability to use. cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 6773ab8d28..10955eb924 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -32,18 +32,20 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_CABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_CABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -56,11 +58,6 @@ def test_shape(self, input_param, input_shape, expected_shape): with eval_mode(net): result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) self.assertEqual(result.shape, expected_shape) - # With flash attention - net = CrossAttentionBlock(**input_param, use_flash_attention=True) - with eval_mode(net): - result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"])) - self.assertEqual(result.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): @@ -69,7 +66,7 @@ def test_ill_arg(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4) - @SkipIfBeforePyTorchVersion((1, 13)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError): CrossAttentionBlock( diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index c06fba0cc5..9094b4fb1d 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -32,35 +32,32 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, - "input_size": input_size, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for flash_attn in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "use_flash_attention": flash_attn, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_SABLOCK) @skipUnless(has_einops, "Requires einops") - @SkipIfBeforePyTorchVersion((0, 2)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): net = SABlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - # With flash attention - net_fa = SABlock(**input_param, use_flash_attention=True) - with eval_mode(net): - result_fa = net_fa(torch.randn(input_shape)) - self.assertEqual(result_fa.shape, expected_shape) def test_ill_arg(self): with self.assertRaises(ValueError): From f304b18035134c593fdc5d04b8276fda8944b9e2 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Fri, 2 Aug 2024 10:02:03 +0100 Subject: [PATCH 4/9] Addition of flash_attention (using new PyTorch functionality for torch >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>> Implementation of proposed corrections: - Addition of causal, dropout and scale to the call to scaled_dot_product_attention - For this, addition of self.dropout_rate as an attribute - Raising error when rel_pos_embedding is not None and use_flash_attention is True - Fix of docstrings that had gone wrong (in cross and self attention and transformer block) - Addition of two tests to self and cross attention blocks tests to account for the rel_pos_embedding error and to make sure that the causal = True call works. Signed-off-by: Virginia Fernandez --- monai/networks/blocks/crossattention.py | 42 ++++++++++++++--------- monai/networks/blocks/selfattention.py | 10 ++++-- monai/networks/blocks/transformerblock.py | 14 ++++---- tests/test_crossattention.py | 30 +++++++++++++++- tests/test_selfattention.py | 30 +++++++++++++++- 5 files changed, 100 insertions(+), 26 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 93ac361ec1..6293fab8c7 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -47,20 +47,25 @@ def __init__( use_flash_attention: bool = False, ) -> None: """ - Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate - (float, optional): fraction of the input units to drop. Defaults to 0.0. hidden_input_size (int, optional): - dimension of the input tensor. Defaults to hidden_size. context_input_size (int, optional): dimension of the - context tensor. Defaults to hidden_size. dim_head (int, optional): dimension of each head. Defaults to - hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. causal: whether to - use causal attention. sequence_length: if causal is True, it is necessary to specify the sequence length. - rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only - "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple( - spatial_dim), optional): Input resolution for calculating the relative positional parameter size. - attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + Args: + hidden_size (int): dimension of hidden layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. + context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + causal (bool, optional): whether to use causal attention. + sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. + rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional + parameter size. + attention_dtype: cast attention operations to this dtype. + use_flash_attention: if True, use Pytorch's inbuilt + flash attention for a memory efficient attention mechanism (see + https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -91,6 +96,9 @@ def __init__( "to True. save_attn can only be used if use_flash_attention is False" ) + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -104,6 +112,7 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.head_dim**-0.5 self.save_attn = save_attn @@ -158,9 +167,10 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) if self.use_flash_attention: - x = torch.nn.functional.scaled_dot_product_attention(q, k, v).contiguous() + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ).contiguous() else: - att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined if self.rel_positional_embedding is not None: diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 487e07b5d1..17e8a05452 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -95,9 +95,12 @@ def __init__( if use_flash_attention and save_attn: raise ValueError( "save_attn has been set to True, but use_flash_attention is also set" - "to True. save_attn can only be used if use_flash_attention is False" + "to True. save_attn can only be used if use_flash_attention is False." ) + if use_flash_attention and rel_pos_embedding is not None: + raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") + self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) @@ -107,6 +110,7 @@ def __init__( self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) + self.dropout_rate = dropout_rate self.scale = self.dim_head**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor() @@ -148,7 +152,9 @@ def forward(self, x): k = k.to(self.attention_dtype) if self.use_flash_attention: - x = F.scaled_dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention( + q, k, v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index d05755e5e7..aaeba94e23 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -39,12 +39,14 @@ def __init__( use_flash_attention: bool = False, ) -> None: """ - Args: hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads - (int): number of attention heads. dropout_rate (float, optional): fraction of the input units to drop. - Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. - save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. - use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + Args: + hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. + num_heads (int): number of attention heads. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias + (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 10955eb924..864116f1d2 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -38,7 +38,7 @@ "hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, "input_size": input_size, "use_flash_attention": flash_attn, }, @@ -73,6 +73,18 @@ def test_save_attn_with_flash_attention(self): hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True ) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + CrossAttentionBlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): with self.assertRaises(ValueError): @@ -86,6 +98,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = CrossAttentionBlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = CrossAttentionBlock( diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 9094b4fb1d..3e98f4c5c4 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -38,7 +38,7 @@ "hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding, + "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, "input_size": input_size, "use_flash_attention": flash_attn, }, @@ -66,6 +66,18 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @SkipIfBeforePyTorchVersion((2, 0)) + def test_rel_pos_embedding_with_flash_attention(self): + with self.assertRaises(ValueError): + SABlock( + hidden_size=128, + num_heads=3, + dropout_rate=0.1, + use_flash_attention=True, + save_attn=False, + rel_pos_embedding=RelPosEmbedding.DECOMPOSED, + ) + @SkipIfBeforePyTorchVersion((1, 13)) def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError): @@ -83,6 +95,22 @@ def test_causal_no_sequence_length(self): with self.assertRaises(ValueError): SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_causal_flash_attention(self): + block = SABlock( + hidden_size=128, + num_heads=1, + dropout_rate=0.1, + causal=True, + sequence_length=16, + save_attn=False, + use_flash_attention=True, + ) + input_shape = (1, 16, 128) + # Check it runs correctly + block(torch.randn(input_shape)) + @skipUnless(has_einops, "Requires einops") def test_causal(self): block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True) From d53b22bd01c1036f8232b14754a67412880c9999 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 5 Aug 2024 14:57:17 +0100 Subject: [PATCH 5/9] Addition of flash_attention (using new PyTorch functionality for torch >= 2.0) to cross attention and self attention blocks, and addition of parameters to diffusion model unet and to transformer block. Modification of tests to check this functionality. >>>> It was necessary to transpose query, value and key passed to the PyTorch flash attention module to get a behavior that is consistent with the xformers and no flash one, and then to transpose back the result. Behavior this way is consistent with xformers. Signed-off-by: Virginia Fernandez --- monai/networks/blocks/crossattention.py | 9 +++++++-- monai/networks/blocks/selfattention.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 6293fab8c7..47940dbef0 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -168,8 +168,13 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - q, k, v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal - ).contiguous() + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose(1, 2) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 17e8a05452..4f6ccc11e1 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -153,8 +153,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - q, k, v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal - ) + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, + ).transpose(1, 2) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale From c41c2d2db1c810a41097a403ed3b0b76d576800e Mon Sep 17 00:00:00 2001 From: Virginia Fernandez Date: Tue, 6 Aug 2024 13:33:05 +0100 Subject: [PATCH 6/9] Addressing suggested changes from PR: https://github.com/Project-MONAI/MONAI/pull/7977 In particular: - modified this line when use_causal is True in self_attention to - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. Signed-off-by: Virginia Fernandez --- monai/networks/blocks/crossattention.py | 8 +++++--- monai/networks/blocks/selfattention.py | 2 +- tests/test_crossattention.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 47940dbef0..daa5abdd56 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -53,7 +53,7 @@ def __init__( dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size. context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size. - dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. + dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. causal (bool, optional): whether to use causal attention. @@ -162,7 +162,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) + q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) @@ -174,7 +174,9 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal, - ).transpose(1, 2) + ).transpose( + 1, 2 + ) # Back to (b, nh, t, hs) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4f6ccc11e1..124c00acc6 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -168,7 +168,7 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: - att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf")) + att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) att_mat = att_mat.softmax(dim=-1) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 864116f1d2..44458147d6 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -52,6 +52,7 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): # Without flash attention net = CrossAttentionBlock(**input_param) From eda38ed3ca6a99582e8f34f535303613fbe8c564 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 6 Aug 2024 13:33:37 +0100 Subject: [PATCH 7/9] Addressing suggested changes from PR: https://github.com/Project-MONAI/MONAI/pull/7977 In particular: - modified this line when use_causal is True in self_attention to - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. >>>> FIX: I forgot to sign! Signed-off-by: Virginia Fernandez --- tests/test_crossattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 44458147d6..62d1e68a09 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -51,7 +51,7 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) - @skipUnless(has_einops, "Requires einops") + @skipUnless(has_einops, "Requires einops_") @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): # Without flash attention From b09f8278abb8e28dcf70f4d4a0e2786e6ab5e710 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 6 Aug 2024 13:33:46 +0100 Subject: [PATCH 8/9] Addressing suggested changes from PR: https://github.com/Project-MONAI/MONAI/pull/7977 In particular: - modified this line when use_causal is True in self_attention to - added pertinent transpose calls to cross attention to ensure that the behaviour matches that of xops and that the code works, as well, for flash_attention=False. - added SkipIfPytorch[...] clause before the test_shape in test_cross_attention to make sure it does not error out for cases in the case block that use flash_attention = True - fix one rogue space on docstrings that had been added I ran autofix and mypy. cross_attention was reformatted. mypy did not suggest changes. >>>> FIX: I forgot to sign! Signed-off-by: Virginia Fernandez --- tests/test_crossattention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 62d1e68a09..44458147d6 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -51,7 +51,7 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_CABLOCK) - @skipUnless(has_einops, "Requires einops_") + @skipUnless(has_einops, "Requires einops") @SkipIfBeforePyTorchVersion((2, 0)) def test_shape(self, input_param, input_shape, expected_shape): # Without flash attention From 2ea7d7988178b1132e145d75c107ecd9fdb65753 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Tue, 6 Aug 2024 15:11:43 +0100 Subject: [PATCH 9/9] Docstrings change in transformerblock.py Signed-off-by: Virginia Fernandez --- monai/networks/blocks/transformerblock.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index aaeba94e23..28d9c563ac 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -40,10 +40,11 @@ def __init__( ) -> None: """ Args: - hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. + hidden_size (int): dimension of hidden layer. + mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. - dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. qkv_bias - (bool, optional): apply bias term for the qkv linear layer. Defaults to False. + dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0. + qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).