diff --git a/examples/sentiment_analysis/textcnn/README.md b/examples/sentiment_analysis/textcnn/README.md new file mode 100644 index 0000000000000..53e4f41bb8b6c --- /dev/null +++ b/examples/sentiment_analysis/textcnn/README.md @@ -0,0 +1,192 @@ +# 使用TextCNN模型完成中文对话情绪识别任务 + +情感分析旨在自动识别和提取文本中的倾向、立场、评价、观点等主观信息。情感分析其中的一个任务就是对话情绪识别,针对智能对话中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极(positive)、消极(negative)和中性(neutral)。 + +本示例展示了如何用TextCNN预训练模型在机器人聊天数据集上进行Finetune完成中文对话情绪识别任务。 + +## 快速开始 + +### 代码结构说明 + +以下是本项目主要代码结构及说明: + +```text +textcnn/ +├── deploy # 部署 +│   └── python +│   └── predict.py # python预测部署示例 +├── data.py # 数据处理脚本 +├── export_model.py # 动态图参数导出静态图参数脚本 +├── model.py # 模型组网脚本 +├── predict.py # 模型预测脚本 +├── README.md # 文档说明 +└── train.py # 对话情绪识别任务训练脚本 +``` + +### 数据准备 + +这里我们提供一份已标注的机器人聊天数据集,包括训练集(train.tsv),开发集(dev.tsv)和测试集(test.tsv)。 +完整数据集可以通过以下命令下载并解压: + +```shell +wget https://paddlenlp.bj.bcebos.com/datasets/RobotChat.tar.gz +tar xvf RobotChat.tar.gz +``` + +### 词表下载 + +在模型训练之前,需要先下载词汇表文件word_dict.txt,用于构造词-id映射关系。 + +```shell +wget https://paddlenlp.bj.bcebos.com/robot_chat_word_dict.txt +``` + +**NOTE:** 词表的选择和实际应用数据相关,需根据实际数据选择词表。 + +### 预训练模型下载 + +这里我们提供了一个百度基于海量数据训练好的TextCNN模型,用户通过以下方式下载预训练模型。 + +```shell +wget https://paddlenlp.bj.bcebos.com/models/textcnn.pdparams +``` + +### 模型训练 + +在下载好词表和预训练模型后就可以在机器人聊天数据集上进行finetune,通过运行以下命令,在训练集(train.tsv)上进行模型训练,并在开发集(dev.tsv)验证,这里通过`--init_from_ckpt=./textcnn.pdparams`指定TextCNN预训练模型。 + +CPU 启动: + +```shell +python train.py --vocab_path=./robot_chat_word_dict.txt \ + --init_from_ckpt=./textcnn.pdparams \ + --device=cpu \ + --lr=5e-5 \ + --batch_size=64 \ + --epochs=10 \ + --save_dir=./checkpoints \ + --data_path=./RobotChat +``` + +GPU 启动: + +```shell +unset CUDA_VISIBLE_DEVICES +python -m paddle.distributed.launch --gpus "0" train.py \ + --vocab_path=./robot_chat_word_dict.txt \ + --init_from_ckpt=./textcnn.pdparams \ + --device=gpu \ + --lr=5e-5 \ + --batch_size=64 \ + --epochs=10 \ + --save_dir=./checkpoints \ + --data_path=./RobotChat +``` + +XPU启动: + +```shell +python train.py --vocab_path=./robot_chat_word_dict.txt \ + --init_from_ckpt=./textcnn.pdparams \ + --device=xpu \ + --lr=5e-5 \ + --batch_size=64 \ + --epochs=10 \ + --save_dir=./checkpoints \ + --data_path=./RobotChat +``` + +以上参数表示: + +* `vocab_path`: 词汇表文件路径。 +* `init_from_ckpt`: 恢复模型训练的断点路径。 +* `device`: 选用什么设备进行训练,可选cpu、gpu或xpu。如使用gpu训练则参数gpus指定GPU卡号。 +* `lr`: 学习率, 默认为5e-5。 +* `batch_size`: 运行一个batch大小,默认为64。 +* `epochs`: 训练轮次,默认为10。 +* `save_dir`: 训练保存模型的文件路径。 +* `data_path`: 数据集文件路径。 + + +程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存模型在指定的`save_dir`中。 +如: +```text +checkpoints/ +├── 0.pdopt +├── 0.pdparams +├── 1.pdopt +├── 1.pdparams +├── ... +└── final.pdparams +``` + +**NOTE:** + +* 如需恢复模型训练,则init_from_ckpt只需指定到文件名即可,不需要添加文件尾缀。如`--init_from_ckpt=checkpoints/0`即可,程序会自动加载模型参数`checkpoints/0.pdparams`,也会自动加载优化器状态`checkpoints/0.pdopt`。 +* 使用动态图训练结束之后,还可以将动态图参数导出成静态图参数,具体代码见export_model.py。静态图参数保存在`output_path`指定路径中。 + 运行方式: + +```shell +python export_model.py --vocab_path=./robot_chat_word_dict.txt --params_path=./checkpoints/final.pdparams --output_path=./static_graph_params +``` + +其中`params_path`是指动态图训练保存的参数路径,`output_path`是指静态图参数导出路径。 + +导出模型之后,可以用于部署,deploy/python/predict.py文件提供了python部署预测示例。运行方式: + +```shell +python deploy/python/predict.py --model_file=static_graph_params.pdmodel --params_file=static_graph_params.pdiparams +``` + +### 模型预测 + +启动预测: + +CPU启动: + +```shell +python predict.py --vocab_path=./robot_chat_word_dict.txt \ + --device=cpu \ + --params_path=./checkpoints/final.pdparams +``` + +GPU启动: + +```shell +export CUDA_VISIBLE_DEVICES=0 +python predict.py --vocab_path=./robot_chat_word_dict.txt \ + --device=gpu \ + --params_path=./checkpoints/final.pdparams +``` + +XPU启动: + +```shell +python predict.py --vocab_path=./robot_chat_word_dict.txt \ + --device=xpu \ + --params_path=./checkpoints/final.pdparams +``` + +待预测数据如以下示例: + +```text +你再骂我我真的不跟你聊了 +你看看我附近有什么好吃的 +我喜欢画画也喜欢唱歌 +``` + +经过`preprocess_prediction_data`函数处理后,调用`predict`函数即可输出预测结果。 + +如 + +```text +Data: 你再骂我我真的不跟你聊了 Label: negative +Data: 你看看我附近有什么好吃的 Label: neutral +Data: 我喜欢画画也喜欢唱歌 Label: positive +``` + +## Reference + +TextCNN参考论文: + +- [EMNLP2014-Convolutional Neural Networks for Sentence Classification](https://aclanthology.org/D14-1181.pdf) diff --git a/examples/sentiment_analysis/textcnn/data.py b/examples/sentiment_analysis/textcnn/data.py new file mode 100644 index 0000000000000..e912c528eb181 --- /dev/null +++ b/examples/sentiment_analysis/textcnn/data.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddlenlp.datasets import load_dataset + + +def create_dataloader(dataset, + mode='train', + batch_size=1, + batchify_fn=None, + trans_fn=None): + """ + Create dataloader. + + Args: + dataset(obj:`paddle.io.Dataset`): Dataset instance. + mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly. + batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch. + batchify_fn(obj:`callable`, optional, defaults to `None`): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0(same as :attr::`np.stack(..., axis=0)`). + trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc. + + Returns: + dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches. + """ + if trans_fn: + dataset = dataset.map(trans_fn) + + shuffle = True if mode == 'train' else False + if mode == "train": + sampler = paddle.io.DistributedBatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=shuffle) + else: + sampler = paddle.io.BatchSampler( + dataset=dataset, batch_size=batch_size, shuffle=shuffle) + dataloader = paddle.io.DataLoader( + dataset, batch_sampler=sampler, collate_fn=batchify_fn) + return dataloader + +def preprocess_prediction_data(data, tokenizer, pad_token_id=0, max_ngram_filter_size=3): + """ + It process the prediction data as the format used as training. + + Args: + data (obj:`list[str]`): The prediction data whose each element is a tokenized text. + tokenizer(obj: paddlenlp.data.JiebaTokenizer): It use jieba to cut the chinese string. + pad_token_id(obj:`int`, optional, defaults to 0): The pad token index. + max_ngram_filter_size (obj:`int`, optional, defaults to 3) Max n-gram size in TextCNN model. + Users should refer to the ngram_filter_sizes setting in TextCNN, if ngram_filter_sizes=(1, 2, 3) + then max_ngram_filter_size=3 + + Returns: + examples (obj:`list`): The processed data whose each element + is a `list` object, which contains + + - word_ids(obj:`list[int]`): The list of word ids. + """ + examples = [] + for text in data: + ids = tokenizer.encode(text) + seq_len = len(ids) + # Sequence length should larger or equal than the maximum ngram_filter_size in TextCNN model + if seq_len < max_ngram_filter_size: + ids.extend([pad_token_id] * (max_ngram_filter_size - seq_len)) + examples.append(ids) + return examples + +def convert_example(example, tokenizer): + """convert_example""" + input_ids = tokenizer.encode(example["text"]) + input_ids = np.array(input_ids, dtype='int64') + + label = np.array(example["label"], dtype="int64") + return input_ids, label + +def read_custom_data(filename): + """Reads data.""" + with open(filename, 'r', encoding='utf-8') as f: + # Skip head + next(f) + for line in f: + data = line.strip().split("\t") + label, text = data + yield {"text": text, "label": label} diff --git a/examples/sentiment_analysis/textcnn/deploy/python/predict.py b/examples/sentiment_analysis/textcnn/deploy/python/predict.py new file mode 100644 index 0000000000000..555f48b02ff1d --- /dev/null +++ b/examples/sentiment_analysis/textcnn/deploy/python/predict.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +import paddle.nn.functional as F +from paddle import inference +from paddlenlp.data import JiebaTokenizer, Pad, Vocab + +from data import preprocess_prediction_data + +# yapf: disable +parser = argparse.ArgumentParser() +parser.add_argument("--model_file", type=str, required=True, + default='./static_graph_params.pdmodel', help="The path to model info in static graph.") +parser.add_argument("--params_file", type=str, required=True, + default='./static_graph_params.pdiparams', help="The path to parameters in static graph.") +parser.add_argument("--vocab_path", type=str, default="./robot_chat_word_dict.txt", help="The path to vocabulary.") +parser.add_argument("--max_seq_length", + default=128, type=int, help="The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will be padded.") +parser.add_argument("--batch_size", default=2, type=int, help="Batch size per GPU/CPU for training.") +parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], + default="gpu", help="Select which device to train model, defaults to gpu.") +args = parser.parse_args() +# yapf: enable + + +class Predictor(object): + def __init__(self, model_file, params_file, device, max_seq_length): + self.max_seq_length = max_seq_length + + config = paddle.inference.Config(model_file, params_file) + if device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + elif device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + elif device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + config.switch_use_feed_fetch_ops(False) + self.predictor = paddle.inference.create_predictor(config) + + self.input_handles = [ + self.predictor.get_input_handle(name) + for name in self.predictor.get_input_names() + ] + + self.output_handle = self.predictor.get_output_handle( + self.predictor.get_output_names()[0]) + + def predict(self, + data, + label_map, + batch_size=1, + pad_token_id=0): + """ + Predicts the data labels. + + Args: + data (obj:`list`): The processed data whose each element + is a `list` object, which contains + + - word_ids(obj:`list[int]`): The list of word ids. + label_map(obj:`dict`): The label id (key) to label str (value) map. + batch_size(obj:`int`, defaults to 1): The number of batch. + pad_token_id(obj:`int`, optional, defaults to 0): The pad token index. + + Returns: + results(obj:`dict`): All the predictions labels. + """ + + # Seperates data into some batches. + batches = [ + data[idx:idx + batch_size] for idx in range(0, len(data), batch_size) + ] + batchify_fn = lambda samples, fn=Pad( + axis=0, pad_val=pad_token_id + ): [data for data in fn(samples)] + + results = [] + for batch in batches: + input_ids = batchify_fn(batch) + self.input_handles[0].copy_from_cpu(input_ids) + self.predictor.run() + logits = paddle.to_tensor(self.output_handle.copy_to_cpu()) + probs = F.softmax(logits, axis=1) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_map[i] for i in idx] + results.extend(labels) + return results + + +if __name__ == "__main__": + # Define predictor to do prediction. + predictor = Predictor(args.model_file, args.params_file, args.device, + args.max_seq_length) + + vocab = Vocab.load_vocabulary( + args.vocab_path, unk_token='[UNK]', pad_token='[PAD]') + pad_token_id = vocab.to_indices('[PAD]') + tokenizer = JiebaTokenizer(vocab) + label_map = {0: 'negative', 1: 'neutral', 2: 'positive'} + + # Firstly pre-processing prediction data and then do predict. + data = [ + '你再骂我我真的不跟你聊了', + '你看看我附近有什么好吃的', + '我喜欢画画也喜欢唱歌' + ] + examples = preprocess_prediction_data(data, tokenizer, pad_token_id) + + results = predictor.predict( + examples, + label_map, + batch_size=args.batch_size, + pad_token_id=pad_token_id) + for idx, text in enumerate(data): + print('Data: {} \t Label: {}'.format(text, results[idx])) diff --git a/examples/sentiment_analysis/textcnn/export_model.py b/examples/sentiment_analysis/textcnn/export_model.py new file mode 100644 index 0000000000000..5386a4f568aeb --- /dev/null +++ b/examples/sentiment_analysis/textcnn/export_model.py @@ -0,0 +1,66 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle +import paddlenlp as ppnlp +from paddlenlp.data import Vocab +from model import TextCNNModel + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--vocab_path", type=str, default="./robot_chat_word_dict.txt", help="The path to vocabulary.") +parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.") +parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.") +parser.add_argument("--output_path", type=str, default='./static_graph_params', help="The path of model parameter in static graph to be saved.") +args = parser.parse_args() +# yapf: enable + + +def main(): + # Load vocab. + if not os.path.exists(args.vocab_path): + raise RuntimeError('The vocab_path can not be found in the path %s' % + args.vocab_path) + + vocab = Vocab.load_vocabulary(args.vocab_path) + label_map = {0: 'negative', 1: 'neutral', 2: 'positive'} + + # Construct the newtork. + vocab_size = len(vocab) + num_classes = len(label_map) + pad_token_id = vocab.to_indices('[PAD]') + + model = TextCNNModel( + vocab_size, + num_classes, + padding_idx=pad_token_id, + ngram_filter_sizes=(1, 2, 3)) + + # Load model parameters. + state_dict = paddle.load(args.params_path) + model.set_dict(state_dict) + model.eval() + + inputs = [paddle.static.InputSpec(shape=[None, None], dtype="int64")] + + model = paddle.jit.to_static(model, input_spec=inputs) + # Save in static graph model. + paddle.jit.save(model, args.output_path) + + +if __name__ == "__main__": + main() diff --git a/examples/sentiment_analysis/textcnn/model.py b/examples/sentiment_analysis/textcnn/model.py new file mode 100644 index 0000000000000..bfe76fd72ca31 --- /dev/null +++ b/examples/sentiment_analysis/textcnn/model.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from paddlenlp.seq2vec import CNNEncoder + + +class TextCNNModel(nn.Layer): + """ + This class implements the Text Convolution Neural Network model. + At a high level, the model starts by embedding the tokens and running them through + a word embedding. Then, we encode these epresentations with a `CNNEncoder`. + The CNN has one convolution layer for each ngram filter size. Each convolution operation gives + out a vector of size num_filter. The number of times a convolution layer will be used + is `num_tokens - ngram_size + 1`. The corresponding maxpooling layer aggregates all these + outputs from the convolution layer and outputs the max. + Lastly, we take the output of the encoder to create a final representation, + which is passed through some feed-forward layers to output a logits (`output_layer`). + + """ + + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + num_filter=128, + ngram_filter_sizes=(1, 2, 3), + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + vocab_size, emb_dim, padding_idx=padding_idx) + self.encoder = CNNEncoder( + emb_dim=emb_dim, + num_filter=num_filter, + ngram_filter_sizes=ngram_filter_sizes) + self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + # Shape: (batch_size, len(ngram_filter_sizes) * num_filter) + encoder_out = paddle.tanh(self.encoder(embedded_text)) + # Shape: (batch_size, fc_hidden_size) + fc_out = paddle.tanh(self.fc(encoder_out)) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc_out) + return logits + \ No newline at end of file diff --git a/examples/sentiment_analysis/textcnn/predict.py b/examples/sentiment_analysis/textcnn/predict.py new file mode 100644 index 0000000000000..3da9533bdace4 --- /dev/null +++ b/examples/sentiment_analysis/textcnn/predict.py @@ -0,0 +1,113 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse + +import paddle +import paddle.nn.functional as F +import paddlenlp as ppnlp +from paddlenlp.data import JiebaTokenizer, Pad, Vocab + +from model import TextCNNModel +from data import preprocess_prediction_data + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.") +parser.add_argument("--batch_size", type=int, default=1, help="Total examples' number of a batch for training.") +parser.add_argument("--vocab_path", type=str, default="./robot_chat_word_dict.txt", help="The path to vocabulary.") +parser.add_argument("--params_path", type=str, default='./checkpoints/final.pdparams', help="The path of model parameter to be loaded.") +args = parser.parse_args() +# yapf: enable + + +def predict(model, data, label_map, batch_size=1, pad_token_id=0): + """ + Predicts the data labels. + + Args: + model (obj:`paddle.nn.Layer`): A model to classify texts. + data (obj:`list`): The processed data whose each element + is a `list` object, which contains + + - word_ids(obj:`list[int]`): The list of word ids. + label_map(obj:`dict`): The label id (key) to label str (value) map. + batch_size(obj:`int`, defaults to 1): The number of batch. + pad_token_id(obj:`int`, optional, defaults to 0): The pad token index. + + Returns: + results(obj:`dict`): All the predictions labels. + """ + + # Seperates data into some batches. + batches = [ + data[idx:idx + batch_size] for idx in range(0, len(data), batch_size) + ] + batchify_fn = lambda samples, fn=Pad( + axis=0, pad_val=pad_token_id + ): [data for data in fn(samples)] + + results = [] + model.eval() + for batch in batches: + texts = paddle.to_tensor(batchify_fn(batch)) + logits = model(texts) + probs = F.softmax(logits, axis=1) + idx = paddle.argmax(probs, axis=1).numpy() + idx = idx.tolist() + labels = [label_map[i] for i in idx] + results.extend(labels) + return results + + +if __name__ == "__main__": + paddle.set_device(args.device) + + # Load vocab. + vocab = Vocab.load_vocabulary( + args.vocab_path, unk_token='[UNK]', pad_token='[PAD]') + label_map = {0: 'negative', 1: 'neutral', 2: 'positive'} + + # Construct the newtork. + vocab_size = len(vocab) + num_classes = len(label_map) + pad_token_id = vocab.to_indices('[PAD]') + + model = TextCNNModel( + vocab_size, + num_classes, + padding_idx=pad_token_id, + ngram_filter_sizes=(1, 2, 3)) + + # Load model parameters. + state_dict = paddle.load(args.params_path) + model.set_dict(state_dict) + print("Loaded parameters from %s" % args.params_path) + + # Firstly pre-processing prediction data and then do predict. + data = [ + '你再骂我我真的不跟你聊了', + '你看看我附近有什么好吃的', + '我喜欢画画也喜欢唱歌' + ] + tokenizer = JiebaTokenizer(vocab) + examples = preprocess_prediction_data(data, tokenizer, pad_token_id) + + results = predict( + model, + examples, + label_map=label_map, + batch_size=args.batch_size, + pad_token_id=pad_token_id) + for idx, text in enumerate(data): + print('Data: {} \t Label: {}'.format(text, results[idx])) diff --git a/examples/sentiment_analysis/textcnn/train.py b/examples/sentiment_analysis/textcnn/train.py new file mode 100644 index 0000000000000..25ed5ae105d8a --- /dev/null +++ b/examples/sentiment_analysis/textcnn/train.py @@ -0,0 +1,126 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import argparse +import os +import random + +import numpy as np +import paddle +from paddlenlp.datasets import load_dataset +from paddlenlp.data import JiebaTokenizer, Pad, Stack, Tuple, Vocab + +from data import create_dataloader, convert_example, read_custom_data +from model import TextCNNModel + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +parser.add_argument("--epochs", type=int, default=10, help="Number of epoches for training.") +parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu'], default="gpu", help="Select which device to train model, defaults to gpu.") +parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate used to train.") +parser.add_argument("--save_dir", type=str, default='checkpoints/', help="Directory to save model checkpoint") +parser.add_argument("--data_path", type=str, default='./RobotChat', help="The path of datasets to be loaded") +parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number of a batch for training.") +parser.add_argument("--vocab_path", type=str, default="./robot_chat_word_dict.txt", help="The directory to dataset.") +parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.") +args = parser.parse_args() +# yapf: enable + + +def set_seed(seed=1000): + """Sets random seed.""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + +if __name__ == "__main__": + paddle.set_device(args.device) + set_seed() + + # Load vocab. + if not os.path.exists(args.vocab_path): + raise RuntimeError('The vocab_path can not be found in the path %s' % + args.vocab_path) + + vocab = Vocab.load_vocabulary( + args.vocab_path, unk_token='[UNK]', pad_token='[PAD]') + + # Load datasets. + dataset_names = ['train.tsv', 'dev.tsv', 'test.tsv'] + train_ds, dev_ds, test_ds = [load_dataset(read_custom_data, \ + filename=os.path.join(args.data_path, dataset_name), lazy=False) for dataset_name in dataset_names] + + tokenizer = JiebaTokenizer(vocab) + trans_fn = partial(convert_example, tokenizer=tokenizer) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=vocab.token_to_idx.get('[PAD]', 0)), + Stack(dtype='int64') # label + ): [data for data in fn(samples)] + train_loader = create_dataloader( + train_ds, + batch_size=args.batch_size, + mode='train', + batchify_fn=batchify_fn, + trans_fn=trans_fn) + dev_loader = create_dataloader( + dev_ds, + batch_size=args.batch_size, + mode='validation', + batchify_fn=batchify_fn, + trans_fn=trans_fn) + test_loader = create_dataloader( + test_ds, + batch_size=args.batch_size, + mode='test', + batchify_fn=batchify_fn, + trans_fn=trans_fn) + + label_map = {0: 'negative', 1: 'neutral', 2: 'positive'} + vocab_size = len(vocab) + num_classes = len(label_map) + pad_token_id = vocab.to_indices('[PAD]') + + model = TextCNNModel( + vocab_size, + num_classes, + padding_idx=pad_token_id, + ngram_filter_sizes=(1, 2, 3)) + + if args.init_from_ckpt and os.path.isfile(args.init_from_ckpt): + state_dict = paddle.load(args.init_from_ckpt) + model.set_dict(state_dict) + + model = paddle.Model(model) + + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=args.lr) + + # Define loss and metric. + criterion = paddle.nn.CrossEntropyLoss() + metric = paddle.metric.Accuracy() + + model.prepare(optimizer, criterion, metric) + + # Start training and evaluating. + callback = paddle.callbacks.ProgBarLogger(log_freq=10, verbose=3) + model.fit(train_loader, + dev_loader, + epochs=args.epochs, + save_dir=args.save_dir, + callbacks=callback) + + # Evaluate on test dataset + print('Start to evaluate on test dataset...') + model.evaluate(test_loader, log_freq=len(test_loader))