diff --git a/applications/text_classification/hierarchical/README.md b/applications/text_classification/hierarchical/README.md
index 2585783e6d44..fecb8ecbe3d0 100644
--- a/applications/text_classification/hierarchical/README.md
+++ b/applications/text_classification/hierarchical/README.md
@@ -36,7 +36,7 @@
- 【标注成本高、标注样本较少的小样本场景】 👉 [提示学习层次分类方案](./few-shot#readme)
-- 【标签类别不固定场景】 👉 [语义索引层次分类方案](./retrieval_based#readme)
+- 【标签类别不固定场景、标签数量众多】 👉 [语义索引层次分类方案](./retrieval_based#readme)
@@ -92,20 +92,13 @@ python3 -m pip install scikit-learn==1.0.2
```text
hierarchical/
+├── few-shot # 小样本学习方案
+├── retrieval_based # 语义索引方案
+├── analysis # 分析模块
├── deploy # 部署
│ └── predictor # 离线部署
-│ │ ├── infer.py # 测试脚本
-│ │ ├── predictor.py 离线部署脚本
-│ │ └── README.md # 离线部署使用说明
│ ├── paddle_serving # PaddleServing在线服务化部署
-│ │ ├──config.yml # 服务端的配置文件
-│ │ ├──rpc_client.py # 客户端预测脚本
-│ │ ├──service.py # 服务端的脚本
-│ │ └── README.md # 在线服务化部署使用说明
│ └── triton_serving # Triton在线服务化部署
-│ ├── README.md # Triton部署使用说明
-│ ├── seqcls_grpc_client.py # 客户端预测脚本
-│ └── models
├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 静态图模型导出脚本
@@ -130,7 +123,7 @@ data/
└── data.txt # 待预测数据文件(可选)
```
-**训练、开发、测试数据集文件:** 文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用英文逗号`','`分隔开。
+**训练、开发、测试数据集文件:** 文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用英文逗号`','`分隔开,文本中避免出现tab符`'\t'`。
- train.txt/dev.txt/test.txt 文件格式:
```text
@@ -267,10 +260,16 @@ checkpoint/
* 英文和中文以外语言的文本分类任务,推荐使用基于96种语言(涵盖法语、日语、韩语、德语、西班牙语等几乎所有常见语言)进行预训练的多语言预训练模型"ernie-m-base"、"ernie-m-large",详情请参见[ERNIE-M论文](https://arxiv.org/pdf/2012.15674.pdf)。
#### 2.4.2 训练评估与模型优化
-训练后的模型我们可以使用 [模型分析模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
+文本分类预测过程中常会遇到诸如"模型为什么会预测出错误的结果","如何提升模型的表现"等问题。[Analysis模块](./analysis) 提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+
+
+
+
+**模型评估:** 训练后的模型我们可以使用 [Analysis模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
```shell
-python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_path "./bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
+python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_file "bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
```
输出打印示例:
@@ -297,15 +296,38 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32
预测错误的样本保存在bad_case.txt文件中:
```text
-Prediction Label Text
-组织关系,组织关系##解雇 组织关系,组织关系##加盟,组织关系##裁员 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
-组织关系,组织关系##解雇 组织关系,组织关系##裁员 冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从
-组织关系,组织关系##裁员 组织关系,组织关系##退出,组织关系##裁员 有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。
-人生,人生##死亡,灾害/意外,灾害/意外##坍/垮塌 灾害/意外,灾害/意外##坍/垮塌 冲刺千亿的美的置业贵阳项目倒塌致8人死亡已责令全面停工
+Text Label Prediction
+据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇
+冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇
+6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易
+有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员
+...
+```
+
+**可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果,用于错误样本(bad case)分析,细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 单词级别可解释性分析,也即分析待预测样本中哪一些单词对模型预测结果起重要作用。以下图为例,用颜色深浅表示单词对预测结果的重要性。
+
+
+
+
+- 句子级别可解释性分析 ,也即分析对待预测样本的模型预测结果与训练集中中哪些样本有重要关系。下面的例子表明句子级别可解释性分析可以帮助理解待预测样本的预测结果与训练集中样本之间的关联。
+```text
+text: 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
+predict label: 组织关系,组织关系##解雇
+label: 组织关系,组织关系##加盟,组织关系##裁员
+examples with positive influence
+support1 text: 尼克斯官方今日宣布,他们已经裁掉了前锋扎克-欧文,后者昨日才与尼克斯签约。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.99357
+support2 text: 活塞官方今日宣布,他们已经签下了克雷格-斯沃德,并且裁掉了托德-威瑟斯。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98344
+support3 text: 孟菲斯灰熊今年宣布,球队已经签下后卫达斯蒂-汉纳斯(DustyHannahs,版头图)并裁掉马特-穆尼。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98219
...
```
-模型表现常常受限于数据质量,在analysis模块中我们提供了基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)的稀疏数据筛选、脏数据清洗、数据增强三种优化方案助力开发者提升模型效果,更多模型评估和优化方案细节详见[训练评估与模型优化指南](analysis/README.md)。
+**数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果,策略细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 稀疏数据筛选主要是解决数据不均衡、训练数据覆盖不足的问题,通过数据增强和数据标注两种方式解决这一问题。
+- 脏数据清洗可以帮助开发者筛选训练集中错误标注的数据,对这些数据重新进行人工标注,得到标注正确的数据再重新进行训练。
+- 数据增强策略提供多种数据增强方案,可以快速扩充数据,提高模型泛化性和鲁棒性。
#### 2.4.3 模型预测
训练结束后,输入待预测数据(data.txt)和类别标签对照列表(label.txt),使用训练好的模型进行,默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
@@ -337,7 +359,7 @@ python predict.py --device "gpu" --max_seq_length 128 --batch_size 32 --dataset_
python export_model.py --params_path ./checkpoint/ --output_path ./export
```
-如果使用ERNIE M作为预训练模型,运行方式:
+如果使用多语言模型 ERNIE M作为预训练模型,运行方式:
```shell
python export_model.py --params_path ./checkpoint/ --output_path ./export --multilingual
```
@@ -418,8 +440,6 @@ prune/
3. ERNIE Base、Medium、Mini、Micro、Nano的模型宽度(multi head数量)为12,ERNIE Xbase、Large 模型宽度(multi head数量)为16,保留比例`width_mult`乘以宽度(multi haed数量)应为整数。
-4. **压缩API暂不支持多语言预训练模型ERNIE-M**,相关功能正在加紧开发中。
-
#### 2.5.3 部署方案
- 离线部署搭建请参考[离线部署](deploy/predictor/README.md)。
diff --git a/applications/text_classification/hierarchical/analysis/README.md b/applications/text_classification/hierarchical/analysis/README.md
index 4ec93d74985d..a683594da3e0 100644
--- a/applications/text_classification/hierarchical/analysis/README.md
+++ b/applications/text_classification/hierarchical/analysis/README.md
@@ -1,18 +1,29 @@
-# 层次分类训练评估与模型优化指南
+# 训练评估与模型优化指南
**目录**
- * [analysis模块介绍](#analysis模块介绍)
+ * [Analysis模块介绍](#Analysis模块介绍)
+ * [环境准备](#环境准备)
* [模型评估](#模型评估)
- * [稀疏数据筛选方案](#稀疏数据筛选方案)
- * [脏数据清洗方案](#脏数据清洗方案)
- * [数据增强策略方案](#数据增强策略方案)
+ * [可解释性分析](#可解释性分析)
+ * [单词级别可解释性分析](#单词级别可解释性分析)
+ * [句子级别可解释性分析](#句子级别可解释性分析)
+ * [数据优化](#数据优化)
+ * [稀疏数据筛选方案](#稀疏数据筛选方案)
+ * [脏数据清洗方案](#脏数据清洗方案)
+ * [数据增强策略方案](#数据增强策略方案)
-## analysis模块介绍
+## Analysis模块介绍
-analysis模块提供了**模型评估**脚本对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题问题。同时基于[可信AI工具集](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化方案从多角度帮助开发者提升模型效果。
+Analysis模块提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+- **模型评估:** 对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题。
+
+- **可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果。
+
+- **数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果。
-
+
以下是本项目主要代码结构及说明:
@@ -20,10 +31,27 @@ analysis模块提供了**模型评估**脚本对整体分类情况和每个类
```text
analysis/
├── evaluate.py # 评估脚本
+├── sent_interpret.py # 句子级别可解释性分析脚本
+├── word_interpret.py # 单词级别可解释性分析notebook
├── sparse.py # 稀疏数据筛选脚本
├── dirty.py # 脏数据清洗脚本
├── aug.py # 数据增强脚本
-└── README.md # 层次分类训练评估与模型优化指南
+└── README.md # 训练评估与模型优化指南
+```
+
+## 环境准备
+需要可解释性分析和数据优化需要安装相关环境。
+- trustai >= 0.1.7
+- interpretdl >= 0.7.0
+
+**安装TrustAI**(可选)如果使用可解释性分析和数据优化中稀疏数据筛选和脏数据清洗需要安装TrustAI。
+```shell
+pip install trustai==0.1.7
+```
+
+**安装InterpretDL**(可选)如果使用词级别可解释性分析GradShap方法,需要安装InterpretDL
+```shell
+pip install interpretdl==0.7.0
```
## 模型评估
@@ -37,14 +65,14 @@ python evaluate.py \
--params_path "../checkpoint" \
--max_seq_length 128 \
--batch_size 32 \
- --bad_case_path "./bad_case.txt"
+ --bad_case_file "bad_case.txt"
```
默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt、dev.txt和label.txt文件;默认为None。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
@@ -54,6 +82,7 @@ python evaluate.py \
* `label_file`:本地数据集中标签集文件名;默认为"label.txt"。
* `bad_case_path`:开发集中预测错误样本保存路径;默认为"/bad_case.txt"。
+
输出打印示例:
```text
@@ -78,30 +107,109 @@ python evaluate.py \
预测错误的样本保存在bad_case.txt文件中:
```text
-Prediction Label Text
-组织关系,组织关系##解雇 组织关系,组织关系##加盟,组织关系##裁员 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
-组织关系,组织关系##解雇 组织关系,组织关系##裁员 冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从
-组织关系,组织关系##裁员 组织关系,组织关系##退出,组织关系##裁员 有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。
-人生,人生##死亡,灾害/意外,灾害/意外##坍/垮塌 灾害/意外,灾害/意外##坍/垮塌 冲刺千亿的美的置业贵阳项目倒塌致8人死亡已责令全面停工
+Text Label Prediction
+据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。 组织关系,组织关系##加盟,组织关系##裁员 组织关系,组织关系##解雇
+冠军射手被裁掉,欲加入湖人队,但湖人却无意,冠军射手何去何从 组织关系,组织关系##裁员 组织关系,组织关系##解雇
+6月7日报道,IBM将裁员超过1000人。IBM周四确认,将裁减一千多人。据知情人士称,此次裁员将影响到约1700名员工,约占IBM全球逾34万员工中的0.5%。IBM股价今年累计上涨16%,但该公司4月发布的财报显示,一季度营收下降5%,低于市场预期。 组织关系,组织关系##裁员 组织关系,组织关系##裁员,财经/交易
+有多名魅族员工表示,从6月份开始,魅族开始了新一轮裁员,重点裁员区域是营销和线下。裁员占比超过30%,剩余员工将不过千余人,魅族的知名工程师,爱讲真话的洪汉生已经从钉钉里退出了,外界传言说他去了OPPO。 组织关系,组织关系##退出,组织关系##裁员 组织关系,组织关系##裁员
...
```
-## 稀疏数据筛选方案
+## 可解释性分析
+"模型为什么会预测出这个结果?"是文本分类任务开发者时常遇到的问题,如何分析错误样本(bad case)是文本分类任务落地中重要一环,本项目基于TrustAI开源了基于词级别和句子级别的模型可解释性分析方法,帮助开发者更好地理解文本分类模型与数据,有助于后续的模型优化与数据清洗标注。
+
+### 单词级别可解释性分析
+本项目开源模型的词级别可解释性分析Notebook,提供LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用Jupyter Notebook进行数据分析。
+
+运行 [word_interpret.ipynb](./word_interpreter.ipynb) 代码,即可分析影响样本预测结果的关键词以及可视化所有词对预测结果的贡献情况,颜色越深代表这个词对预测结果影响越大:
+
+
+
+
+### 句子级别可解释性分析
+本项目基于特征相似度([FeatureSimilarity](https://arxiv.org/abs/2104.04128))算法,计算对样本预测结果正影响的训练数据,帮助理解模型的预测结果与训练集数据的关系。
-稀疏数据指缺乏足够训练数据支持导致低置信度的待预测数据,简单来说,由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别。本项目中稀疏数据筛选基于TrustAI(可信AI)工具集,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+待分析数据文件`interpret_input_file`应为以下三种格式中的一种:
+**格式一:包括文本、标签、预测结果**
+```text
+<文本>'\t'<标签>'\t'<预测结果>
+...
+```
-稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的稀疏数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型预测效果。
+**格式二:包括文本、标签**
+```text
+<文本>'\t'<标签>
+...
+```
-**安装TrustAI**
+**格式三:只包括文本**
+```text
+<文本>
+准予原告胡某甲与被告韩某甲离婚。
+...
+```
+
+我们可以运行代码,得到支持样本模型预测结果的训练数据:
```shell
-pip install trustai==0.1.4
+python sent_interpret.py \
+ --device "gpu" \
+ --dataset_dir "../data" \
+ --params_path "../checkpoint/" \
+ --max_seq_length 128 \
+ --batch_size 16 \
+ --top_k 3 \
+ --train_file "train.txt" \
+ --interpret_input_file "bad_case.txt" \
+ --interpret_result_file "sent_interpret.txt"
+```
+
+默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
+
+可支持配置的参数:
+
+* `device`: 选用什么设备进行训练,可可选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
+* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
+* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
+* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
+* `seed`:随机种子,默认为3。
+* `top_k`:筛选支持训练证据数量;默认为3。
+* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
+* `interpret_input_file`:本地数据集中待分析文件名;默认为"bad_case.txt"。
+* `interpret_result_file`:保存句子级别可解释性结果文件名;默认为"sent_interpret.txt"。
+
+可解释性结果保存在 `interpret_result_file` 文件中:
+```text
+text: 据猛龙随队记者JoshLewenberg报道,消息人士透露,猛龙已将前锋萨加巴-科纳特裁掉。此前他与猛龙签下了一份Exhibit10合同。在被裁掉后,科纳特下赛季大概率将前往猛龙的发展联盟球队效力。
+predict label: 组织关系,组织关系##解雇
+label: 组织关系,组织关系##加盟,组织关系##裁员
+examples with positive influence
+support1 text: 尼克斯官方今日宣布,他们已经裁掉了前锋扎克-欧文,后者昨日才与尼克斯签约。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.99357
+support2 text: 活塞官方今日宣布,他们已经签下了克雷格-斯沃德,并且裁掉了托德-威瑟斯。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98344
+support3 text: 孟菲斯灰熊今年宣布,球队已经签下后卫达斯蒂-汉纳斯(DustyHannahs,版头图)并裁掉马特-穆尼。 label: 组织关系,组织关系##加盟,组织关系##解雇 score: 0.98219
+...
```
-### 稀疏数据识别--数据增强
+## 数据优化
+
+### 稀疏数据筛选方案
+
+稀疏数据筛选适用于文本分类中**数据不平衡或训练数据覆盖不足**的场景,简单来说,就是由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别的情况。稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型在开发集的预测效果。
+
+本项目中稀疏数据筛选基于TrustAI,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+
+
+#### 稀疏数据识别—数据增强
+
+这里我们将介绍稀疏数据识别—数据增强流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
-这里我们将介绍稀疏数据识别--数据增强流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选训练集中对稀疏数据的支持数据进行数据增强,然后将得到的数据增强后的支持数据加入到训练集中进行训练。
+- **数据增强**:将稀疏数据集在训练集中的支持证据应用数据增强策略,这些数据增强后的训练数据记为支持数据集(Support Dataset);
-现在我们进行稀疏数据识别--数据增强,得到新增训练数据:
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
+
+现在我们进行稀疏数据识别-数据增强,得到支持数据集:
```shell
python sparse.py \
@@ -119,7 +227,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `aug_strategy`:数据增强类型,可选"duplicate","substitute", "insert", "delete", "swap";默认为"substitute"。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -156,9 +264,15 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_aug.txt
|训练集+支持增强集(delete) |90.56| 80.26|
|训练集+支持增强集(swap) |90.18|80.05|
-### 稀疏数据识别--数据标注
+#### 稀疏数据识别-数据标注
+
+本方案能够有针对性进行数据标注,相比于随机标注数据更好提高模型预测效果。这里我们将介绍稀疏数据识别-数据标注流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
+
+- **数据标注**:在未标注数据集中筛选稀疏数据集的支持证据,并进行数据标注,记为支持数据集(Support Dataset);
-这里我们将介绍稀疏数据识别--数据标注流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选对稀疏数据支持的未标注数据,然后将得到支持数据进行标注后加入到训练集中进行训练。
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
现在我们进行稀疏数据识别--数据标注,得到待标注数据:
@@ -179,7 +293,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `annotate`:选择稀疏数据识别--数据标注模式;默认为False。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -216,14 +330,13 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_annotate.txt
|训练集+策略采样集(100) |91.12|**84.13**|
|训练集+随机采样集(100) |**91.24**|81.66|
-## 脏数据清洗方案
+### 脏数据清洗方案
-训练数据标注质量对模型效果有较大影响,但受限于标注人员水平、标注任务难易程度等影响,训练数据中都存在一定比例的标注较差的数据(脏数据)。当标注数据规模较大时,数据标注检查就成为一个难题。本项目中脏数据清洗基于TrustAI(可信AI)工具集,利用基于表示点方法的实例级证据分析方法,计算训练数据对模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为脏数据(标注错误样本)。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+脏数据清洗方案是基于已训练好的文本分类模型,筛选出训练数据集中标注错误的数据,再由人工检查重新标注,获得标注正确的数据集进行重新训练。我们将介绍脏数据清洗流程:
-**安装TrustAI**
-```shell
-pip install trustai==0.1.4
-```
+- **脏数据筛选:** 基于TrustAI中表示点方法,计算训练数据对文本分类模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为标注错误样本,记为脏数据集(Dirty Dataset)。
+
+- **数据清洗、训练:** 将筛选出的脏数据由人工重新检查,为数据打上正确的标签。将清洗后的训练数据重新放入文本分类模型进行训练。
现在我们进行脏数据识别,脏数据保存在`"train_dirty.txt"`,剩余训练数据保存在`"train_dirty_rest.txt"`:
@@ -246,16 +359,16 @@ python dirty.py \
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt和label.txt文件;默认为None。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `seed`:随机种子,默认为3。
* `dirty_file`:保存脏数据文件名,默认为"train_dirty.txt"。
* `rest_file`:保存剩余数据(非脏数据)文件名,默认为"train_dirty_rest.txt"。
* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
-* `dirty_threshold`:筛选脏数据用于重新标注的阈值,只选择影响分数大于阈值作为有效数据,默认为0。
+* `dirty_threshold`:筛选脏数据用于重新标注的阈值,只选择影响分数大于阈值作为支持数据,默认为0。
-我们将筛选出脏数据进行重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
+我们将筛选出脏数据进行人工检查重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
```shell
cat ../data/train_dirty_rest.txt ../data/train_dirty.txt > ../data/train_clean.txt
@@ -273,7 +386,7 @@ cat ../data/train_dirty_rest.txt ../data/train_dirty.txt > ../data/train_clean.t
|训练集(2000,含200条脏数据) + 脏数据清洗(100)|92.48|**87.83**|
|训练集(2000,含200条脏数据) + 脏数据清洗(150)|**92.55**|83.73|
-## 数据增强策略方案
+### 数据增强策略方案
在数据量较少或某些类别样本量较少时,也可以通过数据增强策略的方式,生成更多的训练数据,提升模型效果。
@@ -289,11 +402,11 @@ python aug.py \
* `train_path`:待增强训练数据集文件路径;默认为"../data/train.txt"。
* `aug_path`:增强生成的训练数据集文件路径;默认为"../data/train_aug.txt"。
-* `aug_strategy`:数据增强策略,可选"mix", "substitute", "insert", "delete", "swap"为多种数据策略混合使用;默认为"substitute"。
+* `aug_strategy`:数据增强策略,可选"mix", "substitute", "insert", "delete", "swap","mix"为多种数据策略混合使用;默认为"substitute"。
* `aug_type`:词替换/词插入增强类型,可选"synonym", "homonym", "mlm",建议在GPU环境下使用mlm类型;默认为"synonym"。
* `create_n`:生成的句子数量,默认为2。
* `aug_percent`:生成词替换百分比,默认为0.1。
-* `device`: 选用什么设备进行增强,选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
+* `device`: 选用什么设备进行增强,可选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
生成的增强数据保存在`"aug.txt"`文件中,与训练集数据`train.txt`合并得到新的训练集`train_aug.txt`重新进行训练:
diff --git a/applications/text_classification/hierarchical/analysis/evaluate.py b/applications/text_classification/hierarchical/analysis/evaluate.py
index f0db5a5d62cd..7bf8a2ebe2cd 100644
--- a/applications/text_classification/hierarchical/analysis/evaluate.py
+++ b/applications/text_classification/hierarchical/analysis/evaluate.py
@@ -37,7 +37,7 @@
parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
parser.add_argument("--dev_file", type=str, default="dev.txt", help="Dev dataset file name")
parser.add_argument("--label_file", type=str, default="label.txt", help="Label file name")
-parser.add_argument("--bad_case_path", type=str, default="./bad_case.txt", help="Bad case saving file path")
+parser.add_argument("--bad_case_file", type=str, default="./bad_case.txt", help="Bad case saving file path")
args = parser.parse_args()
# yapf: enable
@@ -236,23 +236,22 @@ def evaluate():
report[str(i)]['recall'] * 100,
report[str(i)]['f1-score'] * 100))
logger.info("----------------------------")
-
- with open(args.bad_case_path, 'w', encoding="utf-8") as f:
- f.write("Prediction\tLabel\tText\n")
+ bad_case_path = os.path.join(args.dataset_dir, args.bad_case_file)
+ with open(bad_case_path, 'w', encoding="utf-8") as f:
+ f.write("Text\tLabel\tPrediction\n")
for i in range(len(preds)):
for p, l in zip(preds[i], labels[i]):
if (p and l == 0) or (not p and l == 1):
pred_n = [
label_map[i] for i, pp in enumerate(preds[i]) if pp
]
- f.write(",".join(pred_n) + "\t" +
+ f.write(dev_ds.data[i]["text"] + "\t" +
dev_ds.data[i]["label_n"] + "\t" +
- dev_ds.data[i]["text"] + "\n")
+ ",".join(pred_n) + "\n")
break
f.close()
- logger.info("Bad case in dev dataset saved in {}".format(
- args.bad_case_path))
+ logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))
return
diff --git a/applications/text_classification/hierarchical/analysis/sent_interpret.py b/applications/text_classification/hierarchical/analysis/sent_interpret.py
new file mode 100644
index 000000000000..9a586d24a5b4
--- /dev/null
+++ b/applications/text_classification/hierarchical/analysis/sent_interpret.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import random
+import os
+import argparse
+import numpy as np
+
+import paddle
+import paddle.nn.functional as F
+from paddle.io import DataLoader, BatchSampler
+from paddlenlp.data import DataCollatorWithPadding
+from paddlenlp.datasets import load_dataset
+from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer, LinearDecayWithWarmup
+from paddlenlp.utils.log import logger
+from trustai.interpretation import FeatureSimilarityModel
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument('--device', default="gpu", help="Select which device to train model, defaults to gpu.")
+parser.add_argument("--dataset_dir", required=True, type=str, help="The dataset directory should include train.txt,dev.txt and test.txt files.")
+parser.add_argument("--params_path", default="../checkpoint/", type=str, help="The path to model parameters to be loaded.")
+parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
+parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
+parser.add_argument("--seed", type=int, default=3, help="random seed for initialization")
+parser.add_argument("--top_k", type=int, default=3, help="Top K important training data.")
+parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
+parser.add_argument("--interpret_input_file", type=str, default="bad_case.txt", help="interpretation file name")
+parser.add_argument("--interpret_result_file", type=str, default="sent_interpret.txt", help="interpreted file name")
+args = parser.parse_args()
+# yapf: enable
+
+
+def set_seed(seed):
+ """
+ Set random seed
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+
+def read_local_dataset(path):
+ """
+ Read dataset file
+ """
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ items = line.strip().split('\t')
+ if items[0] == 'Text':
+ continue
+ if len(items) == 3:
+ yield {'text': items[0], 'label': items[1], 'predict': items[2]}
+ elif len(items) == 2:
+ yield {'text': items[0], 'label': items[1], 'predict': ''}
+ elif len(items) == 1:
+ yield {'text': items[0], 'label': '', 'predict': ''}
+ else:
+ logger.info(line.strip())
+ raise ValueError("{} should be in fixed format.".format(path))
+
+
+def preprocess_function(examples, tokenizer, max_seq_length):
+ """
+ Preprocess dataset
+ """
+ result = tokenizer(text=examples["text"], max_seq_len=max_seq_length)
+ return result
+
+
+class LocalDataCollatorWithPadding(DataCollatorWithPadding):
+ """
+ Convert the result of DataCollatorWithPadding from dict dictionary to a list
+ """
+
+ def __call__(self, features):
+ batch = super().__call__(features)
+ batch = list(batch.values())
+ return batch
+
+
+def find_positive_influence_data():
+
+ set_seed(args.seed)
+ paddle.set_device(args.device)
+
+ # Define model & tokenizer
+ if os.path.exists(os.path.join(
+ args.params_path, "model_state.pdparams")) and os.path.exists(
+ os.path.join(args.params_path,
+ "model_config.json")) and os.path.exists(
+ os.path.join(args.params_path,
+ "tokenizer_config.json")):
+ model = AutoModelForSequenceClassification.from_pretrained(
+ args.params_path)
+ tokenizer = AutoTokenizer.from_pretrained(args.params_path)
+ else:
+ raise ValueError("The {} should exist.".format(args.params_path))
+
+ # Prepare & preprocess dataset
+ train_path = os.path.join(args.dataset_dir, args.train_file)
+ interpret_path = os.path.join(args.dataset_dir, args.interpret_input_file)
+
+ train_ds = load_dataset(read_local_dataset, path=train_path, lazy=False)
+ interpret_ds = load_dataset(read_local_dataset,
+ path=interpret_path,
+ lazy=False)
+ trans_func = functools.partial(preprocess_function,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length)
+
+ train_ds = train_ds.map(trans_func)
+ interpret_ds = interpret_ds.map(trans_func)
+
+ # Batchify dataset
+ collate_fn = LocalDataCollatorWithPadding(tokenizer)
+ train_batch_sampler = BatchSampler(train_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ interpret_batch_sampler = BatchSampler(interpret_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ train_data_loader = DataLoader(dataset=train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=collate_fn)
+ interpret_data_loader = DataLoader(dataset=interpret_ds,
+ batch_sampler=interpret_batch_sampler,
+ collate_fn=collate_fn)
+
+ # Classifier_layer_name is the layer name of the last output layer
+ feature_sim = FeatureSimilarityModel(model,
+ train_data_loader,
+ classifier_layer_name="classifier")
+ # Feature similarity analysis & select sparse data
+ analysis_result = []
+ for batch in interpret_data_loader:
+ analysis_result += feature_sim(batch, sample_num=args.top_k)
+ with open(os.path.join(args.dataset_dir, args.interpret_result_file),
+ 'w') as f:
+ for i in range(len(analysis_result)):
+ f.write("text: " + interpret_ds.data[i]["text"] + '\n')
+ if 'predict' in interpret_ds.data[i]:
+ f.write("predict label: " + interpret_ds.data[i]["predict"] +
+ '\n')
+ if 'label' in interpret_ds.data[i]:
+ f.write("label: " + interpret_ds.data[i]["label"] + '\n')
+ f.write("examples with positive influence\n")
+ for i, (idx, score) in enumerate(
+ zip(analysis_result[i].pos_indexes,
+ analysis_result[i].pos_scores)):
+ f.write("support{} text: ".format(i + 1) +
+ train_ds.data[idx]["text"] + '\t' + "label: " +
+ train_ds.data[idx]["label"] + '\t' + "score: " +
+ "{:.5f}".format(score) + '\n')
+ f.close()
+
+
+if __name__ == "__main__":
+ find_positive_influence_data()
diff --git a/applications/text_classification/hierarchical/analysis/word_interpret.ipynb b/applications/text_classification/hierarchical/analysis/word_interpret.ipynb
new file mode 100644
index 000000000000..8c25b87a2711
--- /dev/null
+++ b/applications/text_classification/hierarchical/analysis/word_interpret.ipynb
@@ -0,0 +1,366 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 词级别可解释性分析\n",
+ "本项目提供模型的词级别可解释性分析,包括LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用此项目进行数据分析。\n",
+ "\n",
+ "![image](https://user-images.githubusercontent.com/63761690/195334753-78cc2dc8-a5ba-4460-9fde-3b1bb704c053.png)\n",
+ " \n",
+ "\n",
+ "## 1.导入Python模块与参数配置\n",
+ "首先我们导入必要的导入必要python模块和设置配置参数,词级别可解释性分析算法支持三种待分析的文本 `INTERPRETER_FILE` 数据文件格式:\n",
+ "\n",
+ "**格式一:包括文本、标签、预测结果**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>'\\t'<预测结果>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式二:包括文本、标签**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式三:只包括文本**\n",
+ "```text\n",
+ "<文本>\n",
+ "...\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import functools\n",
+ "import random\n",
+ "import os\n",
+ "import argparse\n",
+ "\n",
+ "import jieba\n",
+ "import numpy as np \n",
+ "from trustai.interpretation import VisualizationTextRecord\n",
+ "from trustai.interpretation import get_word_offset\n",
+ "import paddle\n",
+ "from paddle.io import DataLoader, BatchSampler\n",
+ "from paddlenlp.data import DataCollatorWithPadding\n",
+ "from paddlenlp.datasets import load_dataset\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from trustai.interpretation import VisualizationTextRecord\n",
+ "from trustai.interpretation import get_word_offset\n",
+ "import paddle\n",
+ "from paddle.io import DataLoader, BatchSampler\n",
+ "from paddlenlp.data import DataCollatorWithPadding\n",
+ "from paddlenlp.datasets import load_dataset\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 预先定义配置参数\n",
+ "\n",
+ "# 运行环境,可选\"cpu\",\"gpu\",\"gpu:x\"(x为gpu编号)\n",
+ "DEVICE = \"gpu\"\n",
+ "# 数据路径\n",
+ "DATASET_DIR = \"../data\" \n",
+ "# 训练模型保存路径\n",
+ "PARAM_PATH = \"../checkpoint/\" \n",
+ "# tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数\n",
+ "MAX_LENGTH = 128 \n",
+ "# 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数\n",
+ "BATCH_SIZE = 1 \n",
+ "# 待分析解释的数据\n",
+ "INTERPRETER_FILE = \"bad_case.txt\"\n",
+ "# 可选 \"ig\",\"lime\",\"grad\" ,可以根据实际任务效果选择解释器\n",
+ "# \"grad\":GradShap方法依赖interpretdl\n",
+ "# !pip install interpretdl\n",
+ "INTERPRETER = \"ig\"\n",
+ "# 分析句子中TOP K关键词,K值\n",
+ "KEY_WORDS_NUM = 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def read_local_dataset(path):\n",
+ " \"\"\"\n",
+ " Read dataset file\n",
+ " \"\"\"\n",
+ " with open(path, 'r', encoding='utf-8') as f:\n",
+ " for line in f:\n",
+ " items = line.strip().split('\\t')\n",
+ " if items[0] == 'Text':\n",
+ " continue\n",
+ " if len(items) == 3:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': items[2]}\n",
+ " elif len(items) == 2:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': ''}\n",
+ " elif len(items) == 1:\n",
+ " yield {'text': items[0], 'label': '', 'predict': ''}\n",
+ " else:\n",
+ " raise ValueError(\"{} should be in fixed format.\".format(path))\n",
+ "\n",
+ "def preprocess_function(examples, tokenizer, max_seq_length):\n",
+ " \"\"\"\n",
+ " Preprocess dataset\n",
+ " \"\"\"\n",
+ " result = tokenizer(text=examples[\"text\"], max_seq_len=max_seq_length)\n",
+ " return result\n",
+ "\n",
+ "class LocalDataCollatorWithPadding(DataCollatorWithPadding):\n",
+ " \"\"\"\n",
+ " Convert the result of DataCollatorWithPadding from dict dictionary to a list\n",
+ " \"\"\"\n",
+ "\n",
+ " def __call__(self, features):\n",
+ " batch = super().__call__(features)\n",
+ " batch = list(batch.values())\n",
+ " return batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m[2022-10-12 11:45:49,858] [ INFO]\u001b[0m - We are using to load '/workspace/PaddleNLP/applications/text_classification/hierarchical/checkpoint/'.\u001b[0m\n",
+ "W1012 11:45:49.861358 26086 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2\n",
+ "W1012 11:45:49.865923 26086 gpu_resources.cc:91] device: 0, cuDNN Version: 8.1.\n",
+ "\u001b[32m[2022-10-12 11:45:52,912] [ INFO]\u001b[0m - We are using to load '/workspace/PaddleNLP/applications/text_classification/hierarchical/checkpoint/'.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "paddle.set_device(DEVICE)\n",
+ "\n",
+ "# Define model & tokenizer\n",
+ "if os.path.exists(os.path.join(\n",
+ " PARAM_PATH, \"model_state.pdparams\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"model_config.json\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"tokenizer_config.json\")):\n",
+ " model = AutoModelForSequenceClassification.from_pretrained(\n",
+ " PARAM_PATH)\n",
+ " tokenizer = AutoTokenizer.from_pretrained(PARAM_PATH)\n",
+ "else:\n",
+ " raise ValueError(\"The {} should exist.\".format(PARAM_PATH))\n",
+ "\n",
+ "# Prepare & preprocess dataset\n",
+ "interpret_path = os.path.join(DATASET_DIR, INTERPRETER_FILE)\n",
+ "\n",
+ "\n",
+ "interpret_ds = load_dataset(read_local_dataset, path=interpret_path, lazy=False)\n",
+ "trans_func = functools.partial(preprocess_function,\n",
+ " tokenizer=tokenizer,\n",
+ " max_seq_length=MAX_LENGTH)\n",
+ "\n",
+ "interpret_ds = interpret_ds.map(trans_func)\n",
+ "\n",
+ "# Batchify dataset\n",
+ "collate_fn = LocalDataCollatorWithPadding(tokenizer)\n",
+ "interpret_batch_sampler = BatchSampler(interpret_ds,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=False)\n",
+ "interpret_data_loader = DataLoader(dataset=interpret_ds,\n",
+ " batch_sampler=interpret_batch_sampler,\n",
+ " collate_fn=collate_fn)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Start token level interpretion, it will take some time...\n",
+ "Building prefix dict from the default dictionary ...\n",
+ "Loading model from cache /tmp/jieba.cache\n",
+ "Loading model cost 0.746 seconds.\n",
+ "Prefix dict has been built successfully.\n",
+ "Start word level alignment, it will take some time...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Init an interpreter\n",
+ "if INTERPRETER == 'ig':\n",
+ " from trustai.interpretation.token_level import IntGradInterpreter\n",
+ " interpreter = IntGradInterpreter(model)\n",
+ "elif INTERPRETER == 'lime':\n",
+ " from trustai.interpretation.token_level import LIMEInterpreter\n",
+ " interpreter = LIMEInterpreter(model, unk_id=tokenizer.convert_tokens_to_ids('[UNK]'), pad_id=tokenizer.convert_tokens_to_ids('[PAD]'))\n",
+ "else:\n",
+ " from trustai.interpretation.token_level import GradShapInterpreter\n",
+ " interpreter = GradShapInterpreter(model)\n",
+ "\n",
+ "# Use interpreter to get the importance scores for all data\n",
+ "print(\"Start token level interpretion, it will take some time...\")\n",
+ "analysis_result = []\n",
+ "for batch in interpret_data_loader:\n",
+ " analysis_result += interpreter(tuple(batch))\n",
+ "\n",
+ "# Add CLS and SEP tags to both original text and standard splited tokens\n",
+ "contexts = []\n",
+ "words = []\n",
+ "for i in range(len(interpret_ds)):\n",
+ " text = interpret_ds.data[i][\"text\"]\n",
+ " contexts.append(\"[CLS]\" + text + \"[SEP]\")\n",
+ " words.append([\"[CLS]\"] + list(jieba.cut(text)) + [\"[SEP]\"])\n",
+ "\n",
+ "# Get the offset map of tokenized tokens and standard splited tokens\n",
+ "print(\"Start word level alignment, it will take some time...\")\n",
+ "ori_offset_maps = []\n",
+ "word_offset_maps = []\n",
+ "for i in range(len(contexts)):\n",
+ " ori_offset_maps.append(tokenizer.get_offset_mapping(contexts[i]))\n",
+ " word_offset_maps.append(get_word_offset(contexts[i], words[i]))\n",
+ "\n",
+ "align_res = interpreter.alignment(analysis_result, contexts, words, word_offset_maps, ori_offset_maps, special_tokens=[\"[CLS]\", '[SEP]'],rationale_num=KEY_WORDS_NUM)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.core.display import display, HTML\n",
+ "class Visualization(VisualizationTextRecord):\n",
+ "\n",
+ " def __init__(self, interpret_res, true_label=None, pred_label=None, words=None):\n",
+ " if words is not None:\n",
+ " self.words = words\n",
+ " else:\n",
+ " self.words = interpret_res.words\n",
+ " self.pred_label = pred_label if pred_label is not None else ''\n",
+ " self.true_label = true_label if true_label is not None else ''\n",
+ " self.key_words = \" \".join(set(interpret_res.rationale_tokens))\n",
+ " word_attributions = interpret_res.word_attributions\n",
+ " _max = max(word_attributions)\n",
+ " _min = min(word_attributions)\n",
+ " self.word_attributions = [(word_imp - _min) / (_max - _min) for word_imp in word_attributions]\n",
+ "\n",
+ " def record_html(self):\n",
+ " \"\"\"change all informations to html\"\"\"\n",
+ " return \"\".join([\n",
+ " \"\",\n",
+ " self._format_class(self.true_label),\n",
+ " self._format_class(self.pred_label),\n",
+ " self._format_class(self.key_words),\n",
+ " self._format_word_attributions(),\n",
+ " \"
\",\n",
+ " ])\n",
+ " def _format_class(self, label):\n",
+ " return '{label} | '.format(label=label)\n",
+ "\n",
+ "def visualize_text(text_records):\n",
+ " \"\"\"visualize text\"\"\"\n",
+ " html = [\"\"]\n",
+ " rows = [\"Label | \"\n",
+ " \"Prediction | \"\n",
+ " \"Key words | \"\n",
+ " \"Important visualization | \"]\n",
+ " for record in text_records:\n",
+ " rows.append(record.record_html())\n",
+ " html.append(\"\".join(rows))\n",
+ " html.append(\"
---|
\")\n",
+ " html = HTML(\"\".join(html))\n",
+ " display(html)\n",
+ " return html.data\n",
+ "\n",
+ "\n",
+ "def visualize(interpret_res, ds):\n",
+ " records = []\n",
+ " for i in range(len(interpret_res)):\n",
+ " records.append(Visualization(interpret_res[i], true_label=ds.data[i][\"label\"], pred_label=ds.data[i][\"predict\"]))\n",
+ " html = visualize_text(records)\n",
+ " return html"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Label | Prediction | Key words | Important visualization |
---|
组织关系,组织关系##加盟,组织关系##裁员 | 组织关系,组织关系##解雇 | 。 特裁 签下 此前 掉 | [CLS] 据 猛龙 随队 记者 JoshLewenberg 报道 , 消息人士 透露 , 猛龙 已 将 前锋 萨 加巴 - 科纳 特裁 掉 。 此前 他 与 猛龙 签下 了 一份 Exhibit10 合同 。 在 被 裁掉 后 , 科纳 特下 赛季 大 概率 将 前往 猛龙 的 发展 联盟 球队 效力 。 [SEP] |
|
组织关系,组织关系##裁员 | 组织关系,组织关系##解雇 | 加入 湖人队 裁掉 被 何去何从 | [CLS] 冠军 射手 被 裁掉 , 欲 加入 湖人队 , 但 湖人 却 无意 , 冠军 射手 何去何从 [SEP] |
|
组织关系,组织关系##裁员 | 组织关系,组织关系##裁员,财经/交易 | 裁员 超过 1000 将 裁减 | [CLS] 6 月 7 日 报道 , IBM 将 裁员 超过 1000 人 。 IBM 周四 确认 , 将 裁减 一千多 人 。 据 知情 人士 称 , 此次 裁员 将 影响 到 约 1700 名 员工 , 约 占 IBM 全球 逾 34 万 员工 中 的 0.5% 。 IBM 股价 今年 累计 上涨 16% , 但 该 公司 4 月 发布 的 财报 显示 , 一季度 营收 下降 5% , 低于 市场 预期 。 [SEP] |
|
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# process for vbisualize\n",
+ "html = visualize(align_res, interpret_ds)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 64-bit",
+ "metadata": {
+ "interpreter": {
+ "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
+ }
+ },
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13-final"
+ },
+ "orig_nbformat": 2
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/applications/text_classification/multi_class/README.md b/applications/text_classification/multi_class/README.md
index e4a45a80760a..5da9922dfd94 100644
--- a/applications/text_classification/multi_class/README.md
+++ b/applications/text_classification/multi_class/README.md
@@ -37,7 +37,7 @@
- 【标注成本高、标注样本较少的小样本场景】 👉 [提示学习多分类方案](./few-shot#readme)
-- 【标签类别不固定场景】 👉 [语义索引多分类方案](./retrieval_based#readme)
+- 【标签类别不固定场景、标签类别众多】 👉 [语义索引多分类方案](./retrieval_based#readme)
@@ -95,20 +95,13 @@ python3 -m pip install scikit-learn==1.0.2
```text
multi_class/
+├── few-shot # 小样本学习方案
+├── retrieval_based # 语义索引方案
+├── analysis # 分析模块
├── deploy # 部署
│ └── predictor # 离线部署
-│ │ ├── infer.py # 测试脚本
-│ │ ├── predictor.py 离线部署脚本
-│ │ └── README.md # 离线部署使用说明
│ ├── paddle_serving # PaddleServing在线服务化部署
-│ │ ├──config.yml # 服务端的配置文件
-│ │ ├──rpc_client.py # 客户端预测脚本
-│ │ ├──service.py # 服务端的脚本
-│ │ └── README.md # 在线服务化部署使用说明
│ └── triton_serving # Triton在线服务化部署
-│ ├── README.md # Triton部署使用说明
-│ ├── seqcls_grpc_client.py # 客户端预测脚本
-│ └── models
├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 静态图模型导出脚本
@@ -132,7 +125,7 @@ data/
└── data.txt # 待预测数据文件
```
-**训练、开发、测试数据集** 文件中文本与标签类别名用tab符`'\t'`分隔开。
+**训练、开发、测试数据集** 文件中文本与标签类别名用tab符`'\t'`分隔开,文本中避免出现tab符`'\t'`。
- train.txt/dev.txt/test.txt 文件格式:
```text
@@ -273,10 +266,16 @@ checkpoint/
#### 2.4.2 训练评估与模型优化
-训练后的模型我们可以使用 [模型分析模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
+文本分类预测过程中常会遇到诸如"模型为什么会预测出错误的结果","如何提升模型的表现"等问题。[Analysis模块](./analysis) 提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+
+
+
+
+**模型评估:** 训练后的模型我们可以使用 [Analysis模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
```shell
-python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_path "./bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
+python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_file "bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
```
输出打印示例:
@@ -301,15 +300,40 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32
预测错误的样本保存在bad_case.txt文件中:
```text
-Confidence Prediction Label Text
-0.77 注意事项 其他 您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗?
-0.94 就医建议 其他 输卵管粘连的基本检查
-0.78 病情诊断 其他 经常干呕恶心,这是生病了吗
-0.79 后果表述 其他 吃左旋肉碱后的不良反应
+Text Label Prediction
+您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗? 其他 注意事项
+输卵管粘连的基本检查 其他 就医建议
+会是胎动么? 其他 病情诊断
+经常干呕恶心,这是生病了吗 其他 病情诊断
+菏泽哪个医院治疗白癜风比较好?怎么治好 就医建议 治疗方案
+...
+```
+
+**可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果,用于错误样本(bad case)分析,细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 单词级别可解释性分析,也即分析待预测样本中哪一些单词对模型预测结果起重要作用。以下图为例,用颜色深浅表示单词对预测结果的重要性。
+
+
+
+
+- 句子级别可解释性分析 ,也即分析对待预测样本的模型预测结果与训练集中中哪些样本有重要关系。下面的例子表明句子级别可解释性分析可以帮助理解待预测样本的预测结果与训练集中样本之间的关联。
+```text
+text: 您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗?
+predict label: 注意事项
+label: 其他
+examples with positive influence
+support1 text: 感冒期间钙产品要继续服用吗? 钙尔奇就可以,也可以吃婴儿吃的乳钙 label: 注意事项 score: 0.96602
+support2 text: 打喷嚏可以吃布洛芬缓释胶囊么 label: 注意事项 score: 0.95687
+support3 text: 孕后期可以足疗吗 label: 注意事项 score: 0.94021
...
```
-模型表现常常受限于数据质量,在analysis模块中我们提供了基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)的稀疏数据筛选、脏数据清洗、数据增强三种优化方案助力开发者提升模型效果,更多模型评估和优化方案细节详见[训练评估与模型优化指南](analysis/README.md)。
+**数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果,策略细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 稀疏数据筛选主要是解决数据不均衡、训练数据覆盖不足的问题,通过数据增强和数据标注两种方式解决这一问题。
+- 脏数据清洗可以帮助开发者筛选训练集中错误标注的数据,对这些数据重新进行人工标注,得到标注正确的数据再重新进行训练。
+- 数据增强策略提供多种数据增强方案,可以快速扩充数据,提高模型泛化性和鲁棒性。
+
#### 2.4.3 模型预测
训练结束后,输入待预测数据(data.txt)和类别标签对照列表(label.txt),使用训练好的模型进行,默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
@@ -338,7 +362,7 @@ python predict.py --device "gpu" --max_seq_length 128 --batch_size 32 --dataset_
python export_model.py --params_path ./checkpoint/ --output_path ./export
```
-如果使用ERNIE M作为预训练模型,运行方式:
+如果使用多语言模型 ERNIE M作为预训练模型,运行方式:
```shell
python export_model.py --params_path ./checkpoint/ --output_path ./export --multilingual
```
@@ -421,7 +445,6 @@ prune/
3. ERNIE Base、Medium、Mini、Micro、Nano的模型宽度(multi head数量)为12,ERNIE Xbase、Large 模型宽度(multi head数量)为16,保留比例`width_mult`乘以宽度(multi haed数量)应为整数。
-4. **压缩API暂不支持多语言预训练模型ERNIE-M**,相关功能正在加紧开发中。
#### 2.5.3 部署方案
diff --git a/applications/text_classification/multi_class/analysis/README.md b/applications/text_classification/multi_class/analysis/README.md
index d9a47b44016c..20f31c0ed581 100644
--- a/applications/text_classification/multi_class/analysis/README.md
+++ b/applications/text_classification/multi_class/analysis/README.md
@@ -1,18 +1,29 @@
-# 多分类训练评估与模型优化指南
+# 训练评估与模型优化指南
**目录**
- * [analysis模块介绍](#analysis模块介绍)
+ * [Analysis模块介绍](#Analysis模块介绍)
+ * [环境准备](#环境准备)
* [模型评估](#模型评估)
- * [稀疏数据筛选方案](#稀疏数据筛选方案)
- * [脏数据清洗方案](#脏数据清洗方案)
- * [数据增强策略方案](#数据增强策略方案)
+ * [可解释性分析](#可解释性分析)
+ * [单词级别可解释性分析](#单词级别可解释性分析)
+ * [句子级别可解释性分析](#句子级别可解释性分析)
+ * [数据优化](#数据优化)
+ * [稀疏数据筛选方案](#稀疏数据筛选方案)
+ * [脏数据清洗方案](#脏数据清洗方案)
+ * [数据增强策略方案](#数据增强策略方案)
-## analysis模块介绍
+## Analysis模块介绍
-analysis模块提供了**模型评估**脚本对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题问题。同时基于[可信AI工具集](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化方案从多角度帮助开发者提升模型效果。
+Analysis模块提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+- **模型评估:** 对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题。
+
+- **可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果。
+
+- **数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果。
-
+
以下是本项目主要代码结构及说明:
@@ -20,10 +31,27 @@ analysis模块提供了**模型评估**脚本对整体分类情况和每个类
```text
analysis/
├── evaluate.py # 评估脚本
+├── sent_interpret.py # 句子级别可解释性分析脚本
+├── word_interpret.py # 单词级别可解释性分析notebook
├── sparse.py # 稀疏数据筛选脚本
├── dirty.py # 脏数据清洗脚本
├── aug.py # 数据增强脚本
-└── README.md # 多分类训练评估与模型优化指南
+└── README.md # 训练评估与模型优化指南
+```
+
+## 环境准备
+需要可解释性分析和数据优化需要安装相关环境。
+- trustai >= 0.1.7
+- interpretdl >= 0.7.0
+
+**安装TrustAI**(可选)如果使用可解释性分析和数据优化中稀疏数据筛选和脏数据清洗需要安装TrustAI。
+```shell
+pip install trustai==0.1.7
+```
+
+**安装InterpretDL**(可选)如果使用词级别可解释性分析GradShap方法,需要安装InterpretDL
+```shell
+pip install interpretdl==0.7.0
```
## 模型评估
@@ -37,14 +65,14 @@ python evaluate.py \
--params_path "../checkpoint" \
--max_seq_length 128 \
--batch_size 32 \
- --bad_case_path "./bad_case.txt"
+ --bad_case_file "bad_case.txt"
```
默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt、dev.txt和label.txt文件;默认为None。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
@@ -76,30 +104,109 @@ python evaluate.py \
预测错误的样本保存在bad_case.txt文件中:
```text
-Confidence Prediction Label Text
-0.77 注意事项 其他 您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗?
-0.94 就医建议 其他 输卵管粘连的基本检查
-0.78 病情诊断 其他 经常干呕恶心,这是生病了吗
-0.79 后果表述 其他 吃左旋肉碱后的不良反应
+Text Label Prediction
+您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗? 其他 注意事项
+输卵管粘连的基本检查 其他 就医建议
+会是胎动么? 其他 病情诊断
+经常干呕恶心,这是生病了吗 其他 病情诊断
+菏泽哪个医院治疗白癜风比较好?怎么治好 就医建议 治疗方案
...
```
-## 稀疏数据筛选方案
+## 可解释性分析
+"模型为什么会预测出这个结果?"是文本分类任务开发者时常遇到的问题,如何分析错误样本(bad case)是文本分类任务落地中重要一环,本项目基于TrustAI开源了基于词级别和句子级别的模型可解释性分析方法,帮助开发者更好地理解文本分类模型与数据,有助于后续的模型优化与数据清洗标注。
-稀疏数据指缺乏足够训练数据支持导致低置信度的待预测数据,简单来说,由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别。本项目中稀疏数据筛选基于TrustAI(可信AI)工具集,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+### 单词级别可解释性分析
+本项目开源模型的词级别可解释性分析Notebook,提供LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用Jupyter Notebook进行数据分析。
-稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的稀疏数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型预测效果。
+运行 [word_interpret.ipynb](./word_interpreter.ipynb) 代码,即可分析影响样本预测结果的关键词以及可视化所有词对预测结果的贡献情况,颜色越深代表这个词对预测结果影响越大:
+
+
+
-**安装TrustAI**
+### 句子级别可解释性分析
+本项目基于特征相似度([FeatureSimilarity](https://arxiv.org/abs/2104.04128))算法,计算对样本预测结果正影响的训练数据,帮助理解模型的预测结果与训练集数据的关系。
+
+待分析数据文件`interpret_input_file`应为以下三种格式中的一种:
+**格式一:包括文本、标签、预测结果**
+```text
+<文本>'\t'<标签>'\t'<预测结果>
+...
+```
+
+**格式二:包括文本、标签**
+```text
+<文本>'\t'<标签>
+...
+```
+
+**格式三:只包括文本**
+```text
+<文本>
+准予原告胡某甲与被告韩某甲离婚。
+...
+```
+
+我们可以运行代码,得到支持样本模型预测结果的训练数据:
```shell
-pip install trustai==0.1.4
+python sent_interpret.py \
+ --device "gpu" \
+ --dataset_dir "../data" \
+ --params_path "../checkpoint/" \
+ --max_seq_length 128 \
+ --batch_size 16 \
+ --top_k 3 \
+ --train_file "train.txt" \
+ --interpret_input_file "bad_case.txt" \
+ --interpret_result_file "sent_interpret.txt"
```
-### 稀疏数据识别--数据增强
+默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
+
+可支持配置的参数:
+
+* `device`: 选用什么设备进行训练,可可选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
+* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
+* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
+* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
+* `seed`:随机种子,默认为3。
+* `top_k`:筛选支持训练证据数量;默认为3。
+* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
+* `interpret_input_file`:本地数据集中待分析文件名;默认为"bad_case.txt"。
+* `interpret_result_file`:保存句子级别可解释性结果文件名;默认为"sent_interpret.txt"。
+
+可解释性结果保存在 `interpret_result_file` 文件中:
+```text
+text: 您好,请问一岁三个月的孩子可以服用复方锌布颗粒吗?
+predict label: 注意事项
+label: 其他
+examples with positive influence
+support1 text: 感冒期间钙产品要继续服用吗? 钙尔奇就可以,也可以吃婴儿吃的乳钙 label: 注意事项 score: 0.96602
+support2 text: 打喷嚏可以吃布洛芬缓释胶囊么 label: 注意事项 score: 0.95687
+support3 text: 孕后期可以足疗吗 label: 注意事项 score: 0.94021
+...
+```
+## 数据优化
-这里我们将介绍稀疏数据识别--数据增强流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选训练集中对稀疏数据的支持数据进行数据增强,然后将得到的数据增强后的支持数据加入到训练集中进行训练。
+### 稀疏数据筛选方案
-现在我们进行稀疏数据识别--数据增强,得到新增训练数据:
+稀疏数据筛选适用于文本分类中**数据不平衡或训练数据覆盖不足**的场景,简单来说,就是由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别的情况。稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型在开发集的预测效果。
+
+本项目中稀疏数据筛选基于TrustAI,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+
+
+#### 稀疏数据识别—数据增强
+
+这里我们将介绍稀疏数据识别—数据增强流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
+
+- **数据增强**:将稀疏数据集在训练集中的支持证据应用数据增强策略,这些数据增强后的训练数据记为支持数据集(Support Dataset);
+
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
+
+现在我们进行稀疏数据识别-数据增强,得到支持数据集:
```shell
python sparse.py \
@@ -117,7 +224,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `aug_strategy`:数据增强类型,可选"duplicate","substitute", "insert", "delete", "swap";默认为"substitute"。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -153,9 +260,15 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_aug.txt
|训练集+支持增强集(delete) |74.27|
|训练集+支持增强集(swap) |73.66|
-### 稀疏数据识别--数据标注
+#### 稀疏数据识别-数据标注
+
+本方案能够有针对性进行数据标注,相比于随机标注数据更好提高模型预测效果。这里我们将介绍稀疏数据识别-数据标注流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
+
+- **数据标注**:在未标注数据集中筛选稀疏数据集的支持证据,并进行数据标注,记为支持数据集(Support Dataset);
-这里我们将介绍稀疏数据识别--数据标注流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选对稀疏数据支持的未标注数据,然后将得到支持数据进行标注后加入到训练集中进行训练。
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
现在我们进行稀疏数据识别--数据标注,得到待标注数据:
@@ -176,7 +289,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `annotate`:选择稀疏数据识别--数据标注模式;默认为False。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -213,14 +326,13 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_annotate.txt
|训练集+策略采样集(100) |**77.64**|
|训练集+随机采样集(100) |76.37|
-## 脏数据清洗方案
+### 脏数据清洗方案
-训练数据标注质量对模型效果有较大影响,但受限于标注人员水平、标注任务难易程度等影响,训练数据中都存在一定比例的标注较差的数据(脏数据)。当标注数据规模较大时,数据标注检查就成为一个难题。本项目中脏数据清洗基于TrustAI(可信AI)工具集,利用基于表示点方法的实例级证据分析方法,计算训练数据对模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为脏数据(标注错误样本)。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+脏数据清洗方案是基于已训练好的文本分类模型,筛选出训练数据集中标注错误的数据,再由人工检查重新标注,获得标注正确的数据集进行重新训练。我们将介绍脏数据清洗流程:
-**安装TrustAI**
-```shell
-pip install trustai==0.1.4
-```
+- **脏数据筛选:** 基于TrustAI中表示点方法,计算训练数据对文本分类模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为标注错误样本,记为脏数据集(Dirty Dataset)。
+
+- **数据清洗、训练:** 将筛选出的脏数据由人工重新检查,为数据打上正确的标签。将清洗后的训练数据重新放入文本分类模型进行训练。
现在我们进行脏数据识别,脏数据保存在`"train_dirty.txt"`,剩余训练数据保存在`"train_dirty_rest.txt"`:
@@ -243,7 +355,7 @@ python dirty.py \
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt和label.txt文件;默认为None。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `seed`:随机种子,默认为3。
* `dirty_file`:保存脏数据文件名,默认为"train_dirty.txt"。
@@ -252,7 +364,7 @@ python dirty.py \
* `dirty_threshold`:筛选脏数据用于重新标注的阈值,只选择影响分数大于阈值作为支持数据,默认为0。
-我们将筛选出脏数据进行重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
+我们将筛选出脏数据进行人工检查重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
```shell
cat ../data/train_dirty_rest.txt ../data/train_dirty.txt > ../data/train_clean.txt
@@ -270,7 +382,7 @@ cat ../data/train_dirty_rest.txt ../data/train_dirty.txt > ../data/train_clean.t
|训练集(500,含100条脏数据) + 脏数据清洗(100)|69.36|
|训练集(500,含100条脏数据) + 脏数据清洗(150)|73.15|
-## 数据增强策略方案
+### 数据增强策略方案
在数据量较少或某些类别样本量较少时,也可以通过数据增强策略的方式,生成更多的训练数据,提升模型效果。
@@ -290,7 +402,7 @@ python aug.py \
* `aug_type`:词替换/词插入增强类型,可选"synonym", "homonym", "mlm",建议在GPU环境下使用mlm类型;默认为"synonym"。
* `create_n`:生成的句子数量,默认为2。
* `aug_percent`:生成词替换百分比,默认为0.1。
-* `device`: 选用什么设备进行增强,选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
+* `device`: 选用什么设备进行增强,可选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
生成的增强数据保存在`"aug.txt"`文件中,与训练集数据`train.txt`合并得到新的训练集`train_aug.txt`重新进行训练:
diff --git a/applications/text_classification/multi_class/analysis/evaluate.py b/applications/text_classification/multi_class/analysis/evaluate.py
index bbd72a39be86..9aa272c650f9 100644
--- a/applications/text_classification/multi_class/analysis/evaluate.py
+++ b/applications/text_classification/multi_class/analysis/evaluate.py
@@ -37,7 +37,7 @@
parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
parser.add_argument("--dev_file", type=str, default="dev.txt", help="Dev dataset file name")
parser.add_argument("--label_file", type=str, default="label.txt", help="Label file name")
-parser.add_argument("--bad_case_path", type=str, default="./bad_case.txt", help="Bad case saving file path")
+parser.add_argument("--bad_case_file", type=str, default="bad_case.txt", help="Bad case saving file name")
args = parser.parse_args()
# yapf: enable
@@ -192,17 +192,16 @@ def evaluate():
logger.info("Evaluation examples in dev dataset: 0 (0%)")
logger.info("----------------------------")
- with open(args.bad_case_path, 'w', encoding="utf-8") as f:
- f.write("Confidence\tPrediction\tLabel\tText\n")
+ bad_case_path = os.path.join(args.dataset_dir, args.bad_case_file)
+ with open(bad_case_path, 'w', encoding="utf-8") as f:
+ f.write("Text\tLabel\tPrediction\n")
for i, (p, l) in enumerate(zip(preds, labels)):
p, l = int(p), int(l)
if p != l:
- f.write("{:.2f}".format(probs[i][p]) + "\t" + label_list[p] +
- "\t" + label_list[l] + "\t" + dev_ds.data[i]["text"] +
- "\n")
+ f.write(dev_ds.data[i]["text"] + "\t" + label_list[l] + "\t" +
+ label_list[p] + "\n")
f.close()
- logger.info("Bad case in dev dataset saved in {}".format(
- args.bad_case_path))
+ logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))
return
diff --git a/applications/text_classification/multi_class/analysis/sent_interpret.py b/applications/text_classification/multi_class/analysis/sent_interpret.py
new file mode 100644
index 000000000000..9a586d24a5b4
--- /dev/null
+++ b/applications/text_classification/multi_class/analysis/sent_interpret.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import random
+import os
+import argparse
+import numpy as np
+
+import paddle
+import paddle.nn.functional as F
+from paddle.io import DataLoader, BatchSampler
+from paddlenlp.data import DataCollatorWithPadding
+from paddlenlp.datasets import load_dataset
+from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer, LinearDecayWithWarmup
+from paddlenlp.utils.log import logger
+from trustai.interpretation import FeatureSimilarityModel
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument('--device', default="gpu", help="Select which device to train model, defaults to gpu.")
+parser.add_argument("--dataset_dir", required=True, type=str, help="The dataset directory should include train.txt,dev.txt and test.txt files.")
+parser.add_argument("--params_path", default="../checkpoint/", type=str, help="The path to model parameters to be loaded.")
+parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
+parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
+parser.add_argument("--seed", type=int, default=3, help="random seed for initialization")
+parser.add_argument("--top_k", type=int, default=3, help="Top K important training data.")
+parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
+parser.add_argument("--interpret_input_file", type=str, default="bad_case.txt", help="interpretation file name")
+parser.add_argument("--interpret_result_file", type=str, default="sent_interpret.txt", help="interpreted file name")
+args = parser.parse_args()
+# yapf: enable
+
+
+def set_seed(seed):
+ """
+ Set random seed
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+
+def read_local_dataset(path):
+ """
+ Read dataset file
+ """
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ items = line.strip().split('\t')
+ if items[0] == 'Text':
+ continue
+ if len(items) == 3:
+ yield {'text': items[0], 'label': items[1], 'predict': items[2]}
+ elif len(items) == 2:
+ yield {'text': items[0], 'label': items[1], 'predict': ''}
+ elif len(items) == 1:
+ yield {'text': items[0], 'label': '', 'predict': ''}
+ else:
+ logger.info(line.strip())
+ raise ValueError("{} should be in fixed format.".format(path))
+
+
+def preprocess_function(examples, tokenizer, max_seq_length):
+ """
+ Preprocess dataset
+ """
+ result = tokenizer(text=examples["text"], max_seq_len=max_seq_length)
+ return result
+
+
+class LocalDataCollatorWithPadding(DataCollatorWithPadding):
+ """
+ Convert the result of DataCollatorWithPadding from dict dictionary to a list
+ """
+
+ def __call__(self, features):
+ batch = super().__call__(features)
+ batch = list(batch.values())
+ return batch
+
+
+def find_positive_influence_data():
+
+ set_seed(args.seed)
+ paddle.set_device(args.device)
+
+ # Define model & tokenizer
+ if os.path.exists(os.path.join(
+ args.params_path, "model_state.pdparams")) and os.path.exists(
+ os.path.join(args.params_path,
+ "model_config.json")) and os.path.exists(
+ os.path.join(args.params_path,
+ "tokenizer_config.json")):
+ model = AutoModelForSequenceClassification.from_pretrained(
+ args.params_path)
+ tokenizer = AutoTokenizer.from_pretrained(args.params_path)
+ else:
+ raise ValueError("The {} should exist.".format(args.params_path))
+
+ # Prepare & preprocess dataset
+ train_path = os.path.join(args.dataset_dir, args.train_file)
+ interpret_path = os.path.join(args.dataset_dir, args.interpret_input_file)
+
+ train_ds = load_dataset(read_local_dataset, path=train_path, lazy=False)
+ interpret_ds = load_dataset(read_local_dataset,
+ path=interpret_path,
+ lazy=False)
+ trans_func = functools.partial(preprocess_function,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length)
+
+ train_ds = train_ds.map(trans_func)
+ interpret_ds = interpret_ds.map(trans_func)
+
+ # Batchify dataset
+ collate_fn = LocalDataCollatorWithPadding(tokenizer)
+ train_batch_sampler = BatchSampler(train_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ interpret_batch_sampler = BatchSampler(interpret_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ train_data_loader = DataLoader(dataset=train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=collate_fn)
+ interpret_data_loader = DataLoader(dataset=interpret_ds,
+ batch_sampler=interpret_batch_sampler,
+ collate_fn=collate_fn)
+
+ # Classifier_layer_name is the layer name of the last output layer
+ feature_sim = FeatureSimilarityModel(model,
+ train_data_loader,
+ classifier_layer_name="classifier")
+ # Feature similarity analysis & select sparse data
+ analysis_result = []
+ for batch in interpret_data_loader:
+ analysis_result += feature_sim(batch, sample_num=args.top_k)
+ with open(os.path.join(args.dataset_dir, args.interpret_result_file),
+ 'w') as f:
+ for i in range(len(analysis_result)):
+ f.write("text: " + interpret_ds.data[i]["text"] + '\n')
+ if 'predict' in interpret_ds.data[i]:
+ f.write("predict label: " + interpret_ds.data[i]["predict"] +
+ '\n')
+ if 'label' in interpret_ds.data[i]:
+ f.write("label: " + interpret_ds.data[i]["label"] + '\n')
+ f.write("examples with positive influence\n")
+ for i, (idx, score) in enumerate(
+ zip(analysis_result[i].pos_indexes,
+ analysis_result[i].pos_scores)):
+ f.write("support{} text: ".format(i + 1) +
+ train_ds.data[idx]["text"] + '\t' + "label: " +
+ train_ds.data[idx]["label"] + '\t' + "score: " +
+ "{:.5f}".format(score) + '\n')
+ f.close()
+
+
+if __name__ == "__main__":
+ find_positive_influence_data()
diff --git a/applications/text_classification/multi_class/analysis/word_interpret.ipynb b/applications/text_classification/multi_class/analysis/word_interpret.ipynb
new file mode 100644
index 000000000000..0b83eec1bdf5
--- /dev/null
+++ b/applications/text_classification/multi_class/analysis/word_interpret.ipynb
@@ -0,0 +1,359 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 词级别可解释性分析\n",
+ "本项目提供模型的词级别可解释性分析,包括LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用此项目进行数据分析。\n",
+ "\n",
+ "![image](https://user-images.githubusercontent.com/63761690/195086276-6ee16e96-4ec3-4a0f-821f-37546d21746b.png)\n",
+ " \n",
+ "\n",
+ "## 1.导入Python模块与参数配置\n",
+ "首先我们导入必要的导入必要python模块和设置配置参数,词级别可解释性分析算法支持三种待分析的文本 `INTERPRETER_FILE` 数据文件格式:\n",
+ "\n",
+ "**格式一:包括文本、标签、预测结果**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>'\\t'<预测结果>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式二:包括文本、标签**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式三:只包括文本**\n",
+ "```text\n",
+ "<文本>\n",
+ "...\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import functools\n",
+ "import random\n",
+ "import os\n",
+ "import argparse\n",
+ "\n",
+ "import jieba\n",
+ "import numpy as np \n",
+ "from trustai.interpretation import VisualizationTextRecord\n",
+ "from trustai.interpretation import get_word_offset\n",
+ "import paddle\n",
+ "from paddle.io import DataLoader, BatchSampler\n",
+ "from paddlenlp.data import DataCollatorWithPadding\n",
+ "from paddlenlp.datasets import load_dataset\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from trustai.interpretation import VisualizationTextRecord\n",
+ "from trustai.interpretation import get_word_offset\n",
+ "import paddle\n",
+ "from paddle.io import DataLoader, BatchSampler\n",
+ "from paddlenlp.data import DataCollatorWithPadding\n",
+ "from paddlenlp.datasets import load_dataset\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 预先定义配置参数\n",
+ "\n",
+ "# 运行环境,可选\"cpu\",\"gpu\",\"gpu:x\"(x为gpu编号)\n",
+ "DEVICE = \"gpu\"\n",
+ "# 数据路径\n",
+ "DATASET_DIR = \"../data\" \n",
+ "# 训练模型保存路径\n",
+ "PARAM_PATH = \"../checkpoint/\" \n",
+ "# tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数\n",
+ "MAX_LENGTH = 128 \n",
+ "# 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数\n",
+ "BATCH_SIZE = 1 \n",
+ "# 待分析解释的数据\n",
+ "INTERPRETER_FILE = \"bad_case.txt\"\n",
+ "# 可选 \"ig\",\"lime\",\"grad\" ,可以根据实际任务效果选择解释器\n",
+ "# \"grad\":GradShap方法依赖interpretdl\n",
+ "# !pip install interpretdl\n",
+ "INTERPRETER = \"ig\"\n",
+ "# 分析句子中TOP K关键词,K值\n",
+ "KEY_WORDS_NUM = 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def read_local_dataset(path):\n",
+ " \"\"\"\n",
+ " Read dataset file\n",
+ " \"\"\"\n",
+ " with open(path, 'r', encoding='utf-8') as f:\n",
+ " for line in f:\n",
+ " items = line.strip().split('\\t')\n",
+ " if items[0] == 'Text':\n",
+ " continue\n",
+ " if len(items) == 3:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': items[2]}\n",
+ " elif len(items) == 2:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': ''}\n",
+ " elif len(items) == 1:\n",
+ " yield {'text': items[0], 'label': '', 'predict': ''}\n",
+ " else:\n",
+ " raise ValueError(\"{} should be in fixed format.\".format(path))\n",
+ "\n",
+ "def preprocess_function(examples, tokenizer, max_seq_length):\n",
+ " \"\"\"\n",
+ " Preprocess dataset\n",
+ " \"\"\"\n",
+ " result = tokenizer(text=examples[\"text\"], max_seq_len=max_seq_length)\n",
+ " return result\n",
+ "\n",
+ "class LocalDataCollatorWithPadding(DataCollatorWithPadding):\n",
+ " \"\"\"\n",
+ " Convert the result of DataCollatorWithPadding from dict dictionary to a list\n",
+ " \"\"\"\n",
+ "\n",
+ " def __call__(self, features):\n",
+ " batch = super().__call__(features)\n",
+ " batch = list(batch.values())\n",
+ " return batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\u001b[32m[2022-10-11 12:17:29,041] [ INFO]\u001b[0m - We are using to load '/workspace/PaddleNLP/applications/text_classification/multi_class/checkpoint/'.\u001b[0m\n",
+ "W1011 12:17:29.044690 79080 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2\n",
+ "W1011 12:17:29.051118 79080 gpu_resources.cc:91] device: 0, cuDNN Version: 8.1.\n",
+ "\u001b[32m[2022-10-11 12:17:32,517] [ INFO]\u001b[0m - We are using to load '/workspace/PaddleNLP/applications/text_classification/multi_class/checkpoint/'.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "paddle.set_device(DEVICE)\n",
+ "\n",
+ "# Define model & tokenizer\n",
+ "if os.path.exists(os.path.join(\n",
+ " PARAM_PATH, \"model_state.pdparams\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"model_config.json\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"tokenizer_config.json\")):\n",
+ " model = AutoModelForSequenceClassification.from_pretrained(\n",
+ " PARAM_PATH)\n",
+ " tokenizer = AutoTokenizer.from_pretrained(PARAM_PATH)\n",
+ "else:\n",
+ " raise ValueError(\"The {} should exist.\".format(PARAM_PATH))\n",
+ "\n",
+ "# Prepare & preprocess dataset\n",
+ "interpret_path = os.path.join(DATASET_DIR, INTERPRETER_FILE)\n",
+ "\n",
+ "\n",
+ "interpret_ds = load_dataset(read_local_dataset, path=interpret_path, lazy=False)\n",
+ "trans_func = functools.partial(preprocess_function,\n",
+ " tokenizer=tokenizer,\n",
+ " max_seq_length=MAX_LENGTH)\n",
+ "\n",
+ "interpret_ds = interpret_ds.map(trans_func)\n",
+ "\n",
+ "# Batchify dataset\n",
+ "collate_fn = LocalDataCollatorWithPadding(tokenizer)\n",
+ "interpret_batch_sampler = BatchSampler(interpret_ds,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=False)\n",
+ "interpret_data_loader = DataLoader(dataset=interpret_ds,\n",
+ " batch_sampler=interpret_batch_sampler,\n",
+ " collate_fn=collate_fn)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Start token level interpretion, it will take some time...\n",
+ "Building prefix dict from the default dictionary ...\n",
+ "Loading model from cache /tmp/jieba.cache\n",
+ "Loading model cost 1.005 seconds.\n",
+ "Prefix dict has been built successfully.\n",
+ "Start word level alignment, it will take some time...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Init an interpreter\n",
+ "if INTERPRETER == 'ig':\n",
+ " from trustai.interpretation.token_level import IntGradInterpreter\n",
+ " interpreter = IntGradInterpreter(model)\n",
+ "elif INTERPRETER == 'lime':\n",
+ " from trustai.interpretation.token_level import LIMEInterpreter\n",
+ " interpreter = LIMEInterpreter(model, unk_id=tokenizer.convert_tokens_to_ids('[UNK]'), pad_id=tokenizer.convert_tokens_to_ids('[PAD]'))\n",
+ "else:\n",
+ " from trustai.interpretation.token_level import GradShapInterpreter\n",
+ " interpreter = GradShapInterpreter(model)\n",
+ "\n",
+ "# Use interpreter to get the importance scores for all data\n",
+ "print(\"Start token level interpretion, it will take some time...\")\n",
+ "analysis_result = []\n",
+ "for batch in interpret_data_loader:\n",
+ " analysis_result += interpreter(tuple(batch))\n",
+ "\n",
+ "# Add CLS and SEP tags to both original text and standard splited tokens\n",
+ "contexts = []\n",
+ "words = []\n",
+ "for i in range(len(interpret_ds)):\n",
+ " text = interpret_ds.data[i][\"text\"]\n",
+ " contexts.append(\"[CLS]\" + text + \"[SEP]\")\n",
+ " words.append([\"[CLS]\"] + list(jieba.cut(text)) + [\"[SEP]\"])\n",
+ "\n",
+ "# Get the offset map of tokenized tokens and standard splited tokens\n",
+ "print(\"Start word level alignment, it will take some time...\")\n",
+ "ori_offset_maps = []\n",
+ "word_offset_maps = []\n",
+ "for i in range(len(contexts)):\n",
+ " ori_offset_maps.append(tokenizer.get_offset_mapping(contexts[i]))\n",
+ " word_offset_maps.append(get_word_offset(contexts[i], words[i]))\n",
+ "\n",
+ "align_res = interpreter.alignment(analysis_result, contexts, words, word_offset_maps, ori_offset_maps, special_tokens=[\"[CLS]\", '[SEP]'],rationale_num=KEY_WORDS_NUM)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.core.display import display, HTML\n",
+ "class Visualization(VisualizationTextRecord):\n",
+ "\n",
+ " def __init__(self, interpret_res, true_label=None, pred_label=None, words=None):\n",
+ " if words is not None:\n",
+ " self.words = words\n",
+ " else:\n",
+ " self.words = interpret_res.words\n",
+ " self.pred_label = pred_label if pred_label is not None else ''\n",
+ " self.true_label = true_label if true_label is not None else ''\n",
+ " self.key_words = \" \".join(set(interpret_res.rationale_tokens))\n",
+ " word_attributions = interpret_res.word_attributions\n",
+ " _max = max(word_attributions)\n",
+ " _min = min(word_attributions)\n",
+ " self.word_attributions = [(word_imp - _min) / (_max - _min) for word_imp in word_attributions]\n",
+ "\n",
+ " def record_html(self):\n",
+ " \"\"\"change all informations to html\"\"\"\n",
+ " return \"\".join([\n",
+ " \"\",\n",
+ " self._format_class(self.true_label),\n",
+ " self._format_class(self.pred_label),\n",
+ " self._format_class(self.key_words),\n",
+ " self._format_word_attributions(),\n",
+ " \"
\",\n",
+ " ])\n",
+ " def _format_class(self, label):\n",
+ " return '{label} | '.format(label=label)\n",
+ "\n",
+ "def visualize_text(text_records):\n",
+ " \"\"\"visualize text\"\"\"\n",
+ " html = [\"\"]\n",
+ " rows = [\"Label | \"\n",
+ " \"Prediction | \"\n",
+ " \"Key words | \"\n",
+ " \"Important visualization | \"]\n",
+ " for record in text_records:\n",
+ " rows.append(record.record_html())\n",
+ " html.append(\"\".join(rows))\n",
+ " html.append(\"
---|
\")\n",
+ " html = HTML(\"\".join(html))\n",
+ " display(html)\n",
+ " return html.data\n",
+ "\n",
+ "\n",
+ "def visualize(interpret_res, ds):\n",
+ " records = []\n",
+ " for i in range(len(interpret_res)):\n",
+ " records.append(Visualization(interpret_res[i], true_label=ds.data[i][\"label\"], pred_label=ds.data[i][\"predict\"]))\n",
+ " html = visualize_text(records)\n",
+ " return html"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "Label | Prediction | Key words | Important visualization |
---|
其他 | 注意事项 | 月 服用 请问 的 可以 | [CLS] 您好 , 请问 一岁 三个 月 的 孩子 可以 服用 复方 锌 布 颗粒 吗 ? [SEP] |
|
其他 | 就医建议 | 输卵管 基本 检查 粘连 的 | [CLS] 输卵管 粘连 的 基本 检查 [SEP] |
|
其他 | 病情诊断 | 胎动 么 ? 是 会 | [CLS] 会 是 胎动 么 ? [SEP] |
|
其他 | 病情诊断 | 这是 经常 干呕 了 生病 | [CLS] 经常 干呕 恶心 , 这是 生病 了 吗 [SEP] |
|
就医建议 | 治疗方案 | 治 治疗 菏泽 怎么 白癜风 | [CLS] 菏泽 哪个 医院 治疗 白癜风 比较 好 ? 怎么 治 好 [SEP] |
|
其他 | 后果表述 | 左旋 不良反应 吃 的 肉碱 | [CLS] 吃 左旋 肉碱 后 的 不良反应 [SEP] |
|
注意事项 | 其他 | 上 出血 吗 做爱 环后 | [CLS] 上 环后 出血 可以 做爱 吗 [SEP] |
|
病情诊断 | 病因分析 | 感冒 了 呀 怎么 会 | [CLS] 孩子 感冒 了 怎么 会 喘息 呀 ? [SEP] |
|
其他 | 治疗方案 | 孕 周 21 | [CLS] 孕 21 周 [SEP] |
|
其他 | 指标解读 | 谱 心肌 意义 酶 ? | [CLS] 心肌 酶 谱 五项 意义 ? [SEP] |
|
病情诊断 | 其他 | 家长 判断 吃 吃饱 怎么 | [CLS] 家长 怎么 判断 孩子 吃饱 了 呢 ? 怎么 都 不肯 吃 了 就是 饱 了 [SEP] |
|
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# process for vbisualize\n",
+ "html = visualize(align_res, interpret_ds)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.7.13 64-bit",
+ "metadata": {
+ "interpreter": {
+ "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
+ }
+ },
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13-final"
+ },
+ "orig_nbformat": 2
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/applications/text_classification/multi_label/README.md b/applications/text_classification/multi_label/README.md
index b474cb47f5dd..bad8c214a976 100644
--- a/applications/text_classification/multi_label/README.md
+++ b/applications/text_classification/multi_label/README.md
@@ -37,7 +37,6 @@
- 【标注成本高、标注样本较少的小样本场景】 👉 [提示学习多标签分类方案](./few-shot#readme)
-
## 2. 快速开始
@@ -92,20 +91,12 @@ python3 -m pip install scikit-learn==1.0.2
```text
multi_label/
+├── few-shot # 小样本学习方案
+├── analysis # 分析模块
├── deploy # 部署
│ └── predictor # 离线部署
-│ │ ├── infer.py # 测试脚本
-│ │ ├── predictor.py 离线部署脚本
-│ │ └── README.md # 离线部署使用说明
│ ├── paddle_serving # PaddleServing在线服务化部署
-│ │ ├──config.yml # 服务端的配置文件
-│ │ ├──rpc_client.py # 客户端预测脚本
-│ │ ├──service.py # 服务端的脚本
-│ │ └── README.md # 在线服务化部署使用说明
│ └── triton_serving # Triton在线服务化部署
-│ ├── README.md # Triton部署使用说明
-│ ├── seqcls_grpc_client.py # 客户端预测脚本
-│ └── models
├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 静态图模型导出脚本
@@ -128,7 +119,7 @@ data/
├── label.txt # 分类标签文件
└── data.txt # 待预测数据文件(可选)
```
-**训练、开发、测试数据集**文件中文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用`','`逗号分隔开。
+**训练、开发、测试数据集**文件中文本与标签类别名用tab符`'\t'`分隔开,标签中多个标签之间用`','`逗号分隔开,文本中避免出现tab符`'\t'`。
- train.txt/dev.txt/test.txt 文件格式:
```text
@@ -267,10 +258,16 @@ checkpoint/
#### 2.4.2 训练评估与模型优化
-训练后的模型我们可以使用 [模型分析模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
+文本分类预测过程中常会遇到诸如"模型为什么会预测出错误的结果","如何提升模型的表现"等问题。[Analysis模块](./analysis) 提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+
+
+
+
+**模型评估:** 训练后的模型我们可以使用 [Analysis模块](./analysis) 对每个类别分别进行评估,并输出预测错误样本(bad case),默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
```shell
-python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_path "./bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
+python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32 --bad_case_file "bad_case.txt" --dataset_dir "data" --params_path "./checkpoint"
```
输出打印示例:
@@ -295,15 +292,37 @@ python analysis/evaluate.py --device "gpu" --max_seq_length 128 --batch_size 32
预测错误的样本保存在bad_case.txt文件中:
```text
-Prediction Label Text
-不动产分割 不动产分割,有夫妻共同财产 2014年,王X以其与肖X协议离婚时未分割该套楼房的首付款为由,起诉至法院,要求分得楼房的首付款15万元。
-婚后分居,准予离婚 二次起诉离婚,准予离婚,婚后分居,法定离婚 但原、被告对已建立起的夫妻感情不够珍惜,因琐事即发生吵闹并最终分居,对夫妻感情造成了严重的影响,现原、被告已分居六年有余,且经人民法院判决不准离婚后仍未和好,夫妻感情确已破裂,依法应准予原、被告离婚。
-婚后有子女,限制行为能力子女抚养 婚后有子女 婚后生有一女,取名彭某乙,已11岁,现已由被告从铁炉白族乡中心小学转入走马镇李桥小学读书。
-婚后分居 不履行家庭义务,婚后分居 2015年2月23日,被告将原告赶出家门,原告居住于娘家待产,双方分居至今。
+Text Label Prediction
+2014年,王X以其与肖X协议离婚时未分割该套楼房的首付款为由,起诉至法院,要求分得楼房的首付款15万元。 不动产分割,有夫妻共同财产 不动产分割
+但原、被告对已建立起的夫妻感情不够珍惜,因琐事即发生吵闹并最终分居,对夫妻感情造成了严重的影响,现原、被告已分居六年有余,且经人民法院判决不准离婚后仍未和好,夫妻感情确已破裂,依法应准予原、被告离婚。 二次起诉离婚,准予离婚,婚后分居,法定离婚 婚后分居,准予离婚
+婚后生有一女,取名彭某乙,已11岁,现已由被告从铁炉白族乡中心小学转入走马镇李桥小学读书。 婚后有子女 婚后有子女,限制行为能力子女抚养
...
```
+**可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果,用于错误样本(bad case)分析,细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 单词级别可解释性分析,也即分析待预测样本中哪一些单词对模型预测结果起重要作用。以下图为例,用颜色深浅表示单词对预测结果的重要性。
+
+
+
+
+- 句子级别可解释性分析 ,也即分析对待预测样本的模型预测结果与训练集中中哪些样本有重要关系。下面的例子表明句子级别可解释性分析可以帮助理解待预测样本的预测结果与训练集中样本之间的关联。
+```text
+text: 2015年2月23日,被告将原告赶出家门,原告居住于娘家待产,双方分居至今。
+predict label: 婚后分居
+label: 不履行家庭义务,婚后分居
+examples with positive influence
+support1 text: 2014年中秋节原告回了娘家,原、被告分居至今。 label: 婚后分居 score: 0.99942
+support2 text: 原告于2013年8月13日离开被告家,分居至今。 label: 婚后分居 score: 0.99916
+support3 text: 2014年4月,被告外出务工,双方分居至今。 label: 婚后分居 score: 0.99902
+...
+```
+
+**数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果,策略细节详见[训练评估与模型优化指南](analysis/README.md)。
+
+- 稀疏数据筛选主要是解决数据不均衡、训练数据覆盖不足的问题,通过数据增强和数据标注两种方式解决这一问题。
+- 脏数据清洗可以帮助开发者筛选训练集中错误标注的数据,对这些数据重新进行人工标注,得到标注正确的数据再重新进行训练。
+- 数据增强策略提供多种数据增强方案,可以快速扩充数据,提高模型泛化性和鲁棒性。
-模型表现常常受限于数据质量,在analysis模块中我们提供了基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)的稀疏数据筛选、脏数据清洗、数据增强三种优化方案助力开发者提升模型效果,更多模型评估和优化方案细节详见[训练评估与模型优化指南](analysis/README.md)。
#### 2.4.3 模型预测
训练结束后,输入待预测数据(data.txt)和类别标签对照列表(label.txt),使用训练好的模型进行,默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`:
@@ -333,12 +352,15 @@ python predict.py --device "gpu" --max_seq_length 128 --batch_size 32 --dataset_
```shell
python export_model.py --params_path ./checkpoint/ --output_path ./export
```
-如果使用ERNIE M作为预训练模型,运行方式:
+
+如果使用多语言模型 ERNIE M作为预训练模型,运行方式:
+
```shell
python export_model.py --params_path ./checkpoint/ --output_path ./export --multilingual
```
可支持配置的参数:
+
* `multilingual`:是否为多语言任务(是否使用ERNIE M作为预训练模型);默认为False。
* `params_path`:动态图训练保存的参数路径;默认为"./checkpoint/"。
* `output_path`:静态图图保存的参数路径;默认为"./export"。
@@ -415,7 +437,6 @@ prune/
3. ERNIE Base、Medium、Mini、Micro、Nano的模型宽度(multi head数量)为12,ERNIE Xbase、Large 模型宽度(multi head数量)为16,保留比例`width_mult`乘以宽度(multi haed数量)应为整数。
-4. **压缩API暂不支持多语言预训练模型ERNIE-M**,相关功能正在加紧开发中。
#### 2.5.3 部署方案
diff --git a/applications/text_classification/multi_label/analysis/README.md b/applications/text_classification/multi_label/analysis/README.md
index 84b8ea6b1382..56927a8c2973 100644
--- a/applications/text_classification/multi_label/analysis/README.md
+++ b/applications/text_classification/multi_label/analysis/README.md
@@ -1,18 +1,29 @@
-# 多标签训练评估与模型优化指南
+# 训练评估与模型优化指南
**目录**
- * [analysis模块介绍](#analysis模块介绍)
+ * [Analysis模块介绍](#Analysis模块介绍)
+ * [环境准备](#环境准备)
* [模型评估](#模型评估)
- * [稀疏数据筛选方案](#稀疏数据筛选方案)
- * [脏数据清洗方案](#脏数据清洗方案)
- * [数据增强策略方案](#数据增强策略方案)
+ * [可解释性分析](#可解释性分析)
+ * [单词级别可解释性分析](#单词级别可解释性分析)
+ * [句子级别可解释性分析](#句子级别可解释性分析)
+ * [数据优化](#数据优化)
+ * [稀疏数据筛选方案](#稀疏数据筛选方案)
+ * [脏数据清洗方案](#脏数据清洗方案)
+ * [数据增强策略方案](#数据增强策略方案)
-## analysis模块介绍
+## Analysis模块介绍
-analysis模块提供了**模型评估**脚本对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题问题。同时基于[可信AI工具集](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化方案从多角度帮助开发者提升模型效果。
+Analysis模块提供了**模型评估、可解释性分析、数据优化**等功能,旨在帮助开发者更好地分析文本分类模型预测结果和对模型效果进行优化。
+
+- **模型评估:** 对整体分类情况和每个类别分别进行评估,并打印预测错误样本,帮助开发者分析模型表现找到训练和预测数据中存在的问题。
+
+- **可解释性分析:** 基于[TrustAI](https://github.com/PaddlePaddle/TrustAI)提供单词和句子级别的模型可解释性分析,帮助理解模型预测结果。
+
+- **数据优化:** 结合[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[数据增强API](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/dataaug.md)提供了**稀疏数据筛选、脏数据清洗、数据增强**三种优化策略,从多角度优化训练数据提升模型效果。
-
+
以下是本项目主要代码结构及说明:
@@ -20,10 +31,27 @@ analysis模块提供了**模型评估**脚本对整体分类情况和每个类
```text
analysis/
├── evaluate.py # 评估脚本
+├── sent_interpret.py # 句子级别可解释性分析脚本
+├── word_interpret.py # 单词级别可解释性分析notebook
├── sparse.py # 稀疏数据筛选脚本
├── dirty.py # 脏数据清洗脚本
├── aug.py # 数据增强脚本
-└── README.md # 多标签训练评估与模型优化指南
+└── README.md # 训练评估与模型优化指南
+```
+
+## 环境准备
+需要可解释性分析和数据优化需要安装相关环境。
+- trustai >= 0.1.7
+- interpretdl >= 0.7.0
+
+**安装TrustAI**(可选)如果使用可解释性分析和数据优化中稀疏数据筛选和脏数据清洗需要安装TrustAI。
+```shell
+pip install trustai==0.1.7
+```
+
+**安装InterpretDL**(可选)如果使用词级别可解释性分析GradShap方法,需要安装InterpretDL
+```shell
+pip install interpretdl==0.7.0
```
## 模型评估
@@ -37,14 +65,14 @@ python evaluate.py \
--params_path "../checkpoint" \
--max_seq_length 128 \
--batch_size 32 \
- --bad_case_path "./bad_case.txt"
+ --bad_case_file "bad_case.txt"
```
默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt、dev.txt和label.txt文件;默认为None。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
@@ -76,30 +104,108 @@ python evaluate.py \
预测错误的样本保存在bad_case.txt文件中:
```text
-Prediction Label Text
-不动产分割 不动产分割,有夫妻共同财产 2014年,王X以其与肖X协议离婚时未分割该套楼房的首付款为由,起诉至法院,要求分得楼房的首付款15万元。
-婚后分居,准予离婚 二次起诉离婚,准予离婚,婚后分居,法定离婚 但原、被告对已建立起的夫妻感情不够珍惜,因琐事即发生吵闹并最终分居,对夫妻感情造成了严重的影响,现原、被告已分居六年有余,且经人民法院判决不准离婚后仍未和好,夫妻感情确已破裂,依法应准予原、被告离婚。
-婚后有子女,限制行为能力子女抚养 婚后有子女 婚后生有一女,取名彭某乙,已11岁,现已由被告从铁炉白族乡中心小学转入走马镇李桥小学读书。
-婚后分居 不履行家庭义务,婚后分居 2015年2月23日,被告将原告赶出家门,原告居住于娘家待产,双方分居至今。
+Text Label Prediction
+2014年,王X以其与肖X协议离婚时未分割该套楼房的首付款为由,起诉至法院,要求分得楼房的首付款15万元。 不动产分割,有夫妻共同财产 不动产分割
+但原、被告对已建立起的夫妻感情不够珍惜,因琐事即发生吵闹并最终分居,对夫妻感情造成了严重的影响,现原、被告已分居六年有余,且经人民法院判决不准离婚后仍未和好,夫妻感情确已破裂,依法应准予原、被告离婚。 二次起诉离婚,准予离婚,婚后分居,法定离婚 婚后分居,准予离婚
+婚后生有一女,取名彭某乙,已11岁,现已由被告从铁炉白族乡中心小学转入走马镇李桥小学读书。 婚后有子女 婚后有子女,限制行为能力子女抚养
...
```
+## 可解释性分析
+"模型为什么会预测出这个结果?"是文本分类任务开发者时常遇到的问题,如何分析错误样本(bad case)是文本分类任务落地中重要一环,本项目基于TrustAI开源了基于词级别和句子级别的模型可解释性分析方法,帮助开发者更好地理解文本分类模型与数据,有助于后续的模型优化与数据清洗标注。
+
+### 单词级别可解释性分析
+本项目开源模型的词级别可解释性分析Notebook,提供LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用Jupyter Notebook进行数据分析。
+
+运行 [word_interpret.ipynb](./word_interpreter.ipynb) 代码,即可分析影响样本预测结果的关键词以及可视化所有词对预测结果的贡献情况,颜色越深代表这个词对预测结果影响越大:
+
+
+
+
+### 句子级别可解释性分析
+本项目基于特征相似度([FeatureSimilarity](https://arxiv.org/abs/2104.04128))算法,计算对样本预测结果正影响的训练数据,帮助理解模型的预测结果与训练集数据的关系。
-## 稀疏数据筛选方案
+待分析数据文件`interpret_input_file`应为以下三种格式中的一种:
+**格式一:包括文本、标签、预测结果**
+```text
+<文本>'\t'<标签>'\t'<预测结果>
+...
+```
-稀疏数据指缺乏足够训练数据支持导致低置信度的待预测数据,简单来说,由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别。本项目中稀疏数据筛选基于TrustAI(可信AI)工具集,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+**格式二:包括文本、标签**
+```text
+<文本>'\t'<标签>
+...
+```
-稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的稀疏数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型预测效果。
+**格式三:只包括文本**
+```text
+<文本>
+准予原告胡某甲与被告韩某甲离婚。
+...
+```
-**安装TrustAI**
+我们可以运行代码,得到支持样本模型预测结果的训练数据:
```shell
-pip install trustai==0.1.4
+python sent_interpret.py \
+ --device "gpu" \
+ --dataset_dir "../data" \
+ --params_path "../checkpoint/" \
+ --max_seq_length 128 \
+ --batch_size 16 \
+ --top_k 3 \
+ --train_file "train.txt" \
+ --interpret_input_file "bad_case.txt" \
+ --interpret_result_file "sent_interpret.txt"
+```
+
+默认在GPU环境下使用,在CPU环境下修改参数配置为`--device "cpu"`
+
+可支持配置的参数:
+
+* `device`: 选用什么设备进行训练,可可选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
+* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
+* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
+* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
+* `seed`:随机种子,默认为3。
+* `top_k`:筛选支持训练证据数量;默认为3。
+* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
+* `interpret_input_file`:本地数据集中待分析文件名;默认为"bad_case.txt"。
+* `interpret_result_file`:保存句子级别可解释性结果文件名;默认为"sent_interpret.txt"。
+
+可解释性结果保存在 `interpret_result_file` 文件中:
+```text
+text: 2015年2月23日,被告将原告赶出家门,原告居住于娘家待产,双方分居至今。
+predict label: 婚后分居
+label: 不履行家庭义务,婚后分居
+examples with positive influence
+support1 text: 2014年中秋节原告回了娘家,原、被告分居至今。 label: 婚后分居 score: 0.99942
+support2 text: 原告于2013年8月13日离开被告家,分居至今。 label: 婚后分居 score: 0.99916
+support3 text: 2014年4月,被告外出务工,双方分居至今。 label: 婚后分居 score: 0.99902
+...
```
-### 稀疏数据识别--数据增强
-这里我们将介绍稀疏数据识别--数据增强流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选训练集中对稀疏数据的支持数据进行数据增强,然后将得到的数据增强后的支持数据加入到训练集中进行训练。
+## 数据优化
+
+### 稀疏数据筛选方案
-现在我们进行稀疏数据识别--数据增强,得到新增训练数据:
+稀疏数据筛选适用于文本分类中**数据不平衡或训练数据覆盖不足**的场景,简单来说,就是由于模型在训练过程中没有学习到足够与待预测样本相似的数据,模型难以正确预测样本所属类别的情况。稀疏数据筛选旨在开发集中挖掘缺乏训练证据支持的数据,通常可以采用**数据增强**或**少量数据标注**的两种低成本方式,提升模型在开发集的预测效果。
+
+本项目中稀疏数据筛选基于TrustAI,利用基于特征相似度的实例级证据分析方法,抽取开发集中样本的支持训练证据,并计算支持证据平均分(通常为得分前三的支持训练证据均分)。分数较低的样本表明其训练证据不足,在训练集中较为稀疏,实验表明模型在这些样本上表现也相对较差。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+
+
+#### 稀疏数据识别—数据增强
+
+这里我们将介绍稀疏数据识别—数据增强流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
+
+- **数据增强**:将稀疏数据集在训练集中的支持证据应用数据增强策略,这些数据增强后的训练数据记为支持数据集(Support Dataset);
+
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
+
+现在我们进行稀疏数据识别-数据增强,得到支持数据集:
```shell
python sparse.py \
@@ -117,7 +223,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `aug_strategy`:数据增强类型,可选"duplicate","substitute", "insert", "delete", "swap";默认为"substitute"。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -154,9 +260,15 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_aug.txt
|训练集+支持增强集(delete) |84.83| 51.04|
|训练集+支持增强集(swap) |**84.84**|51.06|
-### 稀疏数据识别--数据标注
+#### 稀疏数据识别-数据标注
-这里我们将介绍稀疏数据识别--数据标注流程,首先使用数据增强脚本挖掘开发集中的稀疏数据,然后筛选对稀疏数据支持的未标注数据,然后将得到支持数据进行标注后加入到训练集中进行训练。
+本方案能够有针对性进行数据标注,相比于随机标注数据更好提高模型预测效果。这里我们将介绍稀疏数据识别-数据标注流程:
+
+- **稀疏数据识别:** 挖掘开发集中的缺乏训练证据支持数据,记为稀疏数据集(Sparse Dataset);
+
+- **数据标注**:在未标注数据集中筛选稀疏数据集的支持证据,并进行数据标注,记为支持数据集(Support Dataset);
+
+- **重新训练模型:** 将支持数据集加入到原有的训练集获得新的训练集,重新训练新的文本分类模型。
现在我们进行稀疏数据识别--数据标注,得到待标注数据:
@@ -177,7 +289,7 @@ python sparse.py \
可支持配置的参数:
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含dev.txt和label.txt文件;默认为None。
* `annotate`:选择稀疏数据识别--数据标注模式;默认为False。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
@@ -214,14 +326,13 @@ cat ../data/train.txt ../data/support.txt > ../data/train_sparse_annotate.txt
|训练集+策略采样集(100) |**86.14**|56.93|
|训练集+随机采样集(100) |84.69|50.76|
-## 脏数据清洗方案
+### 脏数据清洗方案
-训练数据标注质量对模型效果有较大影响,但受限于标注人员水平、标注任务难易程度等影响,训练数据中都存在一定比例的标注较差的数据(脏数据)。当标注数据规模较大时,数据标注检查就成为一个难题。本项目中脏数据清洗基于TrustAI(可信AI)工具集,利用基于表示点方法的实例级证据分析方法,计算训练数据对模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为脏数据(标注错误样本)。更多细节详见[TrustAI](https://github.com/PaddlePaddle/TrustAI)和[实例级证据分析](https://github.com/PaddlePaddle/TrustAI/blob/main/trustai/interpretation/example_level/README.md)。
+脏数据清洗方案是基于已训练好的文本分类模型,筛选出训练数据集中标注错误的数据,再由人工检查重新标注,获得标注正确的数据集进行重新训练。我们将介绍脏数据清洗流程:
-**安装TrustAI**
-```shell
-pip install trustai==0.1.4
-```
+- **脏数据筛选:** 基于TrustAI中表示点方法,计算训练数据对文本分类模型的影响分数,分数高的训练数据表明对模型影响大,这些数据有较大概率为标注错误样本,记为脏数据集(Dirty Dataset)。
+
+- **数据清洗、训练:** 将筛选出的脏数据由人工重新检查,为数据打上正确的标签。将清洗后的训练数据重新放入文本分类模型进行训练。
现在我们进行脏数据识别,脏数据保存在`"train_dirty.txt"`,剩余训练数据保存在`"train_dirty_rest.txt"`:
@@ -244,21 +355,21 @@ python dirty.py \
* `dataset_dir`:必须,本地数据集路径,数据集路径中应包含train.txt和label.txt文件;默认为None。
* `max_seq_length`:分词器tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数;默认为128。
* `params_path`:保存训练模型的目录;默认为"../checkpoint/"。
-* `device`: 选用什么设备进行训练,选择cpu、gpu、xpu、npu;默认为"gpu"。
+* `device`: 选用什么设备进行训练,可选择cpu、gpu、xpu、npu;默认为"gpu"。
* `batch_size`:批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数;默认为32。
* `seed`:随机种子,默认为3。
* `dirty_file`:保存脏数据文件名,默认为"train_dirty.txt"。
* `rest_file`:保存剩余数据(非脏数据)文件名,默认为"train_dirty_rest.txt"。
* `train_file`:本地数据集中训练集文件名;默认为"train.txt"。
-* `dirty_num`:筛选脏数据用于重新标注的数量;默认为100。
-* `dirty_threshold`:筛选脏数据用于重新标注的阈值,只选择影响分数大于阈值作为有效数据,默认为0。
+* `dirty_threshold`:筛选脏数据用于重新标注的阈值,只选择影响分数大于阈值作为支持数据,默认为0。
-我们将筛选出脏数据进行重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
+我们将筛选出脏数据进行人工检查重新标注,可以将`train_dirty.txt`直接导入标注工具doccano帮助更快重新标注,详情请参考[文本分类任务doccano数据标注使用指南](../../doccano.md)进行文本分类数据标注。然后将已重新标注的脏数据`train_dirty.txt`与剩余训练集数据`train_dirty_rest.txt`合并得到新的训练集`train_clean.txt`重新进行训练:
```shell
cat ../data/train_dirty_rest.txt ../data/train_dirty.txt > ../data/train_clean.txt
```
+
**方案效果**
我们在CAIL2019—婚姻家庭要素提取数据集抽取部分训练数据(训练集数据规模:500)进行实验,取50条数据进行脏数据处理,也即50条训练数据为标签错误数据。选择不同`dirty_num`应用脏数据清洗策略进行评测:
@@ -291,7 +402,7 @@ python aug.py \
* `aug_type`:词替换/词插入增强类型,可选"synonym", "homonym", "mlm",建议在GPU环境下使用mlm类型;默认为"synonym"。
* `create_n`:生成的句子数量,默认为2。
* `aug_percent`:生成词替换百分比,默认为0.1。
-* `device`: 选用什么设备进行增强,选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
+* `device`: 选用什么设备进行增强,可选择cpu、gpu、xpu、npu,仅在使用mlm类型有影响;默认为"gpu"。
生成的增强数据保存在`"aug.txt"`文件中,与训练集数据`train.txt`合并得到新的训练集`train_aug.txt`重新进行训练:
diff --git a/applications/text_classification/multi_label/analysis/evaluate.py b/applications/text_classification/multi_label/analysis/evaluate.py
index b79127c70426..b7e32acf5dc7 100644
--- a/applications/text_classification/multi_label/analysis/evaluate.py
+++ b/applications/text_classification/multi_label/analysis/evaluate.py
@@ -37,7 +37,7 @@
parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
parser.add_argument("--dev_file", type=str, default="dev.txt", help="Dev dataset file name")
parser.add_argument("--label_file", type=str, default="label.txt", help="Label file name")
-parser.add_argument("--bad_case_path", type=str, default="./bad_case.txt", help="Bad case saving file path")
+parser.add_argument("--bad_case_file", type=str, default="bad_case.txt", help="Bad case saving file name")
args = parser.parse_args()
# yapf: enable
@@ -201,23 +201,22 @@ def evaluate():
report[str(i)]['recall'] * 100,
report[str(i)]['f1-score'] * 100))
logger.info("----------------------------")
-
- with open(args.bad_case_path, 'w', encoding="utf-8") as f:
- f.write("Prediction\tLabel\tText\n")
+ bad_case_path = os.path.join(args.dataset_dir, args.bad_case_file)
+ with open(bad_case_path, 'w', encoding="utf-8") as f:
+ f.write("Text\tLabel\tPrediction\n")
for i in range(len(preds)):
for p, l in zip(preds[i], labels[i]):
if (p and l == 0) or (not p and l == 1):
pred_n = [
label_map[i] for i, pp in enumerate(preds[i]) if pp
]
- f.write(",".join(pred_n) + "\t" +
+ f.write(dev_ds.data[i]["text"] + "\t" +
dev_ds.data[i]["label_n"] + "\t" +
- dev_ds.data[i]["text"] + "\n")
+ ",".join(pred_n) + "\n")
break
f.close()
- logger.info("Bad case in dev dataset saved in {}".format(
- args.bad_case_path))
+ logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))
return
diff --git a/applications/text_classification/multi_label/analysis/sent_interpret.py b/applications/text_classification/multi_label/analysis/sent_interpret.py
new file mode 100644
index 000000000000..9a586d24a5b4
--- /dev/null
+++ b/applications/text_classification/multi_label/analysis/sent_interpret.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import random
+import os
+import argparse
+import numpy as np
+
+import paddle
+import paddle.nn.functional as F
+from paddle.io import DataLoader, BatchSampler
+from paddlenlp.data import DataCollatorWithPadding
+from paddlenlp.datasets import load_dataset
+from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer, LinearDecayWithWarmup
+from paddlenlp.utils.log import logger
+from trustai.interpretation import FeatureSimilarityModel
+
+# yapf: disable
+parser = argparse.ArgumentParser()
+parser.add_argument('--device', default="gpu", help="Select which device to train model, defaults to gpu.")
+parser.add_argument("--dataset_dir", required=True, type=str, help="The dataset directory should include train.txt,dev.txt and test.txt files.")
+parser.add_argument("--params_path", default="../checkpoint/", type=str, help="The path to model parameters to be loaded.")
+parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.")
+parser.add_argument("--batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
+parser.add_argument("--seed", type=int, default=3, help="random seed for initialization")
+parser.add_argument("--top_k", type=int, default=3, help="Top K important training data.")
+parser.add_argument("--train_file", type=str, default="train.txt", help="Train dataset file name")
+parser.add_argument("--interpret_input_file", type=str, default="bad_case.txt", help="interpretation file name")
+parser.add_argument("--interpret_result_file", type=str, default="sent_interpret.txt", help="interpreted file name")
+args = parser.parse_args()
+# yapf: enable
+
+
+def set_seed(seed):
+ """
+ Set random seed
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+
+def read_local_dataset(path):
+ """
+ Read dataset file
+ """
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ items = line.strip().split('\t')
+ if items[0] == 'Text':
+ continue
+ if len(items) == 3:
+ yield {'text': items[0], 'label': items[1], 'predict': items[2]}
+ elif len(items) == 2:
+ yield {'text': items[0], 'label': items[1], 'predict': ''}
+ elif len(items) == 1:
+ yield {'text': items[0], 'label': '', 'predict': ''}
+ else:
+ logger.info(line.strip())
+ raise ValueError("{} should be in fixed format.".format(path))
+
+
+def preprocess_function(examples, tokenizer, max_seq_length):
+ """
+ Preprocess dataset
+ """
+ result = tokenizer(text=examples["text"], max_seq_len=max_seq_length)
+ return result
+
+
+class LocalDataCollatorWithPadding(DataCollatorWithPadding):
+ """
+ Convert the result of DataCollatorWithPadding from dict dictionary to a list
+ """
+
+ def __call__(self, features):
+ batch = super().__call__(features)
+ batch = list(batch.values())
+ return batch
+
+
+def find_positive_influence_data():
+
+ set_seed(args.seed)
+ paddle.set_device(args.device)
+
+ # Define model & tokenizer
+ if os.path.exists(os.path.join(
+ args.params_path, "model_state.pdparams")) and os.path.exists(
+ os.path.join(args.params_path,
+ "model_config.json")) and os.path.exists(
+ os.path.join(args.params_path,
+ "tokenizer_config.json")):
+ model = AutoModelForSequenceClassification.from_pretrained(
+ args.params_path)
+ tokenizer = AutoTokenizer.from_pretrained(args.params_path)
+ else:
+ raise ValueError("The {} should exist.".format(args.params_path))
+
+ # Prepare & preprocess dataset
+ train_path = os.path.join(args.dataset_dir, args.train_file)
+ interpret_path = os.path.join(args.dataset_dir, args.interpret_input_file)
+
+ train_ds = load_dataset(read_local_dataset, path=train_path, lazy=False)
+ interpret_ds = load_dataset(read_local_dataset,
+ path=interpret_path,
+ lazy=False)
+ trans_func = functools.partial(preprocess_function,
+ tokenizer=tokenizer,
+ max_seq_length=args.max_seq_length)
+
+ train_ds = train_ds.map(trans_func)
+ interpret_ds = interpret_ds.map(trans_func)
+
+ # Batchify dataset
+ collate_fn = LocalDataCollatorWithPadding(tokenizer)
+ train_batch_sampler = BatchSampler(train_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ interpret_batch_sampler = BatchSampler(interpret_ds,
+ batch_size=args.batch_size,
+ shuffle=False)
+ train_data_loader = DataLoader(dataset=train_ds,
+ batch_sampler=train_batch_sampler,
+ collate_fn=collate_fn)
+ interpret_data_loader = DataLoader(dataset=interpret_ds,
+ batch_sampler=interpret_batch_sampler,
+ collate_fn=collate_fn)
+
+ # Classifier_layer_name is the layer name of the last output layer
+ feature_sim = FeatureSimilarityModel(model,
+ train_data_loader,
+ classifier_layer_name="classifier")
+ # Feature similarity analysis & select sparse data
+ analysis_result = []
+ for batch in interpret_data_loader:
+ analysis_result += feature_sim(batch, sample_num=args.top_k)
+ with open(os.path.join(args.dataset_dir, args.interpret_result_file),
+ 'w') as f:
+ for i in range(len(analysis_result)):
+ f.write("text: " + interpret_ds.data[i]["text"] + '\n')
+ if 'predict' in interpret_ds.data[i]:
+ f.write("predict label: " + interpret_ds.data[i]["predict"] +
+ '\n')
+ if 'label' in interpret_ds.data[i]:
+ f.write("label: " + interpret_ds.data[i]["label"] + '\n')
+ f.write("examples with positive influence\n")
+ for i, (idx, score) in enumerate(
+ zip(analysis_result[i].pos_indexes,
+ analysis_result[i].pos_scores)):
+ f.write("support{} text: ".format(i + 1) +
+ train_ds.data[idx]["text"] + '\t' + "label: " +
+ train_ds.data[idx]["label"] + '\t' + "score: " +
+ "{:.5f}".format(score) + '\n')
+ f.close()
+
+
+if __name__ == "__main__":
+ find_positive_influence_data()
diff --git a/applications/text_classification/multi_label/analysis/word_interpret.ipynb b/applications/text_classification/multi_label/analysis/word_interpret.ipynb
new file mode 100644
index 000000000000..07dd3f85c839
--- /dev/null
+++ b/applications/text_classification/multi_label/analysis/word_interpret.ipynb
@@ -0,0 +1,381 @@
+{
+ "metadata": {
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13-final"
+ },
+ "orig_nbformat": 2,
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3.7.13 64-bit",
+ "metadata": {
+ "interpreter": {
+ "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2,
+ "cells": [
+ {
+ "source": [
+ "# 词级别可解释性分析\n",
+ "本项目提供模型的词级别可解释性分析,包括LIME、Integrated Gradient、GradShap 三种分析方法,支持分析微调后模型的预测结果,开发者可以通过更改**数据目录**和**模型目录**在自己的任务中使用此项目进行数据分析。\n",
+ "\n",
+ "![image](https://user-images.githubusercontent.com/63761690/192739675-63145d59-23c6-416f-bf71-998fd4995254.png)\n",
+ "\n",
+ "## 1.导入Python模块与参数配置\n",
+ "首先我们导入必要的导入必要python模块和设置配置参数,词级别可解释性分析算法支持三种待分析的文本 `INTERPRETER_FILE` 数据文件格式:\n",
+ "\n",
+ "**格式一:包括文本、标签、预测结果**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>'\\t'<预测结果>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式二:包括文本、标签**\n",
+ "```text\n",
+ "<文本>'\\t'<标签>\n",
+ "...\n",
+ "```\n",
+ "\n",
+ "**格式三:只包括文本**\n",
+ "```text\n",
+ "<文本>\n",
+ "准予原告胡某甲与被告韩某甲离婚。\n",
+ "...\n",
+ "```\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "source": [
+ "import functools\n",
+ "import random\n",
+ "import os\n",
+ "import argparse\n",
+ "\n",
+ "import jieba\n",
+ "import numpy as np\n",
+ "from trustai.interpretation import VisualizationTextRecord\n",
+ "from trustai.interpretation import get_word_offset\n",
+ "import paddle\n",
+ "from paddle.io import DataLoader, BatchSampler\n",
+ "from paddlenlp.data import DataCollatorWithPadding\n",
+ "from paddlenlp.datasets import load_dataset\n",
+ "from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer\n"
+ ],
+ "cell_type": "code",
+ "metadata": {},
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "grep: warning: GREP_OPTIONS is deprecated; please use an alias or script\n",
+ "/usr/local/lib/python3.7/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "/usr/local/lib/python3.7/dist-packages/paddlenlp/transformers/image_utils.py:213: DeprecationWarning: BILINEAR is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BILINEAR instead.\n",
+ " resample=Image.BILINEAR,\n",
+ "/usr/local/lib/python3.7/dist-packages/paddlenlp/transformers/image_utils.py:379: DeprecationWarning: NEAREST is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.NEAREST or Dither.NONE instead.\n",
+ " resample=Image.NEAREST,\n",
+ "/usr/local/lib/python3.7/dist-packages/paddlenlp/transformers/ernie_vil/feature_extraction.py:65: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.\n",
+ " resample=Image.BICUBIC,\n",
+ "/usr/local/lib/python3.7/dist-packages/paddlenlp/transformers/clip/feature_extraction.py:64: DeprecationWarning: BICUBIC is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.BICUBIC instead.\n",
+ " resample=Image.BICUBIC,\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 预先定义配置参数\n",
+ "\n",
+ "# 运行环境,可选\"cpu\",\"gpu\",\"gpu:x\"(x为gpu编号)\n",
+ "DEVICE = \"gpu\"\n",
+ "# 数据路径\n",
+ "DATASET_DIR = \"../data\" \n",
+ "# 训练模型保存路径\n",
+ "PARAM_PATH = \"../checkpoint/\" \n",
+ "# tokenizer使用的最大序列长度,ERNIE模型最大不能超过2048。请根据文本长度选择,通常推荐128、256或512,若出现显存不足,请适当调低这一参数\n",
+ "MAX_LENGTH = 128 \n",
+ "# 批处理大小,请结合显存情况进行调整,若出现显存不足,请适当调低这一参数\n",
+ "BATCH_SIZE = 1 \n",
+ "# 待分析解释的数据\n",
+ "INTERPRETER_FILE = \"bad_case.txt\"\n",
+ "# 可选 \"ig\",\"lime\",\"grad\" ,可以根据实际任务效果选择解释器\n",
+ "# \"grad\":GradShap方法依赖interpretdl\n",
+ "# !pip install interpretdl\n",
+ "INTERPRETER = \"ig\"\n",
+ "# 分析句子中TOP K关键词,K值\n",
+ "KEY_WORDS_NUM = 5"
+ ]
+ },
+ {
+ "source": [
+ "## 2.读取待分析数据"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def read_local_dataset(path):\n",
+ " \"\"\"\n",
+ " Read dataset file\n",
+ " \"\"\"\n",
+ " with open(path, 'r', encoding='utf-8') as f:\n",
+ " for line in f:\n",
+ " items = line.strip().split('\\t')\n",
+ " if items[0] == 'Text':\n",
+ " continue\n",
+ " if len(items) == 3:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': items[2]}\n",
+ " elif len(items) == 2:\n",
+ " yield {'text': items[0], 'label': items[1], 'predict': ''}\n",
+ " elif len(items) == 1:\n",
+ " yield {'text': items[0], 'label': '', 'predict': ''}\n",
+ " else:\n",
+ " raise ValueError(\"{} should be in fixed format.\".format(path))\n",
+ "\n",
+ "def preprocess_function(examples, tokenizer, max_seq_length):\n",
+ " \"\"\"\n",
+ " Preprocess dataset\n",
+ " \"\"\"\n",
+ " result = tokenizer(text=examples[\"text\"], max_seq_len=max_seq_length)\n",
+ " return result\n",
+ "\n",
+ "class LocalDataCollatorWithPadding(DataCollatorWithPadding):\n",
+ " \"\"\"\n",
+ " Convert the result of DataCollatorWithPadding from dict dictionary to a list\n",
+ " \"\"\"\n",
+ "\n",
+ " def __call__(self, features):\n",
+ " batch = super().__call__(features)\n",
+ " batch = list(batch.values())\n",
+ " return batch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\u001b[32m[2022-09-28 04:51:03,566] [ INFO]\u001b[0m - We are using to load '../checkpoint/'.\u001b[0m\n",
+ "W0928 04:51:03.570216 4827 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2\n",
+ "W0928 04:51:03.575362 4827 gpu_resources.cc:91] device: 0, cuDNN Version: 8.1.\n",
+ "\u001b[32m[2022-09-28 04:51:06,542] [ INFO]\u001b[0m - We are using to load '../checkpoint/'.\u001b[0m\n"
+ ]
+ }
+ ],
+ "source": [
+ "paddle.set_device(DEVICE)\n",
+ "\n",
+ "# Define model & tokenizer\n",
+ "if os.path.exists(os.path.join(\n",
+ " PARAM_PATH, \"model_state.pdparams\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"model_config.json\")) and os.path.exists(\n",
+ " os.path.join(PARAM_PATH,\n",
+ " \"tokenizer_config.json\")):\n",
+ " model = AutoModelForSequenceClassification.from_pretrained(\n",
+ " PARAM_PATH)\n",
+ " tokenizer = AutoTokenizer.from_pretrained(PARAM_PATH)\n",
+ "else:\n",
+ " raise ValueError(\"The {} should exist.\".format(PARAM_PATH))\n",
+ "\n",
+ "# Prepare & preprocess dataset\n",
+ "interpret_path = os.path.join(DATASET_DIR, INTERPRETER_FILE)\n",
+ "\n",
+ "\n",
+ "interpret_ds = load_dataset(read_local_dataset, path=interpret_path, lazy=False)\n",
+ "trans_func = functools.partial(preprocess_function,\n",
+ " tokenizer=tokenizer,\n",
+ " max_seq_length=MAX_LENGTH)\n",
+ "\n",
+ "interpret_ds = interpret_ds.map(trans_func)\n",
+ "\n",
+ "# Batchify dataset\n",
+ "collate_fn = LocalDataCollatorWithPadding(tokenizer)\n",
+ "interpret_batch_sampler = BatchSampler(interpret_ds,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " shuffle=False)\n",
+ "interpret_data_loader = DataLoader(dataset=interpret_ds,\n",
+ " batch_sampler=interpret_batch_sampler,\n",
+ " collate_fn=collate_fn)\n"
+ ]
+ },
+ {
+ "source": [
+ "## 3.开始数据可解释性分析\n",
+ "数据量较大时,数据分析时间较长,请耐心等待"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Start token level interpretion, it will take some time...\n",
+ "Building prefix dict from the default dictionary ...\n",
+ "Loading model from cache /tmp/jieba.cache\n",
+ "Loading model cost 0.751 seconds.\n",
+ "Prefix dict has been built successfully.\n",
+ "Start word level alignment, it will take some time...\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Init an interpreter\n",
+ "if INTERPRETER == 'ig':\n",
+ " from trustai.interpretation.token_level import IntGradInterpreter\n",
+ " interpreter = IntGradInterpreter(model)\n",
+ "elif INTERPRETER == 'lime':\n",
+ " from trustai.interpretation.token_level import LIMEInterpreter\n",
+ " interpreter = LIMEInterpreter(model, unk_id=tokenizer.convert_tokens_to_ids('[UNK]'), pad_id=tokenizer.convert_tokens_to_ids('[PAD]'))\n",
+ "else:\n",
+ " from trustai.interpretation.token_level import GradShapInterpreter\n",
+ " interpreter = GradShapInterpreter(model)\n",
+ "\n",
+ "# Use interpreter to get the importance scores for all data\n",
+ "print(\"Start token level interpretion, it will take some time...\")\n",
+ "analysis_result = []\n",
+ "for batch in interpret_data_loader:\n",
+ " analysis_result += interpreter(tuple(batch))\n",
+ "\n",
+ "# Add CLS and SEP tags to both original text and standard splited tokens\n",
+ "contexts = []\n",
+ "words = []\n",
+ "for i in range(len(interpret_ds)):\n",
+ " text = interpret_ds.data[i][\"text\"]\n",
+ " contexts.append(\"[CLS]\" + text + \"[SEP]\")\n",
+ " words.append([\"[CLS]\"] + list(jieba.cut(text)) + [\"[SEP]\"])\n",
+ "\n",
+ "# Get the offset map of tokenized tokens and standard splited tokens\n",
+ "print(\"Start word level alignment, it will take some time...\")\n",
+ "ori_offset_maps = []\n",
+ "word_offset_maps = []\n",
+ "for i in range(len(contexts)):\n",
+ " ori_offset_maps.append(tokenizer.get_offset_mapping(contexts[i]))\n",
+ " word_offset_maps.append(get_word_offset(contexts[i], words[i]))\n",
+ "\n",
+ "align_res = interpreter.alignment(analysis_result, contexts, words, word_offset_maps, ori_offset_maps, special_tokens=[\"[CLS]\", '[SEP]'],rationale_num=KEY_WORDS_NUM)"
+ ]
+ },
+ {
+ "source": [
+ "## 4.数据可解释性分析结果可视化\n",
+ "使用用颜色深浅可视化方式代表句子中词对预测结果的重要程度"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from IPython.core.display import display, HTML\n",
+ "class Visualization(VisualizationTextRecord):\n",
+ "\n",
+ " def __init__(self, interpret_res, true_label=None, pred_label=None, words=None):\n",
+ " if words is not None:\n",
+ " self.words = words\n",
+ " else:\n",
+ " self.words = interpret_res.words\n",
+ " self.pred_label = pred_label if pred_label is not None else ''\n",
+ " self.true_label = true_label if true_label is not None else ''\n",
+ " self.key_words = \" \".join(set(interpret_res.rationale_tokens))\n",
+ " word_attributions = interpret_res.word_attributions\n",
+ " _max = max(word_attributions)\n",
+ " _min = min(word_attributions)\n",
+ " self.word_attributions = [(word_imp - _min) / (_max - _min) for word_imp in word_attributions]\n",
+ "\n",
+ " def record_html(self):\n",
+ " \"\"\"change all informations to html\"\"\"\n",
+ " return \"\".join([\n",
+ " \"\",\n",
+ " self._format_class(self.true_label),\n",
+ " self._format_class(self.pred_label),\n",
+ " self._format_class(self.key_words),\n",
+ " self._format_word_attributions(),\n",
+ " \"
\",\n",
+ " ])\n",
+ " def _format_class(self, label):\n",
+ " return '{label} | '.format(label=label)\n",
+ "\n",
+ "def visualize_text(text_records):\n",
+ " \"\"\"visualize text\"\"\"\n",
+ " html = [\"\"]\n",
+ " rows = [\"Label | \"\n",
+ " \"Prediction | \"\n",
+ " \"Key words | \"\n",
+ " \"Important visualization | \"]\n",
+ " for record in text_records:\n",
+ " rows.append(record.record_html())\n",
+ " html.append(\"\".join(rows))\n",
+ " html.append(\"
---|
\")\n",
+ " html = HTML(\"\".join(html))\n",
+ " display(html)\n",
+ " return html.data\n",
+ "\n",
+ "\n",
+ "def visualize(interpret_res, ds):\n",
+ " records = []\n",
+ " for i in range(len(interpret_res)):\n",
+ " records.append(Visualization(interpret_res[i], true_label=ds.data[i][\"label\"], pred_label=ds.data[i][\"predict\"]))\n",
+ " html = visualize_text(records)\n",
+ " return html"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": "",
+ "text/html": "Label | Prediction | Key words | Important visualization |
---|
不履行家庭义务,婚后分居 | 婚后分居 | 至今 双方 出 分居 。 | [CLS] 2015 年 2 月 23 日 , 被告 将 原告 赶 出 家门 , 原告 居住 于 娘家 待产 , 双方 分居 至今 。 [SEP] |
|
婚后有子女,限制行为能力子女抚养 | 婚后有子女,限制行为能力子女抚养,不履行离婚协议 | 财产 符合 付清 欠条 抚养 | [CLS] 被告 孙某 辩称 : 离婚 协议 中 关于 财产 分割 给付 资金 不 符合 法律 规定 , 只有 离婚 和 子女 抚养 符合 法律 规定 , 就 没有 么 , 协议 不 代表 被告 的 真实 意思 表示 , 离婚 协议 中 没有 约定 付款 时间 , 而且 被告 给 原告 出具 了 欠条 是 5 年内 付清 , 原告 未 在 期满 后 起诉 , 应 驳回 。 [SEP] |
|
存在非婚生子,支付抚养费,限制行为能力子女抚养 | 限制行为能力子女抚养,存在非婚生子 | 赵某 并非 认可 之女 表示 | [CLS] 被告 董某 认可 赵某 乙 并非 原告 之女 , 并 表示 愿意 自行 抚养 赵某 乙 。 [SEP] |
|
准予离婚 | 准予离婚,法定离婚 | 原告 韩某 准予 离婚 。 | [CLS] 准予 原告 胡某 甲 与 被告 韩某 甲 离婚 。 [SEP] |
|
"
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "# process for vbisualize\n",
+ "html = visualize(align_res, interpret_ds)"
+ ]
+ }
+ ]
+}
\ No newline at end of file