From d7cb5e138ec1ccc848a554574b1a89f0dfaf0e90 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 26 Oct 2023 13:06:21 +0200 Subject: [PATCH] [Llama FA2] Re-add _expand_attention_mask and clean a couple things (#27074) * clean * clean llama * fix more * make style * Apply suggestions from code review * Apply suggestions from code review * Update src/transformers/models/llama/modeling_llama.py * Update src/transformers/models/llama/modeling_llama.py * Apply suggestions from code review * finish * make style --- .../models/falcon/modeling_falcon.py | 17 ++++---- .../models/llama/modeling_llama.py | 42 ++++++++++++++----- .../models/mistral/modeling_mistral.py | 24 +++++++---- 3 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 642e60a72f91df..49307cf52ecd75 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -148,7 +148,7 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, - key_value_length: int, + key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -157,12 +157,16 @@ def to_4d( causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) - past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -182,8 +186,8 @@ def to_4d( return expanded_4d_mask + @staticmethod def _make_causal_mask( - self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, @@ -212,7 +216,8 @@ def _make_causal_mask( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -837,7 +842,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal ) return attn_output @@ -1154,8 +1159,6 @@ def __init__(self, config: FalconConfig): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - # create attention mask cache that trickles down to each attention layer - # so that the attention_mask cache can be shared among layers self.attn_mask_converter = AttnMaskConverter(is_causal=True) # Transformer blocks diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 207302a10e1683..4c31729337ddd1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -64,6 +64,24 @@ def _get_unpad_data(attention_mask): ) +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask" + ) + return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask" + ) + return AttnMaskConverter._make_causal_mask( + input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length + ) + + class AttnMaskConverter: """ A utility attention mask class that allows: @@ -122,7 +140,7 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, - key_value_length: int, + key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -131,12 +149,16 @@ def to_4d( causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) - past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -156,8 +178,8 @@ def to_4d( return expanded_4d_mask + @staticmethod def _make_causal_mask( - self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, @@ -186,7 +208,8 @@ def _make_causal_mask( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -555,7 +578,7 @@ def forward( value_states = self.v_proj(hidden_states) # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dime x hidden_dim + # batch_size x seq_length x head_dim x hidden_dim # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -669,7 +692,7 @@ def _flash_attention_forward( attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal ) return attn_output @@ -739,8 +762,9 @@ def forward( """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -914,8 +938,6 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - # create attention mask cache that trickles down to each attention layer - # so that the attention_mask cache can be shared among layers self.attn_mask_converter = AttnMaskConverter(is_causal=True) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index fbedb20dbc1539..f98eb4de88410c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -113,7 +113,7 @@ def to_4d( self, attention_mask_2d: torch.Tensor, query_length: int, - key_value_length: int, + key_value_length: Optional[int] = None, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -122,12 +122,16 @@ def to_4d( causal, a causal mask will be added. """ input_shape = (attention_mask_2d.shape[0], query_length) - past_key_values_length = key_value_length - query_length # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] causal_4d_mask = None if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -147,8 +151,8 @@ def to_4d( return expanded_4d_mask + @staticmethod def _make_causal_mask( - self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, @@ -177,7 +181,8 @@ def _make_causal_mask( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ @@ -645,7 +650,12 @@ def _flash_attention_forward( else: if not use_sliding_windows: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, ) else: attn_output = flash_attn_func( @@ -654,7 +664,7 @@ def _flash_attention_forward( value_states, dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) @@ -903,8 +913,6 @@ def __init__(self, config: MistralConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - # create attention mask cache that trickles down to each attention layer - # so that the attention_mask cache can be shared among layers self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)