-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add hierarchical text classification #2501
Conversation
@@ -0,0 +1,251 @@ | |||
# 多标签层次分类任务 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- pre-commit 没有安装,code style check没有过
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已安装
## 层次分类任务介绍 | ||
|
||
多标签层次分类任务指自然语言处理任务中,每个样本具有多个标签标记,并且标签集合中存在预定义的树状结构或有向无环图结构,多标签层次分类需要充分考虑标签集之间的层次结构关系来预测层次化预测结果。在现实场景中,大量的数据如新闻分类、专利分类、学术论文分类等标签集合存在预定义的层次化结构,需要利用算法为文本自动标注更细粒度和更准确的标签。本项目中采用通用多标签层次分类算法将每个结点的标签路径视为一个多类标签,使用单个分类器进行决策。 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块能不能用一些图片或者例子来示意层次分类了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
## 层次分类任务介绍 | ||
|
||
多标签层次分类任务指自然语言处理任务中,每个样本具有多个标签标记,并且标签集合中存在预定义的树状结构或有向无环图结构,多标签层次分类需要充分考虑标签集之间的层次结构关系来预测层次化预测结果。在现实场景中,大量的数据如新闻分类、专利分类、学术论文分类等标签集合存在预定义的层次化结构,需要利用算法为文本自动标注更细粒度和更准确的标签。本项目中采用通用多标签层次分类算法将每个结点的标签路径视为一个多类标签,使用单个分类器进行决策。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多类标签->多分类标签
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
## 模型微调 | ||
|
||
我们以层次分类公开数据集WOS(Web of Science)为示例,在训练集上进行模型训练,并在开发集上验证,开发集中选出的最优的模型在测试集上进行评估。WOS数据集是一个层次文本分类数据集,包含7个父类和134子类,每个样本对应一个父类标签和子类标签,父类标签和子类标签间具有层次结构关系,WOS数据集已内置到PaddleNLP中。 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
,WOS数据集已内置到PaddleNLP中。 这句话可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去掉
单卡训练 | ||
```shell | ||
$ unset CUDA_VISIBLE_DEVICES | ||
$ python train.py --early_stop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是单卡训练不应该 unset CUDA_VISIBLE_DEVICES
, 多卡的时候需要unset一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已去掉
指定GPU卡号/多卡训练 | ||
```shell | ||
$ unset CUDA_VISIBLE_DEVICES | ||
$ python -m paddle.distributed.launch --gpus "0,1" train.py --early_stop |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议 --gpus "0,1" -> --gpus "0", 因为大多数用户是没有多卡,可以备注一下 如果用多卡,可以指定0,1,这样的数字
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改:
指定GPU卡号/多卡训练
unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0" train.py --early_stop
使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"
**NOTE:** | ||
* 如需恢复模型训练,则可以设置 `init_from_ckpt` , 如 `init_from_ckpt=checkpoints/macro/model_state.pdparams` 。 | ||
* 如需训练中文层次分类任务,只需更换预训练模型参数 `model_name` 。中文训练任务推荐使用"ernie-3.0-base-zh",更多可选模型可参考[Transformer预训练模型](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/index.html#transformer)。 | ||
* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如 `pip install sentencepiece`。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块话去掉把,后续我们建议用户使用的ERNIE-3.0-tiny,使用的tokenizer是不需要sentencepiece
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存开发集上最佳 Macro F1 值和最佳 Micro F1 值的模型在指定的 `save_dir` 中,保存模型文件结构如下所示: | ||
|
||
```text |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有个疑问,一般情况下如果选择macro和micro两个其中的模型了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,默认保存最佳macro f1模型参数
|
||
* `params_path`:动态图训练保存的参数路径;默认为"./checkpoint/macro/model_state.pdparams"。 | ||
* `output_path`:静态图图保存的参数路径;默认为"./export"。 | ||
* `dataset`:训练数据集;默认为wos数据集。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是导出模型的话应该是不需要dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改,使用num_classes
来确定AutoModelForSequenceClassification中类别数:
num_classes
:任务标签类别数;默认为wos数据集类别数141。
|
||
```shell | ||
python deploy/paddle2onnx/infer.py --model_path_prefix ./export/wos/1/float32 | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本的模型输入会有点怪,因为这个时候还没有进行模型的裁减的工作
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例export_model.py导出的静态图模型可以用于onnxruntime推理
|
||
启动裁剪: | ||
```shell | ||
$ python prune.py --output_dir ./export |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
裁剪这块的输入模型是什么了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
启动裁剪:
python prune.py --output_dir ./prune --params_dir ./checkpoint/model_state.pdparams
* `params_dir`:待预测模型参数文件;默认为"./checkpoint/macro/model_state.pdparams"。 | ||
* `model_name_or_path`:选择预训练模型;默认为"bert-base-uncased"。 | ||
|
||
以上参数都可通过 `python xxx.py --dataset xx --params_dir xx` 的方式传入) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xxx.py -> prune.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改:
以上参数都可通过 python prune.py --dataset xx --params_dir xx
的方式传入)
|
||
1. 数据集:WOS(英文层次分类数据集) | ||
|
||
2. 计算卡:V100、CUDA11.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.物理机环境
系统: CentOS Linux release 7.5.1804
GPU: Tesla V100-SXM2-32GB * 8
CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40
CUDA: 11
cuDNN: 8.0.4
Driver Version: 450.80.02
内存: 502 GB
GPU相关信息再细致一点,参考上面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改:
-
物理机环境
系统: CentOS Linux release 7.7.1908 (Core)
GPU: Tesla V100-SXM2-32GB * 8
CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
CUDA: 11.2
cuDNN: 8.1.0
Driver Version: 460.27.04
内存: 630 GB
| BERT base | 86.06 | 81.29 | 8.80 | | ||
| BERT base+裁剪(3/4) | 86.83(+0.77) | 81.08(-0.21) | 6.85 | | ||
| BERT base+裁剪(2/3) | 86.77(+0.71) | 80.48(-0.81) | 5.98 | | ||
| BERT base+裁剪(1/4) | 86.40(+0.34) | 80.79(-0.5) | 2.51 | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块可能最好是基于ERNIE-2.0-en来做实验,这里的实验细节没有看明白, BERT base+裁剪(3/4) 裁剪越多,latency最高?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
裁剪后面数值为保留比例,已更换ernie-2.0-base-en进行实验:
Micro F1 | Macro F1 | latency(ms) | |
---|---|---|---|
ERNIE 2.0 | 85.71 | 80.82 | 8.80 |
ERNIE 2.0+裁剪(保留比例3/4) | 86.83(+1.12) | 81.78(+0.96) | 6.85 |
ERNIE 2.0+裁剪(保留比例2/3) | 86.74(+1.03) | 81.64(+0.82) | 5.98 |
ERNIE 2.0+裁剪(保留比例1/4) | 85.79(+0.08) | 79.53(-1.29) | 2.51 |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的import要符合PE8, 这块可以看看
一般按照这个规则来,标准库模块,第三方模块,自用模块
paddle这种一般都是向后排
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
) | ||
parser.add_argument( | ||
'--model_name', | ||
default="bert-base-uncased", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块换成ERNIE EN系列的模型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
if paddle.distributed.get_world_size() > 1: | ||
paddle.distributed.init_parallel_env() | ||
|
||
# load and preprocess dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释手字母大写,整体都改一下了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
# batchify dataset | ||
collate_fn = DataCollatorWithPadding(tokenizer) | ||
train_batch_sampler = BatchSampler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块有点问题,如果是多卡时,应该使用的DistributedBatchSampler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改:
if paddle.distributed.get_world_size() > 1:
train_batch_sampler = DistributedBatchSampler(train_ds, batch_size=args.batch_size, shuffle=True)
else:
train_batch_sampler = BatchSampler(train_ds, batch_size=args.batch_size, shuffle=True)
lr_scheduler.step() | ||
optimizer.clear_grad() | ||
|
||
global_step += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是因为step是从1开始,这块的global_step这块的逻辑是不是有点不太对
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
global_step是从0开始
early_stop_count = 0 | ||
best_micro_f1_score = micro_f1_score | ||
model._layers.save_pretrained(save_best_micro_path) | ||
tokenizer.save_pretrained(save_best_micro_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save的逻辑可能再讨论一下,是不是mean值会更好,放出太多的模型会让用户比较迷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
选择保留最佳macro f1模型
tokenizer.save_pretrained(save_best_micro_path) | ||
|
||
|
||
def test(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里比较有疑问的是,和evaluate的区别是什么,看这里加载的模型是marco f1值最好的模型,这里的考虑是什么了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test()是打算评测测试集表现,已去掉test函数
if step % 100 == 0: | ||
logger.info("step %d, %d samples processed" % | ||
(step, step * args.batch_size)) | ||
metric.report() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的report里面看起来是使用了sklearn相关的函数,在requirement没有体现出来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已新加requirements.txt
from paddlenlp.utils.log import logger | ||
|
||
|
||
# 构建验证集evaluate函数 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释改成英文注释,统一一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import关系要符合一下PE8标准
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@paddle.no_grad() | ||
def evaluate(model, criterion, metric, data_loader): | ||
""" | ||
Given a dataset, it evals model and computes the metric. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
evals看起来没有这个单测,eval->evaluate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为evaluates
import yaml | ||
import functools | ||
from typing import Optional | ||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的import关系再check一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
from metric import MetricReport | ||
|
||
nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward | ||
nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的函数变换,可以问一下佳琪的原因,然后注释一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加注释:
# Paddleslim will modify MultiHeadAttention.forward and MultiHeadAttention._prepare_qkv
# Original forward and _prepare_qkv should be saved before import paddleslim
nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward
nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv
@@ -0,0 +1,58 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的文件改一下吧,和之前的一些保持一致,export.py->export_model.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改脚本名
help="The path to model parameters to be loaded.") | ||
parser.add_argument("--output_path", type=str, default='./export', | ||
help="The path of model parameter in static graph to be saved.") | ||
parser.add_argument("--dataset", default="wos", type=str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset没有使用,可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
args = parser.parse_args() | ||
|
||
|
||
def predict(data, label_list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个@paddle.no_grad,节省显存占用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加上
results.append(labels) | ||
|
||
for idx, text in enumerate(data): | ||
label_name = [label_list[r] for r in results[idx]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里对结果进行解析的过程中,回到之前的问题,因为这次是一个层次分类的任务,最终输出看起来是多标签分类的结果
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
for r in results[idx]:
if r < 7:
level1.append(label_list[r])
else:
level2.append(label_list[r])
print('predicted result:')
print('level 1 : {} level 2 : {}'.format(', '.join(level1), ', '.join(
level2)))
## 环境准备 | ||
|
||
模型转换与ONNXRuntime预测部署依赖Paddle2ONNX和ONNXRuntime,Paddle2ONNX支持将Paddle模型转化为ONNX模型格式,算子目前稳定支持导出ONNX Opset 7~15,更多细节可参考:[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的依赖安装是不是也要增加一下paddle2onnx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddlenlp安装的时候已经安装paddle2onnx了
|
||
2. 计算卡:V100、CUDA11.2 | ||
|
||
3. CPU 信息:Intel(R) Xeon(R) Gold 6271C CPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的硬件设备信息,根据trianer再正式一点
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改:
2. 物理机环境
系统: CentOS Linux release 7.7.1908 (Core)
GPU: Tesla V100-SXM2-32GB * 8
CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz
CUDA: 11.2
cuDNN: 8.1.0
Driver Version: 460.27.04
内存: 630 GB
|
||
import paddle | ||
import argparse | ||
from predictor import Predictor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的import往下摞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
import onnxruntime as ort | ||
from paddlenlp.transformers import AutoTokenizer | ||
import paddle.nn.functional as F | ||
from sklearn.metrics import f1_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
咱们的框架是有F1只算的 AccuracyAndF1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sklearn可以计算macro和micro f1
def printer(self, infer_result, input_datas): | ||
label = infer_result["label"] | ||
confidence = infer_result["confidence"] | ||
for i, input_data in enumerate(input_datas): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里看看,能不能改成logger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
按照comment修改了代码和readme,基于ernie-2.0-en的性能和精度评测结果后续补上,serving部署部分明后两天补上。 |
已增加triton serving部署部分代码和文档说明 |
完善triton serving部分,并补充ERNIE 2.0在wos数据集上表现和性能(裁剪,fp16,int8) |
由于pre-commit-config.yaml变化,更新了yapf版本,导致代码codestyle不通过。通过merge develop分支最新代码,更新pre-commit文件 |
新增层次分类读取本地数据方式,修改paddlenlp/dataset/wos.py以支持以内置数据集格式读取本地数据集。目前本地数据集读取支持树状图和有向无环图标签层次结构,支持数据集层次标签不同深度,用训练、预测、部署仅需提供设定格式本地数据集路径和参数配置。 |
Traceback (most recent call last): 虽然数据可以读取本地了, 但还是报错 |
utils中preprocess_function.py已经发生变动,如需使用请使用最新代码(其他多个文件也在输入输出部分出现变化,请一并拉取) |
W0624 11:18:51.985149 25272 gpu_context.cc:306] device: 0, cuDNN Version: 8.0. bug是不是有点太多了 |
谢谢,大佬, 拉取最新的后 跑通了。 请问可以应用于商品同款匹配嘛 |
本层次文本分类任务目前主要用于对文本片段进行分类,确定文本在不同层次标签中所属的类别。同款商品匹配不建议使用文本分类,中文商品匹配可以参考application中语义索引任务进行搭建 |
什么时候有中文的 文本分类, 怎么示例用英文的呢? |
|
加入Paddle Serving部署 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
什么时候合并到主分支 ,刚在本地运行成功了 |
PR types
Others
PR changes
Others
Description
新增层次分类算法,包含训练,预测,静态图导出,裁剪等算法,支持paddle serving, triton, onnxruntime多种部署方式。
以下是本项目主要代码结构及说明: