Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] update clear_grad #8829

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions docs/trainer.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
trainer.md
# PaddleNLP Trainer API

PaddleNLP提供了Trainer训练API,针对训练过程的通用训练配置做了封装,比如:
PaddleNLP 提供了 Trainer 训练 API,针对训练过程的通用训练配置做了封装,比如:

- 优化器、学习率调度等训练配置
- 多卡,混合精度,梯度累积等功能
- checkpoint断点,断点重启(数据集,随机数恢复)
- 日志显示,loss可视化展示等
- checkpoint 断点,断点重启(数据集,随机数恢复)
- 日志显示,loss 可视化展示等

用户输入模型,数据集,就可以使用Trainer API高效快速的实现预训练、微调等任务。
用户输入模型,数据集,就可以使用 Trainer API 高效快速的实现预训练、微调等任务。


## Trainer基本使用方法介绍
## Trainer 基本使用方法介绍

下面是用户使用 Trainer API进行finetune任务的简单示例,这里以中文情感分类数据集`chnsenticorp`为例。
下面是用户使用 Trainer API 进行 finetune 任务的简单示例,这里以中文情感分类数据集`chnsenticorp`为例。
更详细的使用可以参考[CLUE Trainer](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/legacy/examples/benchmark/clue/classification/run_clue_classifier_trainer.py)版本。

1. 导入需要用到的头文件。
- 主要是模型、Tokenizer
- 还有Trainer组件
- 还有 Trainer 组件
- 其中`Trainer`是训练主要入口,用户传入模型,数据集,即可进行训练
- `TrainingArguments` 包含了用户需要的大部分训练参数。
- `PdArgumentParser` 是用户输出参数的工具
Expand All @@ -30,7 +30,7 @@ from paddlenlp.transformers import AutoModelForSequenceClassification, AutoToken
from paddlenlp.trainer import Trainer, TrainingArguments, PdArgumentParser
```
2. 设置好用户参数
- PdArgumentParser 可以接受多个类似`TrainingArguments`的参数。用户可以自定义所需要的`ModelArguments`, `DataArguments`为 tuple 传入 PdArgumentParser即可
- PdArgumentParser 可以接受多个类似`TrainingArguments`的参数。用户可以自定义所需要的`ModelArguments`, `DataArguments`为 tuple 传入 PdArgumentParser 即可
- 这些参数都是通过`python xxx.py --dataset xx --max_seq_length xx`的方式传入。`TrainingArguments`的所有可配置参数见后文。
```python
from dataclasses import dataclass
Expand All @@ -49,8 +49,8 @@ parser = PdArgumentParser(TrainingArguments, DataArguments)
```

3. 加载模型,tokenizer, 数据集
- 注意,这里的数据集,需要输出的是一个dict。dict中的key,需要和模型的输入名称对应。
- 这里的,`labels`如果模型没有使用到,我们还需要额外定义`criterion`,计算最后的loss损失
- 注意,这里的数据集,需要输出的是一个 dict。dict 中的 key,需要和模型的输入名称对应。
- 这里的,`labels`如果模型没有使用到,我们还需要额外定义`criterion`,计算最后的 loss 损失
```python
train_dataset = load_dataset("chnsenticorp", splits=["train"])
model = AutoModelForSequenceClassification.from_pretrained("ernie-3.0-medium-zh", num_classes=len(train_dataset.label_list))
Expand All @@ -64,9 +64,9 @@ def convert_example(example, tokenizer):
train_dataset = train_dataset.map(partial(convert_example, tokenizer=tokenizer))
```

4. 构造Trainer实例,进行模型训练。
- 这里传入`model,criterion,args,train_dataset,tokenizer`这些训练需要的组件,构建了实例化的trainer
- 使用trainer.train()接口开始训练过程。训练完成后,可以保存模型,保存一些日志。
4. 构造 Trainer 实例,进行模型训练。
- 这里传入`model,criterion,args,train_dataset,tokenizer`这些训练需要的组件,构建了实例化的 trainer
- 使用 trainer.train()接口开始训练过程。训练完成后,可以保存模型,保存一些日志。
```python
trainer = Trainer(
model=model,
Expand All @@ -85,24 +85,24 @@ if training_args.do_train:
预训练的使用方式可以参考[ERNIE-1.0 Trainer](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/legacy/model_zoo/ernie-1.0/run_pretrain_trainer.py)版本。


## Trainer进阶分布式能力使用介绍
## Trainer 进阶分布式能力使用介绍

**通用分布式能力**
对于通用的分布式能力, PaddleNLP主要做了数据并行data_parallel, 分布式参数sharding功能的支持.
对于通用的分布式能力, PaddleNLP 主要做了数据并行 data_parallel, 分布式参数 sharding 功能的支持.
这类功能无需用户修改组网, 直接多卡即可运行.

用户使用 `paddle.distruted.launch --devices "0,1,2,3" train.py`即可将运行的程序切换为多卡数据并行.
如果想要使用sharding功能, 减少模型显存占用, 指定参数`--sharding "stage2"`即可. 更多sharding功能配置见参数介绍部分.
如果想要使用 sharding 功能, 减少模型显存占用, 指定参数`--sharding "stage2"`即可. 更多 sharding 功能配置见参数介绍部分.


**混合并行分布式能力**

飞桨4D并行, 即: data parallel + sharding parallel + tensor parallel + pipeline parallel.
飞桨4D 并行, 即: data parallel + sharding parallel + tensor parallel + pipeline parallel.

混合并行这里, 主要添加了 tensor parallel (TP) 和 pipeline parallel(PP)支持.
目前, PaddleNLP主要对一些大模型, 如 GPT, Llama等做了 TP PP支持, 用户可以使用这些策略.
目前, PaddleNLP 主要对一些大模型, 如 GPT, Llama 等做了 TP PP 支持, 用户可以使用这些策略.

相关代码实现可以参考llama训练的[例子](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm)
相关代码实现可以参考 llama 训练的[例子](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm)

流水线并行的组网改造可以参见[modeling_pp.py](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling_pp.py)

Expand All @@ -113,7 +113,7 @@ if training_args.do_train:


## Trainer 实例化参数介绍
Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并针对 PaddleNLP 模型进行了优化。
Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并针对 PaddleNLP 模型进行了优化。

```text
参数:
Expand Down Expand Up @@ -731,4 +731,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
Whether to enable MoE (Mixture of Experts) expert parallel training.
(default: False)

--release_grads
是否在训练过程每次迭代后对梯度进行释放,减少峰值显存. 可选,默认为False)
Whether to reduce peak memory usage by releasing gradients after each iteration. (default: False)

```
2 changes: 1 addition & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
if optimizer_was_run:
self.lr_scheduler.step()

if enable_release_grads:
if args.release_grads or enable_release_grads:
self.optimizer.clear_grad(set_to_zero=False)
if args.pipeline_parallel_degree > 1:
for _, buffers in model._chunk_2_comm_buffers.items():
Expand Down
7 changes: 6 additions & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class TrainingArguments:
enable_stage1_broadcast_overlap, overlap stage1 V1 broadcast with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for broadcast overlap forward compute and no other sync could be called during the training for broadcast overlap.
enable_stage1_allgather_overlap, overlap stage1 V2 allgather with next step forward computation. There are some constraints for the overlap, such as the logging_step should be bigger than 1 for allgather overlap forward compute and no other sync could be called during the training for allgather overlap.
disable_stage1_reduce_avg, replace reduce_avg with original reduce_sum+scale in stage1, which can be used for accuracy verification.
enable_release_graHEADds, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
recompute (`bool`, *optional*, defaults to `False`):
Recompute the forward pass to calculate gradients. Used for saving memory.
Only support for networks with transformer blocks.
Expand Down Expand Up @@ -355,6 +355,8 @@ class TrainingArguments:
Whether skip profile timer, timer will record time usage of forward/ backward/ step, etc.
distributed_dataloader (`bool`, *optional*):
Whether to use distributed dataloader. Default is `False`.
release_grads (`bool`, *optional*):
Whether to release gradients during training. Default is `False`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""

output_dir: str = field(
Expand Down Expand Up @@ -832,6 +834,9 @@ class TrainingArguments:
default=False,
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
)
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down
Loading