Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashMask] Add FlashMask for Qwen2 #9264

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,19 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩

* 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下:

| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| Mixtral | ✅ | ✅ | | ❌ | 🚧 | 🚧 | 🚧 | 🚧 |
| Mistral | ✅ | ✅ | | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| Baichuan/Baichuan2 | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ |
| ChatGLM-6B | ✅ | ✅ | | ✅ | 🚧 | 🚧 | ✅ | ❌ |
| ChatGLM2/ChatGLM3 | ✅ | ✅ | | ✅ | 🚧 | 🚧 | ✅ | ✅ |
| Bloom | ✅ | ✅ | | ✅ | 🚧 | 🚧 | ✅ | ✅ |
| GPT-3 | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
| OPT | ✅ | ✅ | | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
| Yuan2 | ✅ | ✅ | | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
| 模型名称/能力支持 | Pretrain | SFT | FlashMask | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert |
|:------------------:|:--------:|:---:|:---------:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:|
| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| Mixtral | ✅ | ✅ | 🚧 | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 |
| Mistral | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ |
| Baichuan/Baichuan2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ |
| ChatGLM-6B | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ |
| ChatGLM2/ChatGLM3 | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ |
| Bloom | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ |
| GPT-3 | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
| OPT | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
| Yuan2 | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ |
------------------------------------------------------------------------------------------

* [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下:
Expand Down
5 changes: 4 additions & 1 deletion llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
LlamaForCausalLM,
LlamaForCausalLMPipe,
LlamaTokenizer,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
Expand All @@ -69,7 +71,7 @@
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe]
flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe]


def main():
Expand Down Expand Up @@ -109,6 +111,7 @@ def main():
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
Expand Down
82 changes: 46 additions & 36 deletions paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
from ..activations import ACT2FN
from ..conversion_utils import StateDictNameMapping, init_name_mappings
from ..linear_utils import Linear
from ..llama import fusion_ops
from ..model_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
from ..model_utils import PretrainedModel, register_base_model
from ..utils import caculate_llm_flops
from ..utils import caculate_llm_flops, logger
from .configuration import Qwen2Config

try:
Expand Down Expand Up @@ -156,6 +157,7 @@
value_states,
attention_mask,
output_attentions,
attn_mask_startend_row_indices=None,
training=True,
sequence_parallel=False,
):
Expand All @@ -166,32 +168,16 @@
# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
# Torch Flash Attention input [ bz, nhead, seqlen, head_dim]

version = paddle.version.full_version
if version != "0.0.0" and version <= "2.5.2":
attn_output, attn_weights = flash_attention(
query_states,
key_states,
value_states,
causal=True,
return_softmax=output_attentions,
)
else:
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
dropout_p=config.attention_dropout if training else 0.0,
training=training,
)
attn_weights = None

if sequence_parallel:
attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads])
else:
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
return (attn_output, attn_weights) if output_attentions else attn_output
return fusion_ops.fusion_flash_attention(

Check warning on line 171 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L171

Added line #L171 was not covered by tests
query_states,
config,
key_states,
value_states,
attention_mask,
output_attentions,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
sequence_parallel=sequence_parallel,
)
else:
# [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
Expand Down Expand Up @@ -510,6 +496,7 @@
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
Expand Down Expand Up @@ -574,6 +561,7 @@
value_states,
attention_mask,
output_attentions,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
training=self.training,
sequence_parallel=self.sequence_parallel,
use_reentrant=self.config.recompute_use_reentrant,
Expand All @@ -586,6 +574,7 @@
value_states,
attention_mask,
output_attentions,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
training=self.training,
sequence_parallel=self.sequence_parallel,
)
Expand Down Expand Up @@ -640,6 +629,7 @@
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs,
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
"""
Expand Down Expand Up @@ -677,6 +667,7 @@
attention_mask,
output_attentions,
use_cache,
attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)
else:
Expand All @@ -687,6 +678,7 @@
attention_mask,
output_attentions,
use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

if type(outputs) is tuple:
Expand Down Expand Up @@ -992,6 +984,7 @@
output_attentions: bool,
past_key_value: Tensor,
use_cache: bool,
attn_mask_startend_row_indices=None,
):
def create_custom_forward(module):
def custom_forward(*inputs):
Expand All @@ -1007,6 +1000,7 @@
output_attentions,
past_key_value,
use_cache,
attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)

Expand All @@ -1023,6 +1017,7 @@
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
attn_mask_startend_row_indices=None,
) -> Union[Tuple, BaseModelOutputWithPast]:

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -1062,20 +1057,24 @@
inputs_embeds = ScatterOp.apply(inputs_embeds)

# embed positions
if attention_mask is None:
if attn_mask_startend_row_indices is not None:
attention_mask = None

Check warning on line 1061 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1061

Added line #L1061 was not covered by tests
else:
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
attention_mask = (
paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if attention_mask is None
else attention_mask
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
attention_mask = None if is_casual_mask(attention_mask) else attention_mask

Check warning on line 1073 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1073

Added line #L1073 was not covered by tests

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
is_casual = is_casual_mask(attention_mask)
if is_casual:
attention_mask = None
hidden_states = inputs_embeds

# decoder layers
Expand Down Expand Up @@ -1103,6 +1102,7 @@
output_attentions,
past_key_value,
use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -1112,6 +1112,7 @@
output_attentions,
past_key_value,
use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

# NOTE: clear outdate cache after it has been used for memory saving
Expand Down Expand Up @@ -1340,6 +1341,7 @@
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
attn_mask_startend_row_indices=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1373,6 +1375,13 @@
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if attn_mask_startend_row_indices is not None and attention_mask is not None:
logger.warning(

Check warning on line 1379 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1379

Added line #L1379 was not covered by tests
"You have provided both attn_mask_startend_row_indices and attention_mask. "
"The attn_mask_startend_row_indices will be used."
)
attention_mask = None

Check warning on line 1383 in paddlenlp/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen2/modeling.py#L1383

Added line #L1383 was not covered by tests

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.qwen2(
input_ids=input_ids,
Expand All @@ -1384,6 +1393,7 @@
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)

hidden_states = outputs[0]
Expand Down
Loading
Loading