-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model Prohetnet #1698
Merged
Merged
Add model Prohetnet #1698
Changes from 7 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a5cf998
add Prohetnet model
d294270681 6f984b9
update prohetnet
d294270681 555ad31
update format
d294270681 74e7318
pre commit
d294270681 fb76a3b
add prophetnet example
d294270681 80e2dca
update tokenizer.py,run_train.sh,train_prophetnet.py
d294270681 7518275
remove evaluate/gigaword/__init__.py
d294270681 76385a7
Merge branch 'develop' into develop
smallv0221 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
# Prophetnet | ||
|
||
## 模型简介 | ||
|
||
ProphetNet(先知网络)是一种新型的 seq2seq 预训练模型。在训练时,Prophetnet 每一时刻将会学习同时预测未来的 N 个字符,这种自监督学习目标可以使得模型考虑未来更远的字符,防止模型对强局部相关(strong | ||
local correlation)过拟合。 | ||
|
||
本项目是 Prophetnet 在 PaddlePaddle 2.2 上开源实现的文本摘要的例子,包含了在 CNN/DailyMail 数据集,Gigaword 数据集上微调和生成的代码。 | ||
|
||
### 项目依赖 | ||
|
||
``` | ||
pip install -r requirements.txt | ||
python -m pip install paddlepaddle-gpu==2.2.2.post112 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html | ||
pip install paddlenlp==2.2.3 | ||
``` | ||
|
||
### 代码结构说明 | ||
|
||
以下是本项目主要代码结构及说明: | ||
|
||
```text | ||
├── train_prophetnet.py # 模型finetune主程序入口 | ||
├── generate.py # 模型生成主程序入口 | ||
├── eval.py # 生成结果评估入口 | ||
├── uncase_tokenize_data.py # 数据预处理 | ||
├── uncompress_data.sh # 数据解压脚本 | ||
├── run_train.sh # 模型训练脚本 | ||
├── run_eval.sh # 模型评估脚本 | ||
├── requirements.txt # 环境依赖文件 | ||
└── README.md # 文档说明 | ||
``` | ||
|
||
### 数据准备 | ||
|
||
GLGE 数据集下载:[链接](https://drive.google.com/file/d/1F4zppa9Gqrh6iNyVsZJkxfbm5waalqEA/view) | ||
|
||
GLGE 测试集下载:[链接](https://drive.google.com/file/d/11lDXIG87dChIfukq3x2Wx4r5_duCRm_J/view) | ||
|
||
将glge_public.tar与glge_hidden_v1.1.tar.gz放入到项目根目录下。 | ||
|
||
``` | ||
bash uncompress_data.sh | ||
``` | ||
|
||
### 下载预训练权重与词表 | ||
|
||
模型权重和词表[下载链接](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 提取码:o28q,下载后放入项目根目录。 | ||
|
||
### 数据预处理 | ||
|
||
``` | ||
python uncase_tokenize_data.py --dataset <DATASET> | ||
``` | ||
|
||
说明: | ||
|
||
- `<DATASET>`可选`cnndm`, `gigaword`. | ||
|
||
### 模型训练 | ||
|
||
``` | ||
bash run_train.sh <DATASET> | ||
``` | ||
|
||
或直接运行finetune程序 | ||
|
||
- cnndm: | ||
|
||
``` | ||
python train_prophetnet.py \ | ||
--dataset=cnndm \ | ||
--pretrained_model_path=./model_state.pdparams \ | ||
--batch_size=4 \ | ||
--epochs=4 \ | ||
--lr=0.0001 \ | ||
--warmup_init_lr=1e-07 \ | ||
--warmup_updates=1000 \ | ||
--clip_norm=0.1 \ | ||
--num_workers=4 \ | ||
--output_dir=./ckpt/cnndm | ||
``` | ||
|
||
- gigaword: | ||
|
||
``` | ||
python train_prophetnet.py \ | ||
--dataset=gigaword \ | ||
--pretrained_model_path=./model_state.pdparams \ | ||
--batch_size=16 \ | ||
--epochs=6 \ | ||
--lr=0.0001 \ | ||
--warmup_init_lr=1e-07 \ | ||
--warmup_updates=1000 \ | ||
--clip_norm=0.1 \ | ||
--num_workers=8 \ | ||
--output_dir=./ckpt/gigaword | ||
``` | ||
|
||
其中参数释义如下: | ||
|
||
- `dataset` 指定数据集,可选cnndm和gigaword | ||
|
||
- `pretrained_model_path` 本地预训练模型初始化权重文件路径,例如: ./model_state.pdparams。 | ||
|
||
- `batch_size` 表示训练样本批大小。 | ||
|
||
- `epochs` 表示训练轮数。 | ||
|
||
- `lr` 表示学习率 | ||
|
||
- `warmup_init_lr` 表示预热学习率 | ||
|
||
- `warmup_updates` 表示预热学习步数 | ||
|
||
- `clip_norm` 表示梯度裁剪 | ||
|
||
- `num_workers` 指定数据加载规模 | ||
|
||
- `output_idr` 指定微调结果权重存放路径 | ||
|
||
已经finetune好的模型权重: | ||
|
||
- cnndm : [链接](https://pan.baidu.com/s/1cemrUDxkqEW9raoasJ_VKw), 提取码:1egi | ||
|
||
- gigaword : [链接](https://pan.baidu.com/s/1qRH2FStT3vNQtDjZLkYJBQ), 提取码:on5v | ||
|
||
### 模型评估 | ||
|
||
使用prophetNet源码的[评估脚本](https://pan.baidu.com/s/1FOnd01rNvDJoONYegacq1Q), 此脚本依赖于pyrouge,需要提前安装rouge。 | ||
|
||
``` | ||
pip install git+https://github.com/pltrdy/pyrouge | ||
``` | ||
|
||
``` | ||
bash run_eval.sh <DATASET> | ||
``` | ||
|
||
或直接运行模型生成程序 | ||
|
||
- cnndm: | ||
|
||
``` | ||
python generate.py \ | ||
--dataset=cnndm \ | ||
--vocab_file=./prophetnet.tokenizer \ | ||
--output_path=./generate/cnndm/generate.txt \ | ||
--min_target_length=45 \ | ||
--max_target_length=110 \ | ||
--decode_strategy=beam_search \ | ||
--num_beams=4 \ | ||
--length_penalty=1.2 \ | ||
--batch_size=16 \ | ||
--ignore_pad_token_for_loss=True \ | ||
--early_stopping=True \ | ||
--logging_steps=100 \ | ||
--device=gpu | ||
|
||
python eval.py --dataset cnndm --generated ./generate/cnndm/generate.txt | ||
``` | ||
|
||
- gigaword: | ||
|
||
``` | ||
python generate.py \ | ||
--dataset=gigaword \ | ||
--vocab_file=./prophetnet.tokenizer \ | ||
--output_path=./generate/gigaword/generate.txt \ | ||
--min_target_length=1 \ | ||
--max_target_length=200 \ | ||
--decode_strategy=beam_search \ | ||
--num_beams=4 \ | ||
--length_penalty=1.6 \ | ||
--batch_size=16 \ | ||
--ignore_pad_token_for_loss=True \ | ||
--early_stopping=True \ | ||
--logging_steps=100 \ | ||
--device=gpu | ||
|
||
python eval.py --dataset gigaword --generated ./generate/gigaword/generate.txt | ||
``` | ||
|
||
其中参数释义如下: | ||
|
||
- `dataset` 指定数据集,可选cnndm和gigaword | ||
|
||
- `vocab_file` 指定词表文件 | ||
|
||
- `output_path` 指定生成结果存放路径 | ||
|
||
- `min_target_length` 指定解码最短长度 | ||
|
||
- `max_target_length` 指定解码最大长度 | ||
|
||
- `decode_strategy` 指定解码策略 | ||
|
||
- `num_beams` 指定beam_search解码宽度 | ||
|
||
- `length_penalty` 指定beam_search解码的长度指数惩罚 | ||
|
||
- `batch_size` 指定评估样本批大小 | ||
|
||
- `ignore_pad_token_for_loss` 表示计算loss时忽略padding | ||
|
||
- `early_stopping` 指定生成结束符是否停止预测 | ||
|
||
- `logging_steps` 指定日志打印间隔 | ||
|
||
- `device` 指定使用设备 | ||
|
||
### 微调测试精度 | ||
|
||
> #### 在CNN/DM数据集的测试效果如下表。 | ||
|
||
|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L| | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
|prophetnet-large-uncased|Adam|4|CNN/DM|44.17|21.24|41.36| | ||
|
||
> #### 在gigaword数据集的测试效果如下表。 | ||
|
||
|网络 |opt|batch_size|数据集|ROUGE_1|ROUGE_2|ROUGE_L| | ||
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | | ||
|prophetnet-large-uncased|Adam|16|gigaword|38.92|19.81|36.06| | ||
|
||
### 实验环境 | ||
|
||
- GPU RTX3090 * 1, CPU Intel i7-11700k | ||
- Ubuntu 18.04 | ||
|
||
### 参考文献 | ||
|
||
1. Qi W, Yan Y, Gong Y, et al. Prophetnet: Predicting future n-gram for sequence-to-sequence pre-training[J]. arXiv | ||
preprint arXiv:2001.04063, 2020. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import argparse | ||
import os | ||
import re | ||
import sys | ||
from os import listdir | ||
from os.path import isfile, join | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--dataset", | ||
type=str, | ||
help="choose from all, or 1 of 8 dataset like cnndm, gigaword etc.") | ||
parser.add_argument("--generated", type=str, help="generated output file.") | ||
|
||
args = parser.parse_args() | ||
|
||
data_root_path = 'data' | ||
|
||
support_dataset = ['cnndm', 'gigaword'] | ||
files2rouge_template = '.*ROUGE-1 Average_F: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2 Average_F: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L Average_F: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*' | ||
# gigaword_template='.*ROUGE-1: (?P<rouge1_f>\d+(\.\d*)?|\.\d+).*ROUGE-2: (?P<rouge2_f>\d+(\.\d*)?|\.\d+).*ROUGE-L: (?P<rougeL_f>\d+(\.\d*)?|\.\d+).*' | ||
qg_template = '.*Bleu_4: (?P<bleu4>\d+(\.\d*)?|\.\d+).*METEOR: (?P<meteor>\d+(\.\d*)?|\.\d+).*ROUGE_L: (?P<rougeL>\d+(\.\d*)?|\.\d+).*' | ||
personachat_template = '.*?(?P<d1>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*?(?P<d2>[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?).*Bleu_1: (?P<bleu1>\d+(\.\d*)?|\.\d+).*Bleu_2: (?P<bleu2>\d+(\.\d*)?|\.\d+).*' | ||
|
||
|
||
def scale_up(d): | ||
return {k: float(d[k]) * 100 for k in d.keys()} | ||
|
||
|
||
def eval_one_dataset(): | ||
golden_file = f"{data_root_path}/{args.dataset}_data/test.tgt" | ||
|
||
eval_template = { | ||
'cnndm': | ||
f"python ./evaluate/cnndm/postprocess_cnn_dm.py --generated {generated_file} --golden {golden_file}", | ||
'gigaword': | ||
f"python ./evaluate/gigaword/eval.py --perl --pred {generated_file} --gold {golden_file}", | ||
} | ||
|
||
cmd = eval_template[args.dataset] | ||
try: | ||
output = os.popen(cmd).read() | ||
if args.dataset in ['cnndm', 'gigaword']: | ||
d = re.search(files2rouge_template, | ||
output.replace("\n", " ")).groupdict() | ||
d = scale_up(d) | ||
print( | ||
f"{args.dataset}\trouge1/rouge2/rougeL\t{d['rouge1_f']:.2f}/{d['rouge2_f']:.2f}/{d['rougeL_f']:.2f}" | ||
) | ||
except: | ||
print("Unexpected error:", sys.exc_info()[0]) | ||
print(f"{args.dataset} evaluate failed!") | ||
|
||
|
||
if args.dataset != 'all': | ||
generated_file = args.generated | ||
eval_one_dataset() | ||
else: | ||
output_root_path = args.generated | ||
onlyfolders = [ | ||
f for f in listdir(output_root_path) | ||
if not isfile(join(args.generated, f)) | ||
] | ||
for dataset in support_dataset: | ||
for folder in onlyfolders: | ||
if folder.startswith(dataset): | ||
for hypo_file in listdir(args.generated + '/' + folder): | ||
if 'hypo' in hypo_file or 'score' in hypo_file: | ||
generated_file = args.generated + '/' + folder + '/' + hypo_file | ||
print(f"{dataset}\tpredict_file:{generated_file}") | ||
args.dataset = dataset | ||
args.gnerated = generated_file | ||
eval_one_dataset() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用warmup_steps比较好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改