Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW Feature] 新增基于hook的refined_recompute支持 #9396

Merged
merged 20 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.trl import SFTTrainer
from paddlenlp.trl.llm_utils import (
ZeroPaddingIterDatasetCallback,
Expand Down Expand Up @@ -146,6 +147,10 @@ def main():
)

LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
model_args.lora,
)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

# Config for model using dropout, such as GPT.
Expand Down
4 changes: 4 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
register_sequence_parallel_allreduce_hooks,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig, llmmetaclass
from paddlenlp.transformers.refined_recompute import update_refined_recompute
from paddlenlp.utils.batch_sampler import DistributedBatchSampler
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device
Expand Down Expand Up @@ -413,6 +414,9 @@ def main():
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
# set all llm config
LlmMetaConfig.set_llm_config(config, training_args)
config.refined_recompute = update_refined_recompute(
training_args.refined_recompute,
)
config.use_fast_layer_norm = model_args.use_fast_layer_norm

config.seq_length = data_args.max_seq_length
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,14 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的配置信息会传到下游任务里面吗?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要 _set_unsavable_keys 吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):

# refined_recompute attributes
(
"refined_recompute",
str,
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
]

@classmethod
Expand Down
14 changes: 11 additions & 3 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
except:
flash_attention = None

from paddlenlp.transformers.refined_recompute import no_recompute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要叫no_recompute,感觉怪怪的

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要么改成skip_recompute也行

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recompute(func, xxxxx) vs no_recompute(func, xxxxxx)

from paddlenlp.transformers.ring_flash_attention import RingFlashAttention


Expand Down Expand Up @@ -174,6 +175,7 @@
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand Down Expand Up @@ -257,28 +259,34 @@
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)

if hasattr(F, "flashmask_attention"):
attn_output = F.flashmask_attention(
attn_output = no_recompute(

Check warning on line 262 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L262

Added line #L262 was not covered by tests
F.flashmask_attention,
query_states,
key_states,
value_states,
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
causal=True,
enable=skip_recompute,
)
else:
attn_output = F.flash_attention_with_sparse_mask(
attn_output = no_recompute(

Check warning on line 272 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L272

Added line #L272 was not covered by tests
F.flash_attention_with_sparse_mask,
query_states,
key_states,
value_states,
attn_mask_start_row_indices=attn_mask_startend_row_indices,
is_causal=True,
enable=skip_recompute,
)
else:
attn_output = F.scaled_dot_product_attention(
attn_output = no_recompute(

Check warning on line 282 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L282

Added line #L282 was not covered by tests
F.scaled_dot_product_attention,
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=query_states.shape[1] != 1,
enable=skip_recompute,
)
attn_weights = None

Expand Down
54 changes: 52 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@
from paddle.autograd import PyLayer
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
Expand Down Expand Up @@ -215,6 +223,7 @@
sequence_parallel=False,
reshard_layer=None,
npu_is_casual=False,
skip_recompute=False,
):
bsz, q_len, num_heads, head_dim = query_states.shape
_, kv_seq_len, _, _ = value_states.shape
Expand All @@ -232,6 +241,7 @@
sequence_parallel,
reshard_layer,
npu_is_casual,
skip_recompute=skip_recompute,
)

# Paddle Flash Attention input [ bz, seqlen, nhead, head_dim]
Expand Down Expand Up @@ -604,10 +614,24 @@
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

Check warning on line 617 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L617

Added line #L617 was not covered by tests
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear

Check warning on line 623 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L620-L623

Added lines #L620 - L623 were not covered by tests
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

Check warning on line 634 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L631-L634

Added lines #L631 - L634 were not covered by tests
if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = ColumnParallelLinear(
Expand Down Expand Up @@ -718,9 +742,22 @@
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

Check warning on line 745 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L745

Added line #L745 was not covered by tests
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear

Check warning on line 751 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L748-L751

Added lines #L748 - L751 were not covered by tests
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

Check warning on line 760 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L758-L760

Added lines #L758 - L760 were not covered by tests

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -820,6 +857,14 @@

self.attn_func = scaled_dot_product_attention

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if (
config.recompute
and not config.recompute_use_reentrant
and config.skip_recompute_ops.get("flash_attn", False)
):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

Check warning on line 867 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L867

Added line #L867 was not covered by tests
def _init_rope(self):
if (
hasattr(self.config, "rope_scaling")
Expand Down Expand Up @@ -1470,7 +1515,12 @@
)

self.layers = nn.LayerList(
[LlamaDecoderLayer(config, i not in self.no_recompute_layers) for i in range(config.num_hidden_layers)]
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)

Expand Down
11 changes: 9 additions & 2 deletions paddlenlp/transformers/llama/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
PipelineLayer,
SharedLayerDesc,
)
from paddle.distributed.fleet.utils import recompute

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
recompute,
)
from paddlenlp.utils.tools import get_env_device

from .modeling import (
Expand Down Expand Up @@ -371,7 +374,11 @@ def get_hcg():

for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(LlamaDecoderLayerPipe, config=config, layerwise_recompute=i not in self.no_recompute_layers),
LayerDesc(
LlamaDecoderLayerPipe,
config=create_skip_config_for_refined_recompute(i, config),
layerwise_recompute=i not in self.no_recompute_layers,
),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
Expand Down
48 changes: 45 additions & 3 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.distributed.fleet.layers.mpu.random import get_rng_state_tracker
from paddle.distributed.fleet.utils import recompute
from paddle.utils import try_import

from paddlenlp.transformers.refined_recompute import (
RRColumnParallelLinear,
RRColumnSequenceParallelLinear,
RRRowParallelLinear,
RRRowSequenceParallelLinear,
create_skip_config_for_refined_recompute,
no_recompute,
recompute,
)

try:
from paddle.incubate.nn.functional import swiglu
except ImportError:
Expand Down Expand Up @@ -154,9 +163,22 @@
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear

Check warning on line 172 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L168-L172

Added lines #L168 - L172 were not covered by tests
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
RowParallelLinear = RRRowParallelLinear

Check warning on line 181 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L178-L181

Added lines #L178 - L181 were not covered by tests

if config.tensor_parallel_degree > 1:
if config.num_attention_heads % config.tensor_parallel_degree != 0:
Expand Down Expand Up @@ -227,12 +249,19 @@
return_softmax=self.config.attn_dropout_prob > 0.0,
)
else:
attn_output = F.scaled_dot_product_attention(
skip_recompute = (

Check warning on line 252 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L252

Added line #L252 was not covered by tests
self.config.recompute
and not self.config.recompute_use_reentrant
and self.config.skip_recompute_ops.get("flash_attn", False)
)
attn_output = no_recompute(

Check warning on line 257 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L257

Added line #L257 was not covered by tests
F.scaled_dot_product_attention,
query,
key,
value,
attn_mask=attention_mask,
is_causal=attention_mask is None,
enable=skip_recompute,
)
attn_weights = None

Expand Down Expand Up @@ -388,9 +417,22 @@
if config.sequence_parallel:
ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear
RowParallelLinear = linear_utils.RowSequenceParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowSequenceParallelLinear

Check warning on line 426 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L422-L426

Added lines #L422 - L426 were not covered by tests
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
RowParallelLinear = RRRowParallelLinear

Check warning on line 435 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L432-L435

Added lines #L432 - L435 were not covered by tests

if config.tensor_parallel_degree > 1:
if self.fuse_attention_ffn:
Expand Down Expand Up @@ -684,7 +726,7 @@
self.h = nn.LayerList(
[
QWenBlock(
config,
create_skip_config_for_refined_recompute(i, config),
)
for i in range(config.num_hidden_layers)
]
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/transformers/qwen/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer

from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import (
create_skip_config_for_refined_recompute,
)

from .modeling import (
QWenBlock,
Expand Down Expand Up @@ -170,7 +173,7 @@ def get_hcg():
self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(QWenBlockPipe, config=config),
LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)),
f"qwen.h.{i}",
)
self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f")
Expand Down
Loading
Loading