Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused linear for the LLAMA MLP block and multi-head attention block #6425

Merged
merged 4 commits into from
Jul 26, 2023

Conversation

littsk
Copy link
Contributor

@littsk littsk commented Jul 18, 2023

PR types

New features

PR changes

Models

Description

Add fused linear operation for MLP and multi-head attention in LLAMA.

@paddle-bot
Copy link

paddle-bot bot commented Jul 18, 2023

Thanks for your contribution!

examples/language_model/llama/run_pretrain.py Outdated Show resolved Hide resolved
gather_output=False,
)
if self.fuse_attn_qkv:
self.qkv_proj = mpu.ColumnParallelLinear(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

预训练参数加载转换,看看能不能根据fuse 搞成自动的。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请问大佬,具体是啥意思呀

paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
@littsk littsk force-pushed the fused_linear_feature branch 3 times, most recently from a460685 to 3672383 Compare July 19, 2023 12:12
@littsk littsk force-pushed the fused_linear_feature branch from 3672383 to 178e63b Compare July 19, 2023 12:14
@codecov
Copy link

codecov bot commented Jul 19, 2023

Codecov Report

Merging #6425 (394608f) into develop (8e27802) will decrease coverage by 0.05%.
Report is 37 commits behind head on develop.
The diff coverage is 65.33%.

@@             Coverage Diff             @@
##           develop    #6425      +/-   ##
===========================================
- Coverage    63.18%   63.14%   -0.05%     
===========================================
  Files          529      529              
  Lines        77214    77315     +101     
===========================================
+ Hits         48789    48817      +28     
- Misses       28425    28498      +73     
Files Changed Coverage Δ
paddlenlp/taskflow/text_feature_extraction.py 46.00% <50.00%> (+0.62%) ⬆️
paddlenlp/transformers/llama/modeling.py 70.86% <65.38%> (-5.30%) ⬇️
paddlenlp/taskflow/task.py 64.21% <100.00%> (ø)
paddlenlp/taskflow/taskflow.py 84.88% <100.00%> (ø)
paddlenlp/transformers/llama/configuration.py 100.00% <100.00%> (ø)

... and 24 files with indirect coverage changes

paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
@littsk littsk force-pushed the fused_linear_feature branch from 9495473 to 32c36cb Compare July 24, 2023 05:47
Copy link
Collaborator

@sijunhe sijunhe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看ok.
@ZHUI 看下呢?

@littsk littsk force-pushed the fused_linear_feature branch from 32c36cb to 6412892 Compare July 24, 2023 07:47
ZHUI
ZHUI previously approved these changes Jul 25, 2023
gather_output=False,
has_bias=False,
)
# 为了减少张量并行的通信量,将两个linear合并成一个
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文注释 删掉

has_bias=False,
)
# 为了减少张量并行的通信量,将两个linear合并成一个
if config.fuse_mlp_linear:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 fuse_mlp_linear 是不是换个名字更好?
fuse_gate_up_proj ? 或者其他?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的好的

@littsk littsk force-pushed the fused_linear_feature branch 2 times, most recently from e1e2083 to 7927e54 Compare July 26, 2023 02:44
@littsk littsk force-pushed the fused_linear_feature branch from 7927e54 to 394608f Compare July 26, 2023 02:49
Copy link
Collaborator

@ZHUI ZHUI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@littsk littsk requested review from sijunhe and FeixLiu July 26, 2023 02:59
Copy link
Contributor

@FeixLiu FeixLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the fuse

@zjjlivein zjjlivein merged commit d1050f4 into PaddlePaddle:develop Jul 26, 2023
triple-Mu pushed a commit to triple-Mu/PaddleNLP that referenced this pull request Aug 3, 2023
…ck (PaddlePaddle#6425)

* Add fused linear for the LLAMA MLP block and multi-head attention block

* Refactor the config name 'fuse_attn_qkv' to 'fuse_attention_qkv' for improved readability and consistency.

* Added a switch for fuse_mlp_linear and improved the organization of the fused linear implementation.

* add tensor parallel mappings for fused linears
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants