Skip to content

Commit

Permalink
[Llama FA2] Re-add _expand_attention_mask and clean a couple things (h…
Browse files Browse the repository at this point in the history
…uggingface#27074)

* clean

* clean llama

* fix more

* make style

* Apply suggestions from code review

* Apply suggestions from code review

* Update src/transformers/models/llama/modeling_llama.py

* Update src/transformers/models/llama/modeling_llama.py

* Apply suggestions from code review

* finish

* make style
  • Loading branch information
patrickvonplaten authored Oct 26, 2023
1 parent 4864d08 commit d7cb5e1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 25 deletions.
17 changes: 10 additions & 7 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -157,12 +157,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -182,8 +186,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -212,7 +216,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -837,7 +842,7 @@ def _flash_attention_forward(
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output
Expand Down Expand Up @@ -1154,8 +1159,6 @@ def __init__(self, config: FalconConfig):
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)

# Transformer blocks
Expand Down
42 changes: 32 additions & 10 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def _get_unpad_data(attention_mask):
)


def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask"
)
return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)


def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask"
)
return AttnMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)


class AttnMaskConverter:
"""
A utility attention mask class that allows:
Expand Down Expand Up @@ -122,7 +140,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -131,12 +149,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -156,8 +178,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -186,7 +208,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -555,7 +578,7 @@ def forward(
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -669,7 +692,7 @@ def _flash_attention_forward(
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output
Expand Down Expand Up @@ -739,8 +762,9 @@ def forward(
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -914,8 +938,6 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
Expand Down
24 changes: 16 additions & 8 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -122,12 +122,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -147,8 +151,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -177,7 +181,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -645,7 +650,12 @@ def _flash_attention_forward(
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
else:
attn_output = flash_attn_func(
Expand All @@ -654,7 +664,7 @@ def _flash_attention_forward(
value_states,
dropout,
softmax_scale=softmax_scale,
causal=True,
causal=self.is_causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)

Expand Down Expand Up @@ -903,8 +913,6 @@ def __init__(self, config: MistralConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
Expand Down

0 comments on commit d7cb5e1

Please sign in to comment.