Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supports llama-dybatch-V1 #6676

Merged
merged 8 commits into from
Aug 22, 2023
Merged

Conversation

carryyu
Copy link
Contributor

@carryyu carryyu commented Aug 10, 2023

PR types

New features

PR changes

Models

Description

supports llama-dybatch-V1

@paddle-bot
Copy link

paddle-bot bot commented Aug 10, 2023

Thanks for your contribution!

Copy link
Contributor

@wj-Mcat wj-Mcat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得大部分的工作都非常棒,有几点想跟你讨论的。

另外,后面有时间可以也加一加相关单测,目前 paddlenlp 合入进去的相关东西一般都是要加的。

paddlenlp/ops/generation/encode_rotary_qk.cu Outdated Show resolved Hide resolved
paddlenlp/ops/generation/setup_cuda.py Outdated Show resolved Hide resolved
paddlenlp/transformers/llama/modeling.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@qingqing01 qingqing01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要补充使用文档

llm/llama/dybatch/export_generation_model.py Outdated Show resolved Hide resolved
@carryyu
Copy link
Contributor Author

carryyu commented Aug 14, 2023

需要补充使用文档

已添加

csrc/encode_rotary_qk.cu Outdated Show resolved Hide resolved
@@ -0,0 +1,21 @@
# LLaMA DyBatch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前在LLM目录模型的使用方法基本得到统一,微调、预测、量化相关的脚本都是共用一套

动态插入的脚本是否可以得到统一

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分可能不太好统一,后续还有各种量化方法,全部放到一起显得不够清晰,或者在主README里面加一下跳转链接这样呢

@@ -0,0 +1,247 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个建议放到paddlenlp/transformers目录,组织方式 @wj-Mcat 帮忙看下~

@@ -0,0 +1,147 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个目录下的文件看下能否删掉,在llm/llama目录的文件加一个--enable_dybatch的参数,通过分支来维护

Copy link

@heavengate heavengate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认下是否是在tests/transformer/llama目录下添加单测

@CLAassistant
Copy link

CLAassistant commented Aug 18, 2023

CLA assistant check
All committers have signed the CLA.

@codecov
Copy link

codecov bot commented Aug 18, 2023

Codecov Report

Merging #6676 (22df0d1) into develop (e0a9f4e) will decrease coverage by 0.35%.
Report is 1 commits behind head on develop.
The diff coverage is 0.44%.

❗ Current head 22df0d1 differs from pull request most recent head e276db4. Consider uploading reports for the commit e276db4 to get more accurate results

@@             Coverage Diff             @@
##           develop    #6676      +/-   ##
===========================================
- Coverage    60.85%   60.50%   -0.35%     
===========================================
  Files          534      539       +5     
  Lines        78870    79322     +452     
===========================================
+ Hits         47995    47996       +1     
- Misses       30875    31326     +451     
Files Changed Coverage Δ
paddlenlp/experimental/transformers/__init__.py 0.00% <0.00%> (ø)
...erimental/transformers/fused_transformer_layers.py 0.00% <0.00%> (ø)
...enlp/experimental/transformers/generation_utils.py 0.00% <0.00%> (ø)
...dlenlp/experimental/transformers/llama/__init__.py 0.00% <0.00%> (ø)
...dlenlp/experimental/transformers/llama/modeling.py 0.00% <0.00%> (ø)
paddlenlp/utils/import_utils.py 85.71% <50.00%> (-0.96%) ⬇️
paddlenlp/transformers/llama/modeling.py 69.95% <100.00%> (ø)

... and 1 file with indirect coverage changes

@@ -0,0 +1,19 @@
# LLaMA Inference

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 按llm目录的组织方式,这里.sh文件删掉吧,文档按Python命令方式给一下,区分一下单卡和多卡
  2. 多卡权重拆分的统一脚本 @wj-Mcat 看下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按llm目录的组织方式,这里.sh文件删掉吧,文档按Python命令方式给一下,区分一下单卡和多卡

我在最新的 commit 当中已经删掉了。

export FLAGS_new_executor_serial_run=1
export FLAGS_allocator_strategy=naive_best_fit
export FLAGS_fraction_of_gpu_memory_to_use=0.95
export FLAGS_use_cutlass_fmha=1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 非必要的flag建议删掉,如log相关
  2. 其余flag在README里简单说明一下作用

@wj-Mcat
Copy link
Contributor

wj-Mcat commented Aug 21, 2023

截止目前,完成了 InferenceModel:

  • 单卡动态图、动转静和静态图验证
  • 多卡动态图、动转静和静态图验证
  • REAMD 文档的调整
  • inferenceModel 的调整

Comment on lines +156 to +159
if paddle.in_dynamic_mode():
y_is_distributed = y.is_distributed
else:
y_is_distributed = tensor_parallel_degree > 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动态图下 y.is_distributed 为真实值,可是在静态图下y.is_distributed 一直为 False,于是会影响最终 Logits 的维度,从而影响解码的精度。

在此处针对于静态图做了一定的适配。

return None


class DygraphInferencePredictor(BasePredictor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的命名后续可以修改一下,之前的理解是dygraph表示动态图,inference表示静态图推理

自己记个TODO吧

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的命名暂时没有比较优雅合适的名字:

  • DygraphInferencePredictor(中庸)
  • DygraphinferenceModelPredictor(太长)
  • DIPredictor(缩写,什么鬼)

大家有什么合适的名字也可以来参与讨论。

@@ -242,53 +250,296 @@ def _infer(self, inputs: dict[str, np.ndarray]):
return decoded_ids


def create_predictor(predictor_args: PredictorArgument, model_args: ModelArgument):
class StaticInferencePredictor(BasePredictor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同时这里区分动态batch和非动态batch的了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有两个 flag:

  • mode: dygraph, static
  • inference_model: bool 类型

通过以上两个 flag 来控制这四种情况。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dygraph -> dynamic

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit b3b650c into PaddlePaddle:develop Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants