Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlashMask] Add FlashMask for Qwen2 #9264

Merged

Conversation

DrownFish19
Copy link
Collaborator

@DrownFish19 DrownFish19 commented Oct 14, 2024

PR types

New features

PR changes

Others

Description

Add FlashMask for Qwen2.

对齐验证步骤:

  1. 修改shuffle为False
    def _get_train_sampler(self) -> Optional[paddle.io.Sampler]:
    if self.train_dataset is None or not has_length(self.train_dataset):
    return None
    if self.args.world_size <= 1:
    return paddle.io.BatchSampler(
    dataset=self.train_dataset,
    shuffle=True,
    batch_size=self.args.per_device_train_batch_size,
    drop_last=self.args.dataloader_drop_last,
    )
    return DistributedBatchSampler(
    self.train_dataset,
    batch_size=self.args.per_device_train_batch_size,
    shuffle=True,
    num_replicas=self.args.dataset_world_size,
    rank=self.args.dataset_rank,
    drop_last=self.args.dataloader_drop_last,
    )
  2. 验证指令

2.1 验证单卡与流水线并行的一致性

# pipeline parallel
python -m paddle.distributed.launch  \
--devices 6,7  \
run_finetune.py  \
config/qwen/sft_argument.json \
--model_name_or_path Qwen/Qwen2-0.5B \
--gradient_accumulation_steps 1 \
--zero_padding true \
--flash_mask true \
--pipeline_parallel_degree 2
# data parallel
python -m paddle.distributed.launch  \
--devices 7  \
run_finetune.py  \
config/qwen/sft_argument.json \
--model_name_or_path Qwen/Qwen2-0.5B \
--gradient_accumulation_steps 1 \
--zero_padding true \
--flash_mask true \
--pipeline_parallel_degree 1

对齐验证结果:

  • modeling + flashmask vs. modeling_pp + flashmask 前4步训练loss无diff,后续逐渐出现随机误差,误差范围在1e-3,符合bf16精度。

2.2 验证flashmask开关前后正确性
验证指令

# pipeline parallel
python -m paddle.distributed.launch  \
--devices 6,7  \
run_finetune.py  \
config/qwen/sft_argument.json \
--model_name_or_path Qwen/Qwen2-0.5B \
--gradient_accumulation_steps 1 \
--zero_padding true \
--flash_mask true \
--pad_to_max_length true
--pipeline_parallel_degree 2

数据类型为bf16时,训练loss误差范围为1e-3
数据类型为fp16时,训练loss误差范围为1e-4

Copy link

paddle-bot bot commented Oct 14, 2024

Thanks for your contribution!

Copy link

codecov bot commented Oct 14, 2024

Codecov Report

Attention: Patch coverage is 20.37037% with 43 lines in your changes missing coverage. Please review.

Project coverage is 52.62%. Comparing base (220cc95) to head (b4a0ba5).
Report is 1 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/transformers/qwen2/modeling_pp.py 9.52% 38 Missing ⚠️
paddlenlp/transformers/qwen2/modeling.py 58.33% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9264      +/-   ##
===========================================
- Coverage    52.84%   52.62%   -0.22%     
===========================================
  Files          661      661              
  Lines       107783   107365     -418     
===========================================
- Hits         56955    56501     -454     
- Misses       50828    50864      +36     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@DrownFish19 DrownFish19 requested review from ZHUI and lugimzzz October 14, 2024 07:24
@DrownFish19 DrownFish19 force-pushed the dev_20241007_qwen2_add_flashmask branch from 4c38433 to 0d872f1 Compare October 14, 2024 08:44
@lugimzzz
Copy link
Contributor

lugimzzz commented Oct 16, 2024

后续验证flashmask和fa2,可以直接用加上,能够所有loss逐位对齐即可
export FLAGS_cudnn_deterministic=1
export FLAGS_embedding_deterministic=1

lugimzzz
lugimzzz previously approved these changes Oct 16, 2024
Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DrownFish19 DrownFish19 force-pushed the dev_20241007_qwen2_add_flashmask branch 2 times, most recently from 336c241 to 0d872f1 Compare October 16, 2024 09:07
@DrownFish19 DrownFish19 force-pushed the dev_20241007_qwen2_add_flashmask branch from d34af01 to 0d872f1 Compare October 18, 2024 02:19
@classmethod
def _prepare_pipeline_inputs_func(cls, inputs):

first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo:attn_mask_startend_row_indices 这个参数可以加内置函数里面吧

last_stage_keys = ["labels"]

def get_expected_keys(inputs, keys):
ret = tuple([inputs.pop(k) if k in inputs else None for k in keys])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ret = tuple([inputs.pop(k) for k in keys if k in inputs])

会有None是不是?

Copy link
Contributor

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZHUI ZHUI merged commit 76a118b into PaddlePaddle:develop Oct 21, 2024
9 of 12 checks passed
@DrownFish19 DrownFish19 deleted the dev_20241007_qwen2_add_flashmask branch October 21, 2024 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants