-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FlashMask] Add FlashMask for Qwen2 #9264
[FlashMask] Add FlashMask for Qwen2 #9264
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #9264 +/- ##
===========================================
- Coverage 52.84% 52.62% -0.22%
===========================================
Files 661 661
Lines 107783 107365 -418
===========================================
- Hits 56955 56501 -454
- Misses 50828 50864 +36 ☔ View full report in Codecov by Sentry. |
4c38433
to
0d872f1
Compare
后续验证flashmask和fa2,可以直接用加上,能够所有loss逐位对齐即可 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
336c241
to
0d872f1
Compare
d34af01
to
0d872f1
Compare
@classmethod | ||
def _prepare_pipeline_inputs_func(cls, inputs): | ||
|
||
first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo:attn_mask_startend_row_indices
这个参数可以加内置函数里面吧
last_stage_keys = ["labels"] | ||
|
||
def get_expected_keys(inputs, keys): | ||
ret = tuple([inputs.pop(k) if k in inputs else None for k in keys]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PaddleNLP/paddlenlp/trainer/trainer.py
Line 1978 in 2b975b1
ret = tuple([inputs.pop(k) for k in keys if k in inputs]) |
会有None是不是?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Description
Add FlashMask for Qwen2.
对齐验证步骤:
PaddleNLP/paddlenlp/trainer/trainer.py
Lines 1326 to 1345 in 0e96b0f
2.1 验证单卡与流水线并行的一致性
对齐验证结果:
modeling + flashmask
vs.modeling_pp + flashmask
前4步训练loss无diff,后续逐渐出现随机误差,误差范围在1e-3,符合bf16精度。2.2 验证flashmask开关前后正确性
验证指令
数据类型为bf16时,训练loss误差范围为1e-3
数据类型为fp16时,训练loss误差范围为1e-4