diff --git a/paddlenlp/transformers/llama/modeling_auto.py b/paddlenlp/transformers/llama/modeling_auto.py index 61246d3b5b22..52c2c9b0a109 100644 --- a/paddlenlp/transformers/llama/modeling_auto.py +++ b/paddlenlp/transformers/llama/modeling_auto.py @@ -82,6 +82,11 @@ def swiglu(x, y=None): ] +def is_pp_enable(): + mesh = fleet.auto.get_mesh() + return "pp" in mesh.dim_names + + def get_mesh(pp_idx=0): mesh = fleet.auto.get_mesh() if "pp" in mesh.dim_names: @@ -89,6 +94,14 @@ def get_mesh(pp_idx=0): return mesh +def global_mesh_starts_with_pp(): + mesh = fleet.auto.get_mesh() + if is_pp_enable(): + return mesh.get_mesh_with_dim("pp") + else: + return mesh + + def scaled_dot_product_attention( query_states, config, @@ -800,21 +813,25 @@ def __init__(self, config: LlamaConfig): [dist.Replicate(), dist.Shard(1)], ) - def get_layer_ipp(layer_index): + def get_layer_pp_info(layer_index): mesh = fleet.auto.get_mesh() - if "pp" not in mesh.dim_names: - return None + if is_pp_enable() is False: + return None, False else: pp_degree = mesh.get_dim_size("pp") layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree) - return layer_index // layer_per_stage - - self.layers = nn.LayerList( - [ - LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i)) - for i in range(config.num_hidden_layers) - ] - ) + input_need_reshard = layer_index % layer_per_stage == 0 + return layer_index // layer_per_stage, input_need_reshard + + decoder_layers = [] + self.next_pp_stage_indexes = [] + for i in range(config.num_hidden_layers): + pp_stage_id, input_need_reshard = get_layer_pp_info(i) + decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id)) + if input_need_reshard: + self.next_pp_stage_indexes.append(i) + + self.layers = nn.LayerList(decoder_layers) self.norm = LlamaRMSNormAuto(config) self.gradient_checkpointing = False @@ -840,13 +857,6 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values combined_attention_mask = _make_causal_mask( input_shape, past_key_values_length=past_key_values_length ) - # NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel - combined_attention_mask = dist.shard_tensor( - combined_attention_mask, - get_mesh(), - [dist.Replicate(), dist.Replicate()], - ) - expanded_attn_mask = expanded_attn_mask & combined_attention_mask # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] elif len(attention_mask.shape) == 3: @@ -903,6 +913,20 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + + global_mesh = global_mesh_starts_with_pp() + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + # embed positions if attention_mask is None: # [bs, seq_len] @@ -914,15 +938,6 @@ def forward( else: alibi = None - if position_ids is None: - position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - # NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel - position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()]) - - if self.config.sequence_parallel: - # [B, S, H] -> [S, B, H] - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) - if self.config.use_flash_attention: # attention_mask in flash_attn is always None for pretrain attention_mask = None @@ -930,6 +945,11 @@ def forward( attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) hidden_states = inputs_embeds hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) @@ -939,33 +959,37 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None - pre_ipp = None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient - - if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp: - hidden_states = dist.reshard( - hidden_states, - get_mesh(decoder_layer.ipp), - self.placements, - ) - position_ids = dist.reshard( + ipp = decoder_layer.ipp + if not is_pp_enable(): + position_ids_input = position_ids + attention_mask_input = attention_mask + else: + position_ids_input = dist.reshard( position_ids, - get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], ) - attention_mask = ( + attention_mask_input = ( dist.reshard( attention_mask, - get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], + get_mesh(ipp), + [dist.Replicate(), dist.Replicate()], ) if attention_mask is not None - else attention_mask + else None + ) + + if idx in self.next_pp_stage_indexes: + hidden_states = dist.reshard( + hidden_states, + get_mesh(ipp), + self.placements, ) if ( @@ -977,8 +1001,8 @@ def forward( layer_outputs = recompute( decoder_layer, hidden_states, - position_ids, - attention_mask, + position_ids_input, + attention_mask_input, output_attentions, past_key_value, use_cache, @@ -987,16 +1011,14 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - position_ids, - attention_mask, + position_ids_input, + attention_mask_input, output_attentions, past_key_value, use_cache, alibi=alibi, ) - pre_ipp = decoder_layer.ipp - if type(layer_outputs) is tuple: hidden_states = layer_outputs[0] else: