Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bad case analysis for text classification #3385

Merged
merged 6 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions applications/text_classification/hierarchical/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

- 【标注成本高、标注样本较少的小样本场景】 👉 [提示学习层次分类方案](./few-shot#readme)

- 【标签类别不固定场景】 👉 [语义索引层次分类方案](./retrieval_based#readme)
- 【标签类别不固定场景、标签数量众多】 👉 [语义索引层次分类方案](./retrieval_based#readme)

<a name="快速开始"></a>

Expand Down Expand Up @@ -92,20 +92,13 @@ python3 -m pip install scikit-learn==1.0.2

```text
hierarchical/
├── few-shot # 小样本学习方案
├── retrieval_based # 语义索引方案
├── analysis # 分析模块
├── deploy # 部署
│   └── predictor # 离线部署
│   │ ├── infer.py # 测试脚本
│   │ ├── predictor.py 离线部署脚本
│   │ └── README.md # 离线部署使用说明
│ ├── paddle_serving # PaddleServing在线服务化部署
│   │ ├──config.yml # 服务端的配置文件
│   │ ├──rpc_client.py # 客户端预测脚本
│   │ ├──service.py # 服务端的脚本
│   │ └── README.md # 在线服务化部署使用说明
│   └── triton_serving # Triton在线服务化部署
│ ├── README.md # Triton部署使用说明
│ ├── seqcls_grpc_client.py # 客户端预测脚本
│   └── models
├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 静态图模型导出脚本
Expand All @@ -130,7 +123,7 @@ data/
└── data.txt # 待预测数据文件(可选)
```

**训练、开发、测试数据集文件:** 文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用英文逗号`','`分隔开。
**训练、开发、测试数据集文件:** 文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用英文逗号`','`分隔开,文本中避免出现tab符`'\t'`

- train.txt/dev.txt/test.txt 文件格式:
```text
Expand Down Expand Up @@ -267,10 +260,16 @@ checkpoint/
* 英文和中文以外语言的文本分类任务,推荐使用基于96种语言(涵盖法语、日语、韩语、德语、西班牙语等几乎所有常见语言)进行预训练的多语言预训练模型"ernie-m-base"、"ernie-m-large",详情请参见[ERNIE-M论文](https://arxiv.org/pdf/2012.15674.pdf)。
#### 2.4.2 训练评估与模型优化

训练后的模型我们可以使用 [模型分析模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
文本分类预测过程中常会遇到诸如"模型为什么会预测出错误的结果","如何提升模型的表现"等问题。[Analysis模块](./analysis) 提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。

<div align="center">
<img src="https://user-images.githubusercontent.com/63761690/195241942-70068989-df17-4f53-9f71-c189d8c5c88d.png" width="600">
</div>

**模型评估:** 训练后的模型我们可以使用 [Analysis模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:

```shell
python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_path "./bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_file "bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
```

输出打印示例:
Expand All @@ -297,15 +296,38 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32
预测错误的样本保存在bad_case.txt文件中:

```text
Prediction Label Text
组织关系,组织关系##解雇 组织关系,组织关系##加盟,组织关系##裁员 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
组织关系,组织关系##解雇 组织关系,组织关系##裁员 冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从
组织关系,组织关系##裁员 组织关系,组织关系##退出,组织关系##裁员 有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。
人生,人生##死亡,灾害/意外,灾害/意外##坍/垮塌 灾害/意外,灾害/意外##坍/垮塌 冲刺千亿的美的置业贵阳项目倒塌致8人死亡已责令全面停工
Text Label Prediction
据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇
冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇
6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易
有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员
...
```

**可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果,用于错误样本(bad case)分析,细节详见[训练评估与模型优化指南](analysis/README.md)。

- 单词级别可解释性分析,也即分析待预测样本中哪一些单词对模型预测结果起重要作用。以下图为例,用颜色深浅表示单词对预测结果的重要性。
<div align="center">
<img src="https://user-images.githubusercontent.com/63761690/195334753-78cc2dc8-a5ba-4460-9fde-3b1bb704c053.png" width="1000">
</div>

- 句子级别可解释性分析 ,也即分析对待预测样本的模型预测结果与训练集中中哪些样本有重要关系。下面的例子表明句子级别可解释性分析可以帮助理解待预测样本的预测结果与训练集中样本之间的关联。
```text
text: 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
predict label: 组织关系,组织关系##解雇
label: 组织关系,组织关系##加盟,组织关系##裁员
examples with positive influence
support1 text: 尼克斯官方今日宣布,他们已经裁掉了前锋扎克-欧文,后者昨日才与尼克斯签约。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.99357
support2 text: 活塞官方今日宣布,他们已经签下了克雷格-斯沃德,并且裁掉了托德-威瑟斯。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98344
support3 text: 孟菲斯灰熊今年宣布,球队已经签下后卫达斯蒂-汉纳斯(DustyHannahs,版头图)并裁掉马特-穆尼。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98219
...
```

模型表现常常受限于数据质量,在analysis模块中我们提供了基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)的稀疏数据筛选、脏数据清洗、数据增强三种优化方案助力开发者提升模型效果,更多模型评估和优化方案细节详见[训练评估与模型优化指南](analysis/README.md)。
**数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果,策略细节详见[训练评估与模型优化指南](analysis/README.md)。

- 稀疏数据筛选主要是解决数据不均衡、训练数据覆盖不足的问题,通过数据增强和数据标注两种方式解决这一问题。
- 脏数据清洗可以帮助开发者筛选训练集中错误标注的数据,对这些数据重新进行人工标注,得到标注正确的数据再重新进行训练。
- 数据增强策略提供多种数据增强方案,可以快速扩充数据,提高模型泛化性和鲁棒性。

#### 2.4.3 模型预测
训练结束后,输入待预测数据(data.txt)和类别标签对照列表(label.txt),使用训练好的模型进行,默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
Expand Down Expand Up @@ -337,7 +359,7 @@ python predict.py --device "gpu" --max_seq_length 128 --batch_size 32 --dataset_
python export_model.py --params_path ./checkpoint/ --output_path ./export
```

如果使用ERNIE M作为预训练模型,运行方式:
如果使用多语言模型 ERNIE M作为预训练模型,运行方式:
```shell
python export_model.py --params_path ./checkpoint/ --output_path ./export --multilingual
```
Expand Down Expand Up @@ -418,8 +440,6 @@ prune/

3. ERNIE Base、Medium、Mini、Micro、Nano的模型宽度(multi head数量)为12,ERNIE Xbase、Large 模型宽度(multi head数量)为16,保留比例`width_mult`乘以宽度(multi haed数量)应为整数。

4. **压缩API暂不支持多语言预训练模型ERNIE-M**,相关功能正在加紧开发中。

#### 2.5.3 部署方案

- 离线部署搭建请参考[离线部署](deploy/predictor/README.md)。
Expand Down
Loading