From 5d9bc483912262e1d0370f0ad8f76ff75fd5cb6a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Sep 2023 12:55:23 +0200 Subject: [PATCH 01/23] add FA-2 support for mistral --- docs/source/en/perf_infer_gpu_one.md | 1 + .../models/mistral/modeling_mistral.py | 207 +++++++++++++++++- 2 files changed, 205 insertions(+), 3 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index f0c0bf0b107154..d24299012e9fe1 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to We natively support Flash Attention 2 for the following models: - Llama +- Mistral - Falcon You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.* diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 62cfca29465f3e..b35f5e57afc0d4 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -25,19 +25,38 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.nn.functional as F + from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, is_flash_attn_available from .configuration_mistral import MistralConfig +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + def _make_sliding_window_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, @@ -226,6 +245,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -291,11 +311,182 @@ def forward( return attn_output, attn_weights, past_key_value +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MistralAttention(config=config) + self.self_attn = MistralAttention(config=config) if not getattr(config, "_flash_attn_2_enabled", False) else MistralFlashAttention2(config) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -308,6 +499,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -335,6 +527,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = residual + hidden_states @@ -382,6 +575,7 @@ class MistralPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.initializer_range @@ -569,11 +763,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + + padding_mask = None + # embed positions if attention_mask is None: attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) + elif 0 in attention_mask: + padding_mask = attention_mask + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), @@ -607,7 +807,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions) + return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) return custom_forward @@ -625,6 +825,7 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + padding_mask=padding_mask, ) hidden_states = layer_outputs[0] From 0983d88becf430030d5bdb34450291659ff7152d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Sep 2023 11:03:34 +0000 Subject: [PATCH 02/23] fixup --- .../models/mistral/modeling_mistral.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b35f5e57afc0d4..d5a942237712b2 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -22,24 +22,29 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -import torch.nn.functional as F - from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, is_flash_attn_available +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_available, + logging, + replace_return_docstrings, +) from .configuration_mistral import MistralConfig + if is_flash_attn_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" @@ -381,8 +386,10 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -486,7 +493,11 @@ class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MistralAttention(config=config) if not getattr(config, "_flash_attn_2_enabled", False) else MistralFlashAttention2(config) + self.self_attn = ( + MistralAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else MistralFlashAttention2(config) + ) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) From b8c9198005c2c67a211d0fc947ff85507897be73 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Sep 2023 18:29:52 +0200 Subject: [PATCH 03/23] add sliding windows --- .../models/mistral/modeling_mistral.py | 61 +++++++++++++------ 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b35f5e57afc0d4..4219370148343c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -19,6 +19,7 @@ # limitations under the License. """ PyTorch Mistral model.""" import math +import inspect from typing import List, Optional, Tuple, Union import torch @@ -38,6 +39,8 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + _is_flash_using_slicing_windows = "window_size" in list(inspect.signature(flash_attn_func).parameters) + logger = logging.get_logger(__name__) @@ -342,10 +345,12 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + use_sliding_windows = _is_flash_using_slicing_windows and kv_seq_len > self.config.sliding_window and not self.training + if past_key_value is not None: - # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) @@ -382,7 +387,7 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate) + attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -394,7 +399,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -425,24 +430,44 @@ def _flash_attention_forward( cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=True, - ) + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + window_size=(self.config.sliding_window, self.config.sliding_window) + ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True - ) + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window // 2, self.config.sliding_window // 2) + ) return attn_output From bd58ca72dddf3b64099e44866a6901ace3399ff9 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 28 Sep 2023 19:00:10 +0200 Subject: [PATCH 04/23] fixing few nits --- src/transformers/models/mistral/modeling_mistral.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 1a45ae3db9d377..839bdd6d5732bc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -45,7 +45,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - _is_flash_using_slicing_windows = "window_size" in list(inspect.signature(flash_attn_func).parameters) + _is_flash_using_sliding_windows = "window_size" in list(inspect.signature(flash_attn_func).parameters) logger = logging.get_logger(__name__) @@ -353,7 +353,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - use_sliding_windows = _is_flash_using_slicing_windows and kv_seq_len > self.config.sliding_window and not self.training + use_sliding_windows = _is_flash_using_sliding_windows and self.config.sliding_window is not None if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) @@ -470,9 +470,7 @@ def _flash_attention_forward( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True ) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window // 2, self.config.sliding_window // 2) - ) + attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window, self.config.sliding_window)) return attn_output From 43b02897684132d1cfc640d4f45275a6c6958b62 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Sep 2023 09:16:08 +0200 Subject: [PATCH 05/23] v1 slicing cache - logits do not match --- .../models/mistral/modeling_mistral.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 839bdd6d5732bc..515528776460fc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -175,8 +175,10 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -349,13 +351,26 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + rotary_seq_len = max(kv_seq_len, position_ids.max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - use_sliding_windows = _is_flash_using_sliding_windows and self.config.sliding_window is not None + use_sliding_windows = _is_flash_using_sliding_windows and self.config.sliding_window is not None and kv_seq_len > self.config.sliding_window if past_key_value is not None: + if use_sliding_windows and kv_seq_len > self.config.sliding_window: + slicing_tokens = kv_seq_len - self.config.sliding_window + + past_key = past_key_value[0] + past_value = past_key_value[1] + + past_key = past_key[:, :, :-slicing_tokens, :].contiguous() + past_value = past_value[:, :, :-slicing_tokens, :].contiguous() + + past_key_value = (past_key, past_value) + key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) From ed2616f2ae1a679539b95f0049fd39d403342db0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 29 Sep 2023 11:25:03 +0200 Subject: [PATCH 06/23] add comment --- src/transformers/models/mistral/modeling_mistral.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 515528776460fc..9d86e19004ea0a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -352,6 +352,7 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + # Extrapolate the RoPE in case of sliding windows rotary_seq_len = max(kv_seq_len, position_ids.max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) @@ -366,8 +367,8 @@ def forward( past_key = past_key_value[0] past_value = past_key_value[1] - past_key = past_key[:, :, :-slicing_tokens, :].contiguous() - past_value = past_value[:, :, :-slicing_tokens, :].contiguous() + past_key = past_key[:, :, kv_seq_len-slicing_tokens:, :].contiguous() + past_value = past_value[:, :, kv_seq_len-slicing_tokens:, :].contiguous() past_key_value = (past_key, past_value) From 7cafc2d7cbafd5d5cd27a3912c3280b278ed65c3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 12:41:58 +0200 Subject: [PATCH 07/23] fix bugs --- .../models/mistral/modeling_mistral.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9d86e19004ea0a..9353223589a5b4 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -362,16 +362,20 @@ def forward( if past_key_value is not None: if use_sliding_windows and kv_seq_len > self.config.sliding_window: - slicing_tokens = kv_seq_len - self.config.sliding_window + slicing_tokens = (kv_seq_len - self.config.sliding_window) + 1 past_key = past_key_value[0] past_value = past_key_value[1] - past_key = past_key[:, :, kv_seq_len-slicing_tokens:, :].contiguous() - past_value = past_value[:, :, kv_seq_len-slicing_tokens:, :].contiguous() + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() past_key_value = (past_key, past_value) + if padding_mask is not None: + padding_mask = padding_mask[:, slicing_tokens:] + padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1) + key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) @@ -492,11 +496,18 @@ def _flash_attention_forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + # HACK: deep dive why we need this? + if kv_seq_len != padding_mask.shape[-1]: + padding_mask_num_tokens = padding_mask.shape[-1] + padding_mask = padding_mask[:, padding_mask_num_tokens-kv_seq_len:] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k From 2b8c7b46f0fef47d57a267fd49e862834a67b908 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 17:11:59 +0200 Subject: [PATCH 08/23] more mem efficient --- src/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9353223589a5b4..ae5a32d7e19995 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -353,7 +353,7 @@ def forward( kv_seq_len += past_key_value[0].shape[-2] # Extrapolate the RoPE in case of sliding windows - rotary_seq_len = max(kv_seq_len, position_ids.max().item()) + 1 + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) From 4a3387df8ab06d08c891abf1a4cfcf5ba567b411 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 17:19:28 +0200 Subject: [PATCH 09/23] add warning once --- src/transformers/models/mistral/modeling_mistral.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ae5a32d7e19995..487a7c4a9d52f1 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -358,10 +358,17 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - use_sliding_windows = _is_flash_using_sliding_windows and self.config.sliding_window is not None and kv_seq_len > self.config.sliding_window + use_sliding_windows = _is_flash_using_sliding_windows and hasattr(self.config, "sliding_window") is not None and kv_seq_len > self.config.sliding_window + + if not _is_flash_using_sliding_windows: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) if past_key_value is not None: - if use_sliding_windows and kv_seq_len > self.config.sliding_window: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: slicing_tokens = (kv_seq_len - self.config.sliding_window) + 1 past_key = past_key_value[0] @@ -444,6 +451,8 @@ def _flash_attention_forward( Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate """ # Contains at least one padding token in the sequence if padding_mask is not None: From 885b601487b45d73949d996577ae8f5a66611e0c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 17:35:21 +0200 Subject: [PATCH 10/23] add warning once --- .../models/mistral/modeling_mistral.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 487a7c4a9d52f1..b1edee8003a2bc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -429,7 +429,7 @@ def forward( return attn_output, attn_weights, past_key_value - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False ): @@ -503,11 +503,11 @@ def _flash_attention_forward( return attn_output - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - # HACK: deep dive why we need this? + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place if kv_seq_len != padding_mask.shape[-1]: padding_mask_num_tokens = padding_mask.shape[-1] padding_mask = padding_mask[:, padding_mask_num_tokens-kv_seq_len:] @@ -842,6 +842,15 @@ def forward( elif 0 in attention_mask: padding_mask = attention_mask + if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled", False): + is_padding_right = padding_mask[:, -1].sum().item() != batch_size + if not is_padding_right: + logger.warning_once( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), From 172d99a406e769b925640f9d121bb3de53d76fe0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 17:38:06 +0200 Subject: [PATCH 11/23] oops --- src/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b1edee8003a2bc..6ce19319a8a7ab 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -842,7 +842,7 @@ def forward( elif 0 in attention_mask: padding_mask = attention_mask - if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled", False): + if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled"): is_padding_right = padding_mask[:, -1].sum().item() != batch_size if not is_padding_right: logger.warning_once( From 253b3830e267cdce747cfd9a2741ea23032a6dfb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 15:38:50 +0000 Subject: [PATCH 12/23] fixup --- .../models/mistral/modeling_mistral.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6ce19319a8a7ab..b2b1760f7f5ceb 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -18,8 +18,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Mistral model.""" -import math import inspect +import math from typing import List, Optional, Tuple, Union import torch @@ -358,7 +358,11 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - use_sliding_windows = _is_flash_using_sliding_windows and hasattr(self.config, "sliding_window") is not None and kv_seq_len > self.config.sliding_window + use_sliding_windows = ( + _is_flash_using_sliding_windows + and hasattr(self.config, "sliding_window") is not None + and kv_seq_len > self.config.sliding_window + ) if not _is_flash_using_sliding_windows: logger.warning_once( @@ -374,8 +378,8 @@ def forward( past_key = past_key_value[0] past_value = past_key_value[1] - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() past_key_value = (past_key, past_value) @@ -418,8 +422,16 @@ def forward( query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward(query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + padding_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -429,9 +441,16 @@ def forward( return attn_output, attn_weights, past_key_value - def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None, use_sliding_windows=False + self, + query_states, + key_states, + value_states, + padding_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -489,7 +508,7 @@ def _flash_attention_forward( dropout_p=dropout, softmax_scale=softmax_scale, causal=True, - window_size=(self.config.sliding_window, self.config.sliding_window) + window_size=(self.config.sliding_window, self.config.sliding_window), ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -499,7 +518,15 @@ def _flash_attention_forward( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True ) else: - attn_output = flash_attn_func(query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True, window_size=(self.config.sliding_window, self.config.sliding_window)) + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=True, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) return attn_output @@ -510,7 +537,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l # by slicing it on the proper place if kv_seq_len != padding_mask.shape[-1]: padding_mask_num_tokens = padding_mask.shape[-1] - padding_mask = padding_mask[:, padding_mask_num_tokens-kv_seq_len:] + padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :] indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) From e4d0fb7a4ed8b7fb4e8f115245686fdd24d08ed5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 17:42:34 +0200 Subject: [PATCH 13/23] more comments --- src/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b2b1760f7f5ceb..0aa66540ea37c5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -471,7 +471,7 @@ def _flash_attention_forward( softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) use_sliding_windows (`bool`, *optional*): - Whether to activate + Whether to activate sliding window attention. """ # Contains at least one padding token in the sequence if padding_mask is not None: From a245722d93cda7f9f45637ea2ac743ab0db4fcdd Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 15:51:48 +0000 Subject: [PATCH 14/23] copy --- src/transformers/models/mistral/modeling_mistral.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b2b1760f7f5ceb..85f24949685dc9 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -175,10 +175,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed From e71c50d3f9dc85955010604971c0cf9bd965d2c3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 19:12:16 +0200 Subject: [PATCH 15/23] add safety checker --- src/transformers/models/mistral/modeling_mistral.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 0aa66540ea37c5..eea38dc52299cb 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -373,7 +373,7 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: - slicing_tokens = (kv_seq_len - self.config.sliding_window) + 1 + slicing_tokens = (kv_seq_len - self.config.sliding_window) past_key = past_key_value[0] past_value = past_key_value[1] @@ -381,6 +381,12 @@ def forward( past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + past_key_value = (past_key, past_value) if padding_mask is not None: From 5d1f5890beef59bc4e6c95f58cdb7eacb9f5442d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 2 Oct 2023 18:59:50 +0000 Subject: [PATCH 16/23] fixup --- src/transformers/models/mistral/modeling_mistral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 33d75b1f9a328e..154b50ae43d2b8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -371,7 +371,7 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: - slicing_tokens = (kv_seq_len - self.config.sliding_window) + slicing_tokens = kv_seq_len - self.config.sliding_window past_key = past_key_value[0] past_value = past_key_value[1] @@ -381,8 +381,8 @@ def forward( if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( - f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" + f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" ) past_key_value = (past_key, past_value) From b478e047c9bdff7403705c700626206f91451cdb Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 3 Oct 2023 10:51:15 +0200 Subject: [PATCH 17/23] Update src/transformers/models/mistral/modeling_mistral.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/mistral/modeling_mistral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 154b50ae43d2b8..a7f32ce28f1926 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -350,7 +350,7 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - # Extrapolate the RoPE in case of sliding windows + # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) From 2fe2f490f4735c94305e2eaf5ccab5989671e07c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 08:56:00 +0000 Subject: [PATCH 18/23] copied from --- src/transformers/models/mistral/modeling_mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a7f32ce28f1926..adea9d24d99af1 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -53,6 +53,7 @@ _CONFIG_FOR_DOC = "MistralConfig" +# Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(padding_mask): seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() From 25789d1c95a055c2c92dbfbce7282c4f2284028f Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 09:00:10 +0000 Subject: [PATCH 19/23] up --- src/transformers/models/mistral/modeling_mistral.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index adea9d24d99af1..97579b43356051 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -45,7 +45,7 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - _is_flash_using_sliding_windows = "window_size" in list(inspect.signature(flash_attn_func).parameters) + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) logger = logging.get_logger(__name__) @@ -358,12 +358,12 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) use_sliding_windows = ( - _is_flash_using_sliding_windows + _flash_supports_window_size and hasattr(self.config, "sliding_window") is not None and kv_seq_len > self.config.sliding_window ) - if not _is_flash_using_sliding_windows: + if not _flash_supports_window_size: logger.warning_once( "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" " make sure to upgrade flash-attn library." From 05ec7f47a1f143b223a978c63e2384a1bde03377 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 11:17:27 +0200 Subject: [PATCH 20/23] raise when padding side is right --- .../models/mistral/modeling_mistral.py | 6 +- tests/models/mistral/test_modeling_mistral.py | 75 ++++++++++++++++++- tests/test_modeling_common.py | 2 +- 3 files changed, 78 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a7f32ce28f1926..5e321ab932cc4a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -873,10 +873,10 @@ def forward( elif 0 in attention_mask: padding_mask = attention_mask - if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled"): + if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled: is_padding_right = padding_mask[:, -1].sum().item() != batch_size - if not is_padding_right: - logger.warning_once( + if is_padding_right: + raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index df3bcf9d671fe2..0737e46a9e3678 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -16,9 +16,11 @@ import unittest +import tempfile +from pytest import mark from transformers import AutoTokenizer, MistralConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device, require_flash_attn, require_torch_gpu from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -352,6 +354,77 @@ def test_past_key_values_format(self): pass + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=False, low_cpu_mem_usage=True + ).to(torch_device) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=True + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, use_flash_attention_2=False + ) + model.to(torch_device) + + dummy_input = torch.LongTensor([[1, 2, 3, 4, 5]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1, 0]]).to(torch_device) + + _ = model(dummy_input, output_hidden_states=True).hidden_states[-1] + with self.assertRaises(ValueError): + _ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True).hidden_states[-1] + + @require_torch class MistralIntegrationTest(unittest.TestCase): @slow diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8c2a277b4b27bd..806799c47f17b3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2926,7 +2926,7 @@ def test_flash_attn_2_generate_use_cache(self): model.save_pretrained(tmpdirname) dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) - dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True From f9a69bcced9d40b837dacff6a169a8cdbad5e672 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 09:18:37 +0000 Subject: [PATCH 21/23] fixup --- src/transformers/models/mistral/modeling_mistral.py | 6 +++++- tests/models/mistral/test_modeling_mistral.py | 13 +++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f6e27529bddc3c..b85d68735190c3 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -874,7 +874,11 @@ def forward( elif 0 in attention_mask: padding_mask = attention_mask - if padding_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled: + if ( + padding_mask is not None + and hasattr(self.config, "_flash_attn_2_enabled") + and self.config._flash_attn_2_enabled + ): is_padding_right = padding_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 0737e46a9e3678..18d7370372dce8 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -15,12 +15,13 @@ """ Testing suite for the PyTorch Mistral model. """ -import unittest import tempfile +import unittest + from pytest import mark from transformers import AutoTokenizer, MistralConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device, require_flash_attn, require_torch_gpu +from transformers.testing_utils import require_flash_attn, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -353,7 +354,6 @@ def test_save_load_fast_init_from_base(self): def test_past_key_values_format(self): pass - @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -377,7 +377,7 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - out = model.generate( + model.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) @@ -390,7 +390,6 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) - @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -422,7 +421,9 @@ def test_flash_attn_2_inference_padding_right(self): _ = model(dummy_input, output_hidden_states=True).hidden_states[-1] with self.assertRaises(ValueError): - _ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True).hidden_states[-1] + _ = model_fa( + dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True + ).hidden_states[-1] @require_torch From 6a48dd31d1258437bf3f21066b0b7ff68c24cc5d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 11:48:51 +0200 Subject: [PATCH 22/23] add doc + few minor changes --- docs/source/en/model_doc/mistral.md | 45 +++++++++++++++++++ .../models/mistral/modeling_mistral.py | 4 +- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/mistral.md b/docs/source/en/model_doc/mistral.md index 578906abe643ee..4a8257d25d6a9e 100644 --- a/docs/source/en/model_doc/mistral.md +++ b/docs/source/en/model_doc/mistral.md @@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path") model = MistralForCausalLM.from_pretrained("/output/path") ``` +## Combining Mistral and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import AutoModelForCausalLM, AutoTokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") + +>>> prompt = "My favourite condiment is" + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True) +>>> tokenizer.batch_decode(generated_ids)[0] +"The expected outupt" +``` + +### Expected speedups + +Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model. + +
+ +
+ +### Sliding window Attention + +The current implementation supports the sliding window attention mechanism and memory efficient cache management. +To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`). + +The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding. + ## The Mistral Team Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f6e27529bddc3c..cf6c198f6d1949 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -401,7 +401,7 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # TODO: llama does not have dropout in the config?? + # TODO: Mistral does not have dropout in the config?? # It is recommended to use dropout with FA according to the docs # when training. dropout_rate = 0.0 # if not self.training else self.attn_dropout @@ -409,8 +409,6 @@ def forward( # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: logger.warning_once( From c2869469cb2cb15ae608cf2529bf87ba2faadfc8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 3 Oct 2023 09:50:59 +0000 Subject: [PATCH 23/23] fixup --- tests/models/mistral/test_modeling_mistral.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 18d7370372dce8..403f2cc7347041 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -377,9 +377,7 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) - model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False - ) + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, use_flash_attention_2=True, low_cpu_mem_usage=True