Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression API Supports ERNIE-M and more Pretrained models #3234

Merged
merged 10 commits into from
Sep 19, 2022
32 changes: 20 additions & 12 deletions docs/compression.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,19 @@ PaddleNLP 模型压缩 API 功能支持对 ERNIE 类下游任务上微调后的

## 如何启动模型压缩

模型压缩 API 中的压缩功能依赖 `paddleslim` 包。可运行以下命令安装:
### 环境依赖

- paddlepaddle-gpu >=2.3
- paddlenlp >= 2.4.0
- paddleslim >= 2.3.0

模型压缩 API 中的压缩功能依赖最新的 `paddleslim` 包。可运行以下命令安装:

```shell
pip install paddleslim
pip install paddleslim -i https://pypi.tuna.tsinghua.edu.cn/simple
```

大致分为四步
模型压缩 API 的使用大致分为四步

- Step 1: 使用 `PdArgumentParser` 解析从命令行传入的超参数,以获取压缩参数 `compression_args`;
- Step 2: 实例化 Trainer 并调用 `compress()` 压缩 API
Expand Down Expand Up @@ -81,7 +87,7 @@ python compress.py \
--output_dir ./compress_models \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--num_train_epochs 4
--num_train_epochs 4 \
--width_mult_list 0.75 \
--batch_size_list 4 8 16 \
--batch_num_list 1 \
Expand Down Expand Up @@ -111,7 +117,7 @@ compression_args = parser.parse_args_into_dataclasses()

#### Trainer 实例化参数介绍

- **--model** 待压缩的模型,目前支持 ERNIE 等模型,是在下游任务中微调后的模型。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
- **--model** 待压缩的模型,目前支持 ERNIE、BERT、RoBERTa、ERNIE-M、ERNIE-Gram、PP-MiniLM、TinyBERT 等结构相似的模型,是在下游任务中微调后的模型,当预训练模型选择 ERNIE 时,需要继承 `ErniePretrainedModel`。以分类任务为例,可通过`AutoModelForSequenceClassification.from_pretrained(model_name_or_path)` 等方式来获取,这种情况下,`model_name_or_path`目录下需要有 model_config.json, model_state.pdparams 文件;
- **--data_collator** 三类任务均可使用 PaddleNLP 预定义好的 [DataCollator 类](../../paddlenlp/data/data_collator.py),`data_collator` 可对数据进行 `Pad` 等操作。使用方法参考 [示例代码](../model_zoo/ernie-3.0/compress_seq_cls.py) 即可;
- **--train_dataset** 裁剪训练需要使用的训练集,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。不启动裁剪时,可以为 None;
- **--eval_dataset** 裁剪训练使用的评估集,也是量化使用的校准数据,是任务相关的数据。自定义数据集的加载可参考 [文档](https://huggingface.co/docs/datasets/loading)。是 Trainer 的必选参数;
Expand Down Expand Up @@ -155,7 +161,7 @@ trainer.compress()

需要注意以下三个条件:

- 如果模型是自定义模型,模型需要支持调用 `from_pretrained()` 导入模型,且只含 `pretrained_model_name_or_path` 一个必选参数,`forward` 函数返回 `logits` 或者 `tuple of logits`;
- 如果模型是自定义模型,需要继承 `XXXPretrainedModel`,例如当预训练模型选择 ERNIE 时,继承 `ErniePretrainedModel`,模型需要支持调用 `from_pretrained()` 导入模型,且只含 `pretrained_model_name_or_path` 一个必选参数,`forward` 函数返回 `logits` 或者 `tuple of logits`;

- 如果模型是自定义模型,或者数据集比较特殊,压缩 API 中 loss 的计算不符合使用要求,需要自定义 `custom_dynabert_calc_loss` 函数。计算 loss 后计算梯度,从而得出计算神经元的重要性以便裁剪使用。可参考下方示例代码。
- 输入每个 batch 的数据,返回模型的 loss。
Expand All @@ -178,8 +184,9 @@ trainer.compress()
model.eval()
metric.reset()
for batch in data_loader:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和下面的例子,是不是把input_ids 和 token_type_ids key写上,和Trainer相关的代码一致?
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒,已经补充

logits = model(batch['input_ids'],
batch['token_type_ids'],
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
#必须写这一行
attention_mask=[None, None])
# Supports paddleslim.nas.ofa.OFA model and nn.layer model.
if isinstance(model, OFA):
Expand All @@ -196,8 +203,9 @@ trainer.compress()

```python
def calc_loss(loss_fct, model, batch, head_mask):
logits = model(batch["input_ids"],
batch["token_type_ids"],
logits = model(input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
# 必须写下面这行
attention_mask=[None, head_mask])
loss = loss_fct(logits, batch["labels"])
return loss
Expand Down Expand Up @@ -226,7 +234,7 @@ python compress.py \
--output_dir ./compress_models \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--num_train_epochs 4
--num_train_epochs 4 \
--width_mult_list 0.75 \
--batch_size_list 4 8 16 \
--batch_num_list 1 \
Expand Down Expand Up @@ -268,7 +276,7 @@ python compress.py \

- **--logging_steps** 两个日志之间的更新步骤数。默认为 500;

- **--save_steps** 评估模型的步数。默认为 500
- **--save_steps** 评估模型的步数。默认为 100

- **--optim** 裁剪训练使用的优化器名称,默认为adamw,默认为 'adamw';

Expand Down
13 changes: 8 additions & 5 deletions model_zoo/ernie-1.0/finetune/sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,14 @@ def convert_clue(example,
max_seq_len=max_seq_length)

if not is_test:
return {
"input_ids": example['input_ids'],
"token_type_ids": example['token_type_ids'],
"labels": label
}
if "token_type_ids" in example:
return {
"input_ids": example['input_ids'],
"token_type_ids": example['token_type_ids'],
"labels": label
}
else:
return {"input_ids": example['input_ids'], "labels": label}
else:
return {
"input_ids": example['input_ids'],
Expand Down
28 changes: 23 additions & 5 deletions model_zoo/ernie-3.0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1340,14 +1340,32 @@ qa_model = AutoModelForQuestionAnswering.from_pretrained("ernie-3.0-medium-zh")

```shell
# 分类任务
python run_seq_cls.py --task_name tnews --model_name_or_path ernie-3.0-medium-zh --do_train
# 该脚本共支持 CLUE 中 7 个分类任务,超参不全相同,因此分类任务中的超参配置利用 config.yml 配置
python run_seq_cls.py \
--task_name tnews \
--model_name_or_path ernie-3.0-medium-zh \
--do_train

# 序列标注任务
python run_token_cls.py --task_name msra_ner --model_name_or_path ernie-3.0-medium-zh --do_train
python run_token_cls.py \
--task_name msra_ner \
--model_name_or_path ernie-3.0-medium-zh \
--do_train \
--num_train_epochs 3 \
--learning_rate 0.00005 \
--save_steps 100 \
--batch_size 32 \
--max_seq_length 128 \
--remove_unused_columns False

# 阅读理解任务
python run_qa.py --model_name_or_path ernie-3.0-medium-zh --do_train

python run_qa.py \
--model_name_or_path ernie-3.0-medium-zh \
--do_train \
--learning_rate 0.00003 \
--num_train_epochs 8 \
--batch_size 24 \
--max_seq_length 512
```

<a name="模型压缩"></a>
Expand Down Expand Up @@ -1617,7 +1635,7 @@ ONNX 导出及 ONNXRuntime 部署请参考:[ONNX 导出及 ONNXRuntime 部署
- [【快速上手ERNIE 3.0】机器阅读理解实战](https://aistudio.baidu.com/aistudio/projectdetail/2017189)

- [【快速上手ERNIE 3.0】对话意图识别](https://aistudio.baidu.com/aistudio/projectdetail/2017202?contributionType=1)
tangtang


## 参考文献

Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/trainer/compression_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def print_config(self, args=None, key=""):
'weight_quantize_type', 'input_infer_model_path'
]
default_arg_dict = {
"width_mult_list": [0.75],
'batch_size_list': [1],
"width_mult_list": ['3/4'],
'batch_size_list': [4, 8, 16],
'algo_list': ['mse', 'KL'],
'batch_num_list': [4, 8, 16]
'batch_num_list': [1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请教一下,batch_num_list这个参数是什么含义了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

校准数据是通过data_loader提供的,batch_num是离线量化要跑的batch数,也属于一个超参,batch_num_list是该参数要搜索的list。比如batch_num=1,batch_size=4 表示使用batch_size=4的 data_loader 作校准数据,跑 1 个 batch 就采样完毕计算scale值完成量化。

}
logger.info("=" * 60)
if args is None:
Expand Down
131 changes: 79 additions & 52 deletions paddlenlp/trainer/trainer_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import math
import numpy as np
import inspect

import paddle
from paddle.utils import try_import
Expand Down Expand Up @@ -77,12 +78,19 @@ def compress(self,
else:
# Prefix of `export_model` is 'model'
self.args.input_filename_prefix = "model"
input_spec = [
paddle.static.InputSpec(shape=[None, None],
dtype="int64"), # input_ids
paddle.static.InputSpec(shape=[None, None],
dtype="int64") # token_type_ids
]
if 'token_type_ids' in self.train_dataset[0]:
input_spec = [
paddle.static.InputSpec(shape=[None, None],
dtype="int64"), # input_ids
paddle.static.InputSpec(shape=[None, None],
dtype="int64") # token_type_ids
]
else:
input_spec = [
paddle.static.InputSpec(shape=[None, None],
dtype="int64") # input_ids
]

input_dir = args.output_dir
export_model(model=self.model,
input_spec=input_spec,
Expand All @@ -106,7 +114,6 @@ def _dynabert(self, model, output_dir):
# Each batch is a dict.
train_dataloader = self.get_train_dataloader()
eval_dataloader = self.get_eval_dataloader(self.eval_dataset)

if "QuestionAnswering" in model.__class__.__name__:
eval_dataloader_with_label = self.get_eval_dataloader(
self.eval_examples)
Expand Down Expand Up @@ -291,8 +298,8 @@ def evaluate_qa(model, data_loader):
all_start_logits = []
all_end_logits = []
for batch in data_loader:
logits = model(batch['input_ids'],
batch['token_type_ids'],
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=[None, None])
if isinstance(model, OFA):
start_logits_tensor, end_logits_tensor = logits[0]
Expand Down Expand Up @@ -323,12 +330,12 @@ def evaluate_seq_cls(model, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
logits = model(batch['input_ids'],
batch['token_type_ids'],
attention_mask=[None, None])
labels = batch.pop("labels")
batch["attention_mask"] = [None, None]
logits = model(**batch)
if isinstance(model, OFA):
logits = logits[0]
correct = metric.compute(logits, batch['labels'])
correct = metric.compute(logits, labels)
metric.update(correct)
res = metric.accumulate()
logger.info("acc: %s, " % res)
Expand All @@ -341,8 +348,8 @@ def evaluate_token_cls(model, data_loader):
model.eval()
metric.reset()
for batch in data_loader:
logits = model(batch['input_ids'],
batch['token_type_ids'],
logits = model(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=[None, None])
if isinstance(model, OFA):
logits = logits[0]
Expand Down Expand Up @@ -382,9 +389,14 @@ def evaluate_token_cls(model, data_loader):
# and use this config in supernet training.
net_config = utils.dynabert_config(ofa_model, width_mult)
ofa_model.set_net_config(net_config)
logits, teacher_logits = ofa_model(batch['input_ids'],
batch['token_type_ids'],
attention_mask=[None, None])
if "token_type_ids" in batch:
logits, teacher_logits = ofa_model(
input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=[None, None])
else:
logits, teacher_logits = ofa_model(
batch['input_ids'], attention_mask=[None, None])
rep_loss = ofa_model.calc_distill_loss()
if isinstance(logits, tuple):
logit_loss = 0
Expand Down Expand Up @@ -474,10 +486,15 @@ def _dynabert_export(self, ofa_model):
for name, sublayer in origin_model_new.named_sublayers():
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(width_mult * sublayer.num_heads)
input_shape = [
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None, None], dtype='int64')
]
if 'token_type_ids':
input_shape = [
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None, None], dtype='int64')
]
else:
input_shape = [
paddle.static.InputSpec(shape=[None, None], dtype='int64')
]
pruned_infer_model_dir = os.path.join(model_dir, "pruned_model")

net = paddle.jit.to_static(origin_model_new, input_spec=input_shape)
Expand Down Expand Up @@ -506,15 +523,20 @@ def _post_training_quantization_grid_search(self, model_dir):
def _post_training_quantization(algo, batch_size, batch_nums):

def _batch_generator_func():
batch_data = [[], []]
has_token_type_ids = "token_type_ids" in self.eval_dataset[0]
batch_data = [[], []] if has_token_type_ids else [[]]
for data in self.eval_dataset:
batch_data[0].append(data['input_ids'])
batch_data[1].append(data['token_type_ids'])
if has_token_type_ids:
batch_data[1].append(data['token_type_ids'])
if len(batch_data[0]) == batch_size:
input_ids = Pad(axis=0, pad_val=0)(batch_data[0])
token_type_ids = Pad(axis=0, pad_val=0)(batch_data[1])
yield [input_ids, token_type_ids]
batch_data = [[], []]
if has_token_type_ids:
token_type_ids = Pad(axis=0, pad_val=0)(batch_data[1])
yield [input_ids, token_type_ids]
else:
yield [input_ids]
batch_data = [[], []] if has_token_type_ids else [[]]

post_training_quantization = PostTrainingQuantization(
executor=exe,
Expand Down Expand Up @@ -565,9 +587,10 @@ def auto_model_forward(self,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype

kwargs = locals()
wtype = self.encoder.layers[0].norm1.fn.weight.dtype if hasattr(
self.encoder.layers[0].norm1,
'fn') else self.encoder.layers[0].norm1.weight.dtype
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time."
Expand Down Expand Up @@ -600,32 +623,36 @@ def auto_model_forward(self,
attention_mask[0] = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(wtype) * -1e4, axis=[1, 2])

if "use_task_id" in self.config:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
task_type_ids=task_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length)
else:
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length)
embedding_kwargs_keys = inspect.signature(
self.embeddings.forward).parameters.keys()
embedding_kwargs = {}
for key in embedding_kwargs_keys:
if key in kwargs.keys():
embedding_kwargs[key] = kwargs[key]
embedding_kwargs["input_ids"] = input_ids

embedding_output = self.embeddings(**embedding_kwargs)

self.encoder._use_cache = use_cache # To be consistent with HF
encoder_outputs = self.encoder(embedding_output,
src_mask=attention_mask,
cache=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

encoder_kwargs_keys = inspect.signature(
self.encoder.forward).parameters.keys()
encoder_kwargs = {}
for key in encoder_kwargs_keys:
if key == "cache":
encoder_kwargs[key] = past_key_values
elif key == "src_mask":
encoder_kwargs[key] = attention_mask
elif key in kwargs:
encoder_kwargs[key] = kwargs[key]

encoder_outputs = self.encoder(embedding_output, **encoder_kwargs)
if isinstance(encoder_outputs, type(embedding_output)):
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
if hasattr(self, 'pooler'):
pooled_output = self.pooler(sequence_output)
else:
pooled_output = sequence_output[:, 0]
return (sequence_output, pooled_output)
else:
sequence_output = encoder_outputs[0]
Expand Down
Loading