From ed3f12cca84149f62cfba69e8b5cabd45f9dd859 Mon Sep 17 00:00:00 2001 From: drownfish19 Date: Mon, 7 Oct 2024 08:16:43 +0000 Subject: [PATCH 1/3] add flashmask --- paddlenlp/transformers/qwen2/modeling.py | 68 +++++++++++++----------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 2a9b79c6ef30..7d347e3f6d75 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -46,7 +46,7 @@ from ..model_utils import PretrainedModel, register_base_model from ..utils import caculate_llm_flops from .configuration import Qwen2Config - +from ..llama import fusion_ops try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -156,6 +156,7 @@ def scaled_dot_product_attention( value_states, attention_mask, output_attentions, + attn_mask_startend_row_indices=None, training=True, sequence_parallel=False, ): @@ -166,32 +167,16 @@ def scaled_dot_product_attention( # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] - version = paddle.version.full_version - if version != "0.0.0" and version <= "2.5.2": - attn_output, attn_weights = flash_attention( - query_states, - key_states, - value_states, - causal=True, - return_softmax=output_attentions, - ) - else: - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=attention_mask is None, - dropout_p=config.attention_dropout if training else 0.0, - training=training, - ) - attn_weights = None - - if sequence_parallel: - attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) - else: - attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) - return (attn_output, attn_weights) if output_attentions else attn_output + return fusion_ops.fusion_flash_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + sequence_parallel=sequence_parallel, + ) else: # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] query_states = paddle.transpose(query_states, [0, 2, 1, 3]) @@ -510,6 +495,7 @@ def forward( attention_mask: Optional[paddle.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, **kwargs, ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -574,6 +560,7 @@ def forward( value_states, attention_mask, output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, training=self.training, sequence_parallel=self.sequence_parallel, use_reentrant=self.config.recompute_use_reentrant, @@ -586,6 +573,7 @@ def forward( value_states, attention_mask, output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, training=self.training, sequence_parallel=self.sequence_parallel, ) @@ -640,6 +628,7 @@ def forward( output_attentions: Optional[bool] = False, past_key_value: Optional[Tuple[paddle.Tensor]] = None, use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, **kwargs, ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: """ @@ -677,6 +666,7 @@ def forward( attention_mask, output_attentions, use_cache, + attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: @@ -687,6 +677,7 @@ def forward( attention_mask, output_attentions, use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) if type(outputs) is tuple: @@ -992,6 +983,7 @@ def recompute_training_full( output_attentions: bool, past_key_value: Tensor, use_cache: bool, + attn_mask_startend_row_indices=None, ): def create_custom_forward(module): def custom_forward(*inputs): @@ -1007,6 +999,7 @@ def custom_forward(*inputs): output_attentions, past_key_value, use_cache, + attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) @@ -1023,6 +1016,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1062,16 +1056,17 @@ def forward( inputs_embeds = ScatterOp.apply(inputs_embeds) # embed positions - if attention_mask is None: + if attention_mask is None and attn_mask_startend_row_indices is None: # [bs, seq_len] attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] + if attn_mask_startend_row_indices is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] if self.config.use_flash_attention: is_casual = is_casual_mask(attention_mask) if is_casual: @@ -1103,6 +1098,7 @@ def forward( output_attentions, past_key_value, use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) else: layer_outputs = decoder_layer( @@ -1112,6 +1108,7 @@ def forward( output_attentions, past_key_value, use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) # NOTE: clear outdate cache after it has been used for memory saving @@ -1340,6 +1337,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1373,6 +1371,13 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.qwen2( input_ids=input_ids, @@ -1384,6 +1389,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) hidden_states = outputs[0] From 90f1c66b4bfecabe3b8b97c57c02af3b3eb038d8 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 14 Oct 2024 13:52:32 +0800 Subject: [PATCH 2/3] update --- README.md | 26 +++++++++++----------- llm/run_finetune.py | 5 ++++- paddlenlp/transformers/qwen2/modeling.py | 28 ++++++++++++++---------- 3 files changed, 33 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 7aa0866c5eb3..a3a1a2bec419 100644 --- a/README.md +++ b/README.md @@ -115,19 +115,19 @@ Unified Checkpoint 大模型存储格式在模型参数分布上支持动态扩 * 大模型预训练、精调(包含 SFT、PEFT 技术)、对齐、量化已支持 LLaMA 系列、Baichuan 系列、Bloom 系列、ChatGLM 系列、Mistral 系列、OPT 系列和 Qwen 系列,【LLM】模型预训练、精调、对齐、量化支持列表如下: -| 模型名称/能力支持 | Pretrain | SFT | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert | -|:------------------:|:--------:|:---:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:| -| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | -| Mixtral | ✅ | ✅ | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 | -| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | -| Baichuan/Baichuan2 | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | -| ChatGLM-6B | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ | -| ChatGLM2/ChatGLM3 | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | -| Bloom | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | -| GPT-3 | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | -| OPT | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | -| Yuan2 | ✅ | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | +| 模型名称/能力支持 | Pretrain | SFT | FlashMask | LoRA | Prefix Tuning | DPO | RLHF | Quantization | Torch convert | +|:------------------:|:--------:|:---:|:---------:|:----:|:-------------:|:---:|:----:|:------------:|:-------------:| +| Llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Qwen | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | +| Mixtral | ✅ | ✅ | 🚧 | ✅ | ❌ | 🚧 | 🚧 | 🚧 | 🚧 | +| Mistral | ✅ | ✅ | 🚧 | ✅ | ✅ | ✅ | 🚧 | 🚧 | ✅ | +| Baichuan/Baichuan2 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🚧 | ✅ | ✅ | +| ChatGLM-6B | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ❌ | +| ChatGLM2/ChatGLM3 | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | +| Bloom | ✅ | ✅ | 🚧 | ✅ | ✅ | 🚧 | 🚧 | ✅ | ✅ | +| GPT-3 | ✅ | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | +| OPT | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | +| Yuan2 | ✅ | ✅ | 🚧 | ✅ | 🚧 | 🚧 | 🚧 | 🚧 | ✅ | ------------------------------------------------------------------------------------------ * [大模型推理](./llm/docs/predict/inference.md)已支持 LLaMA 系列、Qwen 系列、Mistral 系列、ChatGLM 系列、Bloom 系列和 Baichuan 系列,支持 Weight Only INT8及 INT4推理,支持 WAC(权重、激活、Cache KV)进行 INT8、FP8量化的推理,【LLM】模型推理支持列表如下: diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 265036207830..8e0190e2589f 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -52,6 +52,8 @@ LlamaForCausalLM, LlamaForCausalLMPipe, LlamaTokenizer, + Qwen2ForCausalLM, + Qwen2ForCausalLMPipe, register_sequence_parallel_allreduce_hooks, ) from paddlenlp.transformers.configuration_utils import LlmMetaConfig @@ -69,7 +71,7 @@ # Fine-tune Environment Variables to support sharding stage1 overlap optimization. os.environ["USE_CASUAL_MASK"] = "False" -flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe] +flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe] def main(): @@ -109,6 +111,7 @@ def main(): if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: try: from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + LinearConfig.enable_accumulate_steps_opt() LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) except ImportError: diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 7d347e3f6d75..0f9fb994539c 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -37,6 +37,7 @@ from ..activations import ACT2FN from ..conversion_utils import StateDictNameMapping, init_name_mappings from ..linear_utils import Linear +from ..llama import fusion_ops from ..model_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -44,9 +45,9 @@ TokenClassifierOutput, ) from ..model_utils import PretrainedModel, register_base_model -from ..utils import caculate_llm_flops +from ..utils import caculate_llm_flops, logger from .configuration import Qwen2Config -from ..llama import fusion_ops + try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -1056,21 +1057,24 @@ def forward( inputs_embeds = ScatterOp.apply(inputs_embeds) # embed positions - if attention_mask is None and attn_mask_startend_row_indices is None: + if attn_mask_startend_row_indices is not None: + attention_mask = None + else: # [bs, seq_len] - attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + attention_mask = None if is_casual_mask(attention_mask) else attention_mask if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - if attn_mask_startend_row_indices is None: - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype - ) # [bs, 1, seq_len, seq_len] - if self.config.use_flash_attention: - is_casual = is_casual_mask(attention_mask) - if is_casual: - attention_mask = None hidden_states = inputs_embeds # decoder layers From 0d872f14d04d6e44904129adfd098dccd9041e00 Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 14 Oct 2024 16:43:06 +0800 Subject: [PATCH 3/3] add flashmask in modeling_pp --- paddlenlp/transformers/qwen2/modeling_pp.py | 87 +++++++++++++++++---- 1 file changed, 71 insertions(+), 16 deletions(-) diff --git a/paddlenlp/transformers/qwen2/modeling_pp.py b/paddlenlp/transformers/qwen2/modeling_pp.py index 549e9e55be26..ae0396ba6312 100644 --- a/paddlenlp/transformers/qwen2/modeling_pp.py +++ b/paddlenlp/transformers/qwen2/modeling_pp.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from typing import OrderedDict + import paddle import paddle.distributed.fleet as fleet import paddle.nn as nn @@ -41,17 +44,17 @@ def parse_args(args): if isinstance(args, tuple): - if len(args) == 3: - hidden_states, attention_mask, position_ids = args + if len(args) == 4: + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args + elif len(args) == 3: + hidden_states, attention_mask, attn_mask_startend_row_indices = args + position_ids = None elif len(args) == 2: hidden_states, attention_mask = args - position_ids = None - elif len(args) == 1: - hidden_states = args - attention_mask, position_ids = None, None + attn_mask_startend_row_indices, position_ids = None, None else: hidden_states = args - attention_mask, position_ids = None, None + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None if position_ids is not None: position_ids.stop_gradient = True @@ -59,14 +62,19 @@ def parse_args(args): if attention_mask is not None: attention_mask.stop_gradient = True - return hidden_states, attention_mask, position_ids + if attn_mask_startend_row_indices is not None: + attn_mask_startend_row_indices.stop_gradient = True + + return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids -def return_args(hidden_states, attention_mask=None, position_ids=None): +def return_args(hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None): ret = (hidden_states,) if attention_mask is not None: ret += (attention_mask.clone(),) + if attn_mask_startend_row_indices is not None: + ret += (attn_mask_startend_row_indices.clone(),) if position_ids is not None: ret += (position_ids.clone(),) if len(ret) == 1: @@ -112,7 +120,7 @@ def forward(self, args): Returns: _type_: _description_ """ - input_ids, attention_mask, position_ids = parse_args(args) + input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) input_embeds = self.embed_tokens(input_ids) if self.config.sequence_parallel: from paddlenlp.transformers import ScatterOp @@ -126,6 +134,10 @@ def forward(self, args): batch_size, seq_length = input_ids.shape if attention_mask is not None: + assert ( + attn_mask_startend_row_indices is None + ), "attention_mask and attn_mask_startend_row_indices can not be set at same time" + attention_mask = Qwen2Model._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), 0, input_embeds.dtype ) @@ -136,22 +148,34 @@ def forward(self, args): attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool")) attention_mask.stop_gradient = True - return return_args(input_embeds, attention_mask, position_ids) + return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) class Qwen2DecoderLayerPipe(Qwen2DecoderLayer): def forward(self, args): - hidden_states, attention_mask, position_ids = parse_args(args) + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) has_gradient = not hidden_states.stop_gradient + if attention_mask is not None and attention_mask.dtype == paddle.int32: + attention_mask, attn_mask_startend_row_indices, position_ids = ( + None, + attention_mask, + attn_mask_startend_row_indices, + ) + elif attention_mask is not None and attention_mask.dtype == paddle.int64: + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask + elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64: + attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: - if attention_mask is not None: + if attention_mask is not None or attn_mask_startend_row_indices is not None: hidden_states = recompute( super().forward, hidden_states, position_ids=position_ids, attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=False, ) else: @@ -160,12 +184,18 @@ def forward(self, args): super().forward, hidden_states, position_ids=position_ids, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, use_reentrant=self.config.recompute_use_reentrant, ) else: - hidden_states = super().forward(hidden_states, position_ids=position_ids, attention_mask=attention_mask) + hidden_states = super().forward( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) - return return_args(hidden_states, attention_mask, position_ids) + return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) class Qwen2RMSNormPipe(nn.Layer): @@ -174,7 +204,7 @@ def __init__(self, config): self.norm = Qwen2RMSNorm(config) def forward(self, args): - hidden_states, attention_mask, position_ids = parse_args(args) + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) return self.norm(hidden_states) @@ -202,6 +232,31 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # DONOT Add base_model_prefix !!!! + @classmethod + def _prepare_pipeline_inputs_func(cls, inputs): + + first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] + last_stage_keys = ["labels"] + + def get_expected_keys(inputs, keys): + ret = tuple([inputs.pop(k) if k in inputs else None for k in keys]) + if len(ret) == 1: + ret = ret[0] + return ret + + if type(inputs) is dict or type(inputs) is OrderedDict: + return [ + get_expected_keys(inputs, first_stage_keys), + get_expected_keys(inputs, last_stage_keys), + ] + + keys = list(inputs[0].keys()) + inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys} + return [ + get_expected_keys(inputs_batch, first_stage_keys), + get_expected_keys(inputs_batch, last_stage_keys), + ] + def __init__(self, config: Qwen2Config): self.config = config