Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] unify llama model #8127

Merged
merged 2 commits into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 70 additions & 48 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,26 @@
]


def is_pp_enable():
mesh = fleet.auto.get_mesh()
return "pp" in mesh.dim_names

Check warning on line 87 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L86-L87

Added lines #L86 - L87 were not covered by tests


def get_mesh(pp_idx=0):
mesh = fleet.auto.get_mesh()
if "pp" in mesh.dim_names:
mesh = mesh.get_mesh_with_dim("pp", pp_idx)
return mesh


def global_mesh_starts_with_pp():
mesh = fleet.auto.get_mesh()
if is_pp_enable():
return mesh.get_mesh_with_dim("pp")

Check warning on line 100 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L98-L100

Added lines #L98 - L100 were not covered by tests
else:
return mesh

Check warning on line 102 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L102

Added line #L102 was not covered by tests


def scaled_dot_product_attention(
query_states,
config,
Expand Down Expand Up @@ -800,21 +813,25 @@
[dist.Replicate(), dist.Shard(1)],
)

def get_layer_ipp(layer_index):
def get_layer_pp_info(layer_index):

Check warning on line 816 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L816

Added line #L816 was not covered by tests
mesh = fleet.auto.get_mesh()
if "pp" not in mesh.dim_names:
return None
if is_pp_enable() is False:
return None, False

Check warning on line 819 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L818-L819

Added lines #L818 - L819 were not covered by tests
else:
pp_degree = mesh.get_dim_size("pp")
layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree)
return layer_index // layer_per_stage

self.layers = nn.LayerList(
[
LlamaDecoderLayerAuto(config, i not in self.no_recompute_layers, get_layer_ipp(i))
for i in range(config.num_hidden_layers)
]
)
input_need_reshard = layer_index % layer_per_stage == 0
return layer_index // layer_per_stage, input_need_reshard

Check warning on line 824 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L823-L824

Added lines #L823 - L824 were not covered by tests

decoder_layers = []
self.next_pp_stage_indexes = []
for i in range(config.num_hidden_layers):
pp_stage_id, input_need_reshard = get_layer_pp_info(i)
decoder_layers.append(LlamaDecoderLayerAuto(config, False, pp_stage_id))
if input_need_reshard:
self.next_pp_stage_indexes.append(i)

Check warning on line 832 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L826-L832

Added lines #L826 - L832 were not covered by tests

self.layers = nn.LayerList(decoder_layers)

Check warning on line 834 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L834

Added line #L834 was not covered by tests
self.norm = LlamaRMSNormAuto(config)

self.gradient_checkpointing = False
Expand All @@ -840,13 +857,6 @@
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel
combined_attention_mask = dist.shard_tensor(
combined_attention_mask,
get_mesh(),
[dist.Replicate(), dist.Replicate()],
)

expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
Expand Down Expand Up @@ -903,6 +913,20 @@
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self.config.sequence_parallel:

Check warning on line 916 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L916

Added line #L916 was not covered by tests
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

Check warning on line 918 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L918

Added line #L918 was not covered by tests

global_mesh = global_mesh_starts_with_pp()
if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

Check warning on line 922 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L920-L922

Added lines #L920 - L922 were not covered by tests

position_ids = dist.shard_tensor(

Check warning on line 924 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L924

Added line #L924 was not covered by tests
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

# embed positions
if attention_mask is None:
# [bs, seq_len]
Expand All @@ -914,22 +938,18 @@
else:
alibi = None

if position_ids is None:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))
# NOTE(zhaoyingli): infer spmd does not support [seq_len] --> [batch, seq_len] in data_parallel
position_ids = dist.shard_tensor(position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()])

if self.config.sequence_parallel:
# [B, S, H] -> [S, B, H]
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

if self.config.use_flash_attention:
# attention_mask in flash_attn is always None for pretrain
attention_mask = None
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
attention_mask = dist.shard_tensor(

Check warning on line 948 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L948

Added line #L948 was not covered by tests
attention_mask,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

hidden_states = inputs_embeds
hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements)
Expand All @@ -939,33 +959,37 @@
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None

pre_ipp = None
for idx, (decoder_layer) in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient

if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp:
hidden_states = dist.reshard(
hidden_states,
get_mesh(decoder_layer.ipp),
self.placements,
)
position_ids = dist.reshard(
ipp = decoder_layer.ipp
if not is_pp_enable():
position_ids_input = position_ids
attention_mask_input = attention_mask

Check warning on line 971 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L968-L971

Added lines #L968 - L971 were not covered by tests
else:
position_ids_input = dist.reshard(

Check warning on line 973 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L973

Added line #L973 was not covered by tests
position_ids,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
attention_mask = (
attention_mask_input = (

Check warning on line 978 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L978

Added line #L978 was not covered by tests
dist.reshard(
attention_mask,
get_mesh(decoder_layer.ipp),
[dist.Shard(0), dist.Replicate()],
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
if attention_mask is not None
else attention_mask
else None
)

if idx in self.next_pp_stage_indexes:
hidden_states = dist.reshard(

Check warning on line 989 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L988-L989

Added lines #L988 - L989 were not covered by tests
hidden_states,
get_mesh(ipp),
self.placements,
)

if (
Expand All @@ -977,8 +1001,8 @@
layer_outputs = recompute(
decoder_layer,
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
Expand All @@ -987,16 +1011,14 @@
else:
layer_outputs = decoder_layer(
hidden_states,
position_ids,
attention_mask,
position_ids_input,
attention_mask_input,
output_attentions,
past_key_value,
use_cache,
alibi=alibi,
)

pre_ipp = decoder_layer.ipp

if type(layer_outputs) is tuple:
hidden_states = layer_outputs[0]
else:
Expand Down
Loading