-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NEW Feature] 新增基于hook的refined_recompute支持 #9396
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9396 +/- ##
===========================================
+ Coverage 52.84% 52.93% +0.09%
===========================================
Files 688 689 +1
Lines 109378 109796 +418
===========================================
+ Hits 57801 58121 +320
- Misses 51577 51675 +98 ☔ View full report in Codecov by Sentry. |
pylayer_matmul = PyLayerMatmul.apply | ||
|
||
|
||
class BertConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么不直接在bert中搞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
主要为了测试,在bert中没有flash attn
@@ -51,6 +51,7 @@ def swiglu(x, y=None): | |||
except: | |||
flash_attention = None | |||
|
|||
from paddlenlp.transformers.refined_recompute import no_recompute |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要叫no_recompute,感觉怪怪的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要么改成skip_recompute也行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recompute(func, xxxxx) vs no_recompute(func, xxxxxx)
再适配一下qwen模型吧。 |
@ZHUI 已经支持qwen和qwen2 |
@@ -268,6 +268,14 @@ class LlmMetaConfig: | |||
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']", | |||
), | |||
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的配置信息会传到下游任务里面吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要 _set_unsavable_keys 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):
return output | ||
|
||
|
||
class RRRowSequenceParallelLinear(RowSequenceParallelLinear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于RowParallelLinear不用重写代码,但是RRRowSequenceParallelLinear需要重新写代码了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前没有支持非SequenceParallel的并行,当然也可以支持看看
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE,已经支持
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Description
llama
、qwen
、qwen2
模型refined_recompute
支持mlp_row_ln,attention_row_ln,attention_column_ln,mlp_column_ln,flash_attn
这些算子,其中LoRA
训练的时候不支持*_ln
, 仅支持flash_attn
llama
模型:meta-llama/Meta-Llama-3-8B
.refined_recompute
,仅限在recompute_use_reentrant=False
的时候生效,其他情况不生效。1. 简单测试refined_recompute代码
2. llama模型SFT 8k下精度对比(开启rr和关闭rr)
结论:10步的loss完全一致,精度一致,符合预期。2.1 【精度】关闭 refined_recompute
2.2 【精度】开启 refined_recompute "flash_attn:-1"
3. llama模型SFT 8k下性能对比(开启rr和关闭rr)
结论:第二步ips, 1.1894 / 1.1636 = 1.022,约有 2.21%的提速3.1 【性能】关闭 refined_recompute
3.2 【性能】开启 refined_recompute "flash_attn:-1"
4.测试PP精度, 对比不开recompute, 标准recompute,RR的recompute
5. llama 16k PostPretrain性能对比代码
速度提升约6~7%,当前由于没有添加fused head loss, 导致无法训练32k,64k配置,理论上提速能更多(超过10%)。
6 新增tp+sp与tp的对比