Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model Prohetnet #1698

Merged
merged 8 commits into from
Mar 7, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions examples/text_summarization/prophetnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# Prophetnet

## 模型简介

ProphetNet(先知网络)是一种新型的 seq2seq 预训练模型。在训练时,Prophetnet 每一时刻将会学习同时预测未来的 N 个字符,这种自监督学习目标可以使得模型考虑未来更远的字符,防止模型对强局部相关(strong
local correlation)过拟合。

本项目是 Prophetnet 在 PaddlePaddle 2.2 上开源实现的文本摘要的例子,包含了在 CNN/DailyMail 数据集,Gigaword 数据集上微调和生成的代码。

### 项目依赖

```
pip install -r requirements.txt
python -m pip install paddlepaddle-gpu==2.2.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
pip install paddlenlp==2.2.3
```

### 代码结构说明

以下是本项目主要代码结构及说明:

```text
├── train_prophetnet.py # 模型finetune主程序入口
├── generate.py # 模型生成主程序入口
├── eval.py # 生成结果评估入口
├── uncase_tokenize_data.py # 数据预处理
├── uncompress_data.sh # 数据解压脚本
├── run_train.sh # 模型训练脚本
├── run_eval.sh # 模型评估脚本
├── requirements.txt # 环境依赖文件
└── README.md # 文档说明
```

### 数据准备

GLGE 数据集下载:[链接](https://drive.google.com/file/d/1F4zppa9Gqrh6iNyVsZJkxfbm5waalqEA/view)

GLGE 测试集下载:[链接](https://drive.google.com/file/d/11lDXIG87dChIfukq3x2Wx4r5_duCRm_J/view)

将glge_public.tar与glge_hidden_v1.1.tar.gz放入到项目根目录下。

```
bash uncompress_data.sh
```

### 下载预训练权重与词表

模型权重和词表[下载链接](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 提取码:o28q,下载后放入项目根目录。

### 数据预处理

```
python uncase_tokenize_data.py --dataset <DATASET>
```

说明:

- `<DATASET>`可选`cnndm`, `gigaword`.

### 模型训练

```
bash run_train.sh <DATASET>
```

或直接运行finetune程序

- cnndm:

```
python train_prophetnet.py \
--dataset=cnndm \
--pretrained_model_path=./model_state.pdparams \
--batch_size=4 \
--epochs=4 \
--lr=0.0001 \
--warmup_init_lr=1e-07 \
--warmup_updates=1000 \
--clip_norm=0.1 \
--num_workers=4 \
--output_dir=./ckpt/cnndm
```

- gigaword:

```
python train_prophetnet.py \
--dataset=gigaword \
--pretrained_model_path=./model_state.pdparams \
--batch_size=16 \
--epochs=6 \
--lr=0.0001 \
--warmup_init_lr=1e-07 \
--warmup_updates=1000 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用warmup_steps比较好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

--clip_norm=0.1 \
--num_workers=8 \
--output_dir=./ckpt/gigaword
```

其中参数释义如下:

- `dataset` 指定数据集,可选cnndm和gigaword

- `pretrained_model_path` 本地预训练模型初始化权重文件路径,例如: ./model_state.pdparams。

- `batch_size` 表示训练样本批大小。

- `epochs` 表示训练轮数。

- `lr` 表示学习率

- `warmup_init_lr` 表示预热学习率

- `warmup_updates` 表示预热学习步数

- `clip_norm` 表示梯度裁剪

- `num_workers` 指定数据加载规模

- `output_idr` 指定微调结果权重存放路径

已经finetune好的模型权重:

- cnndm : [链接](https://pan.baidu.com/s/1cemrUDxkqEW9raoasJ_VKw), 提取码:1egi

- gigaword : [链接](https://pan.baidu.com/s/1qRH2FStT3vNQtDjZLkYJBQ), 提取码:on5v

### 模型评估

使用prophetNet源码的[评估脚本](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 此脚本依赖于pyrouge,需要提前安装rouge。

```
pip install git+https://github.com/pltrdy/pyrouge
```

```
bash run_eval.sh <DATASET>
```

或直接运行模型生成程序

- cnndm:

```
python generate.py \
--dataset=cnndm \
--vocab_file=./prophetnet.tokenizer \
--output_path=./generate/cnndm/generate.txt \
--min_target_length=45 \
--max_target_length=110 \
--decode_strategy=beam_search \
--num_beams=4 \
--length_penalty=1.2 \
--batch_size=16 \
--ignore_pad_token_for_loss=True \
--early_stopping=True \
--logging_steps=100 \
--device=gpu

python eval.py --dataset cnndm --generated ./generate/cnndm/generate.txt
```

- gigaword:

```
python generate.py \
--dataset=gigaword \
--vocab_file=./prophetnet.tokenizer \
--output_path=./generate/gigaword/generate.txt \
--min_target_length=1 \
--max_target_length=200 \
--decode_strategy=beam_search \
--num_beams=4 \
--length_penalty=1.6 \
--batch_size=16 \
--ignore_pad_token_for_loss=True \
--early_stopping=True \
--logging_steps=100 \
--device=gpu

python eval.py --dataset gigaword --generated ./generate/gigaword/generate.txt
```

其中参数释义如下:

- `dataset` 指定数据集,可选cnndm和gigaword

- `vocab_file` 指定词表文件

- `output_path` 指定生成结果存放路径

- `min_target_length` 指定解码最短长度

- `max_target_length` 指定解码最大长度

- `decode_strategy` 指定解码策略

- `num_beams` 指定beam_search解码宽度

- `length_penalty` 指定beam_search解码的长度指数惩罚

- `batch_size` 指定评估样本批大小

- `ignore_pad_token_for_loss` 表示计算loss时忽略padding

- `early_stopping` 指定生成结束符是否停止预测

- `logging_steps` 指定日志打印间隔

- `device` 指定使用设备

### 微调测试精度

> #### 在CNN/DM数据集的测试效果如下表。

|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|prophetnet-large-uncased|Adam|4|CNN/DM|44.17|21.24|41.36|

> #### 在gigaword数据集的测试效果如下表。

|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L|
| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|prophetnet-large-uncased|Adam|16|gigaword|38.92|19.81|36.06|

### 实验环境

- GPU RTX3090 * 1, CPU Intel i7-11700k
- Ubuntu 18.04

### 参考文献

1. Qi W, Yan Y, Gong Y, et al. Prophetnet: Predicting future n-gram for sequence-to-sequence pre-training[J]. arXiv
preprint arXiv:2001.04063, 2020.
73 changes: 73 additions & 0 deletions examples/text_summarization/prophetnet/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
import os
import re
import sys
from os import listdir
from os.path import isfile, join

parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset",
type=str,
help="choose from all, or 1 of 8 dataset like cnndm, gigaword etc.")
parser.add_argument("--generated", type=str, help="generated output file.")

args = parser.parse_args()

data_root_path = 'data'

support_dataset = ['cnndm', 'gigaword']
files2rouge_template = '.*ROUGE-1 Average_F: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2 Average_F: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L Average_F: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*'
# gigaword_template='.*ROUGE-1: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*'
qg_template = '.*Bleu_4: (?P<bleu4>\d+(\.\d*)?|\.\d+).*METEOR: (?P<meteor>\d+(\.\d*)?|\.\d+).*ROUGE_L: (?P<rougeL>\d+(\.\d*)?|\.\d+).*'
personachat_template = '.*?(?P<d1>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*?(?P<d2>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*Bleu_1: (?P<bleu1>\d+(\.\d*)?|\.\d+).*Bleu_2: (?P<bleu2>\d+(\.\d*)?|\.\d+).*'


def scale_up(d):
return {k: float(d[k]) * 100 for k in d.keys()}


def eval_one_dataset():
golden_file = f"{data_root_path}/{args.dataset}_data/test.tgt"

eval_template = {
'cnndm':
f"python ./evaluate/cnndm/postprocess_cnn_dm.py --generated {generated_file} --golden {golden_file}",
'gigaword':
f"python ./evaluate/gigaword/eval.py --perl --pred {generated_file} --gold {golden_file}",
}

cmd = eval_template[args.dataset]
try:
output = os.popen(cmd).read()
if args.dataset in ['cnndm', 'gigaword']:
d = re.search(files2rouge_template,
output.replace("\n", " ")).groupdict()
d = scale_up(d)
print(
f"{args.dataset}\trouge1/rouge2/rougeL\t{d['rouge1_f']:.2f}/{d['rouge2_f']:.2f}/{d['rougeL_f']:.2f}"
)
except:
print("Unexpected error:", sys.exc_info()[0])
print(f"{args.dataset} evaluate failed!")


if args.dataset != 'all':
generated_file = args.generated
eval_one_dataset()
else:
output_root_path = args.generated
onlyfolders = [
f for f in listdir(output_root_path)
if not isfile(join(args.generated, f))
]
for dataset in support_dataset:
for folder in onlyfolders:
if folder.startswith(dataset):
for hypo_file in listdir(args.generated + '/' + folder):
if 'hypo' in hypo_file or 'score' in hypo_file:
generated_file = args.generated + '/' + folder + '/' + hypo_file
print(f"{dataset}\tpredict_file:{generated_file}")
args.dataset = dataset
args.gnerated = generated_file
eval_one_dataset()
Loading