diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 44ec9474221d..7987c86dee54 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -11,6 +11,7 @@ - [文本纠错](#文本纠错) - [句法分析](#句法分析) - [情感分析](#情感分析) + - [文本相似度](#文本相似度) - [知识挖掘-词类知识标注](#知识挖掘-词类知识标注) - [知识挖掘-名词短语标注](#知识挖掘-名词短语标注) - [生成式问答](#生成式问答) @@ -31,6 +32,7 @@ | 文本纠错 | 开放域对话(TODO) | | 句法分析 | 自动对联(TODO) | | 情感分析 | | +| 文本相似度 | | | 知识挖掘-词类知识标注 | | | 知识挖掘-名词短语标注 | | @@ -79,6 +81,10 @@ seg(["第十四届全运会在西安举办", "三亚是一个美丽的城市"]) >>> [['第十四届', '全运会', '在', '西安', '举办'], ['三亚', '是', '一个', '美丽', '的', '城市']] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 + ### 词性标注 ```python @@ -92,6 +98,10 @@ tag(["第十四届全运会在西安举办", "三亚是一个美丽的城市"]) >>> [[('第十四届', 'm'), ('全运会', 'nz'), ('在', 'p'), ('西安', 'LOC'), ('举办', 'v')], [('三亚', 'LOC'), ('是', 'v'), ('一个', 'm'), ('美丽', 'a'), ('的', 'u'), ('城市', 'n')]] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认值为1。 + ### 命名实体识别 ```python @@ -100,12 +110,15 @@ from paddlenlp import Taskflow ner = Taskflow("ner") ner("《孤女》是2010年九州出版社出版的小说,作者是余兼羽") >>> [('《', 'w'), ('孤女', '作品类_实体'), ('》', 'w'), ('是', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('的', '助词'), ('小说', '作品类_概念'), (',', 'w'), ('作者', '人物类_概念'), ('是', '肯定词'), ('余兼羽', '人物类_实体')] -ner = Taskflow("ner", batch_size=2) -ner(["热梅茶是一道以梅子为主要原料制作的茶饮", - "《孤女》是2010年九州出版社出版的小说,作者是余兼羽"]) + +ner(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》是2010年九州出版社出版的小说,作者是余兼羽"]) >>> [[('热梅茶', '饮食类_饮品'), ('是', '肯定词'), ('一道', '数量词'), ('以', '介词'), ('梅子', '饮食类'), ('为', '肯定词'), ('主要原料', '物体类'), ('制作', '场景事件'), ('的', '助词'), ('茶饮', '饮食类_饮品')], [('《', 'w'), ('孤女', '作品类_实体'), ('》', 'w'), ('是', '肯定词'), ('2010年', '时间类'), ('九州出版社', '组织机构类'), ('出版', '场景事件'), ('的', '助词'), ('小说', '作品类_概念'), (',', 'w'), ('作者', '人物类_概念'), ('是', '肯定词'), ('余兼羽', '人物类_实体')]] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 + ### 文本纠错 ```python @@ -120,6 +133,10 @@ corrector(['遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇 >>> [{'source': '遇到逆竟时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。', 'target': '遇到逆境时,我们必须勇于面对,而且要愈挫愈勇,这样我们才能朝著成功之路前进。', 'errors': [{'position': 3, 'correction': {'竟': '境'}}]}, {'source': '人生就是如此,经过磨练才能让自己更加拙壮,才能使自己更加乐观。', 'target': '人生就是如此,经过磨练才能让自己更加茁壮,才能使自己更加乐观。', 'errors': [{'position': 18, 'correction': {'拙': '茁'}}]}] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 + ### 句法分析 ```python @@ -156,6 +173,15 @@ import cv2 cv2.imwrite('test.png', result) ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `tree`:确保输出结果是正确的依存句法树,默认为True。 +* `prob`:是否输出每个弧对应的概率值,默认为False。 +* `use_pos`:是否返回词性标签,默认为False。 +* `use_cuda`:是否使用GPU进行切词,默认为False。 +* `return_visual`:是否返回句法树的可视化结果,默认为False。 + ### 情感分析 ```python @@ -174,6 +200,24 @@ senta("作为老的四星酒店,房间依然很整洁,相当不错。机场 >>> [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive', 'score': 0.984320878982544}] ``` +### 文本相似度 + +```python +from paddlenlp import Taskflow + +similarity = Taskflow("text_similarity") +similarity([["世界上什么东西最小", "世界上什么东西最小?"]]) +>>> [{'text1': '世界上什么东西最小', 'text2': '世界上什么东西最小?', 'similarity': 0.992725}] + +similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) +>>> [{'text1': '光眼睛大就好看吗', 'text2': '眼睛好看吗?', 'similarity': 0.74502707}, {'text1': '小蝌蚪找妈妈怎么样', 'text2': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] +``` + +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `max_seq_len`:最大序列长度,默认为128。 + ### 知识挖掘-词类知识标注 ```python @@ -183,12 +227,20 @@ wordtag = Taskflow("knowledge_mining") wordtag("《孤女》是2010年九州出版社出版的小说,作者是余兼羽") >>> [{'text': '《孤女》是2010年九州出版社出版的小说,作者是余兼羽', 'items': [{'item': '《', 'offset': 0, 'wordtag_label': 'w', 'length': 1}, {'item': '孤女', 'offset': 1, 'wordtag_label': '作品类_实体', 'length': 2}, {'item': '》', 'offset': 3, 'wordtag_label': 'w', 'length': 1}, {'item': '是', 'offset': 4, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '2010年', 'offset': 5, 'wordtag_label': '时间类', 'length': 5, 'termid': '时间阶段_cb_2010年'}, {'item': '九州出版社', 'offset': 10, 'wordtag_label': '组织机构类', 'length': 5, 'termid': '组织机构_eb_九州出版社'}, {'item': '出版', 'offset': 15, 'wordtag_label': '场景事件', 'length': 2, 'termid': '场景事件_cb_出版'}, {'item': '的', 'offset': 17, 'wordtag_label': '助词', 'length': 1, 'termid': '助词_cb_的'}, {'item': '小说', 'offset': 18, 'wordtag_label': '作品类_概念', 'length': 2, 'termid': '小说_cb_小说'}, {'item': ',', 'offset': 20, 'wordtag_label': 'w', 'length': 1}, {'item': '作者', 'offset': 21, 'wordtag_label': '人物类_概念', 'length': 2, 'termid': '人物_cb_作者'}, {'item': '是', 'offset': 23, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '余兼羽', 'offset': 24, 'wordtag_label': '人物类_实体', 'length': 3}]}] -wordtag= Taskflow("knowledge_mining", batch_size=2) wordtag(["热梅茶是一道以梅子为主要原料制作的茶饮", "《孤女》是2010年九州出版社出版的小说,作者是余兼羽"]) >>> [{'text': '热梅茶是一道以梅子为主要原料制作的茶饮', 'items': [{'item': '热梅茶', 'offset': 0, 'wordtag_label': '饮食类_饮品', 'length': 3}, {'item': '是', 'offset': 3, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '一道', 'offset': 4, 'wordtag_label': '数量词', 'length': 2}, {'item': '以', 'offset': 6, 'wordtag_label': '介词', 'length': 1, 'termid': '介词_cb_以'}, {'item': '梅子', 'offset': 7, 'wordtag_label': '饮食类', 'length': 2, 'termid': '饮食_cb_梅'}, {'item': '为', 'offset': 9, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_为'}, {'item': '主要原料', 'offset': 10, 'wordtag_label': '物体类', 'length': 4, 'termid': '物品_cb_主要原料'}, {'item': '制作', 'offset': 14, 'wordtag_label': '场景事件', 'length': 2, 'termid': '场景事件_cb_制作'}, {'item': '的', 'offset': 16, 'wordtag_label': '助词', 'length': 1, 'termid': '助词_cb_的'}, {'item': '茶饮', 'offset': 17, 'wordtag_label': '饮食类_饮品', 'length': 2, 'termid': '饮品_cb_茶饮'}]}, {'text': '《孤女》是2010年九州出版社出版的小说,作者是余兼羽', 'items': [{'item': '《', 'offset': 0, 'wordtag_label': 'w', 'length': 1}, {'item': '孤女', 'offset': 1, 'wordtag_label': '作品类_实体', 'length': 2}, {'item': '》', 'offset': 3, 'wordtag_label': 'w', 'length': 1}, {'item': '是', 'offset': 4, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '2010年', 'offset': 5, 'wordtag_label': '时间类', 'length': 5, 'termid': '时间阶段_cb_2010年'}, {'item': '九州出版社', 'offset': 10, 'wordtag_label': '组织机构类', 'length': 5, 'termid': '组织机构_eb_九州出版社'}, {'item': '出版', 'offset': 15, 'wordtag_label': '场景事件', 'length': 2, 'termid': '场景事件_cb_出版'}, {'item': '的', 'offset': 17, 'wordtag_label': '助词', 'length': 1, 'termid': '助词_cb_的'}, {'item': '小说', 'offset': 18, 'wordtag_label': '作品类_概念', 'length': 2, 'termid': '小说_cb_小说'}, {'item': ',', 'offset': 20, 'wordtag_label': 'w', 'length': 1}, {'item': '作者', 'offset': 21, 'wordtag_label': '人物类_概念', 'length': 2, 'termid': '人物_cb_作者'}, {'item': '是', 'offset': 23, 'wordtag_label': '肯定词', 'length': 1, 'termid': '肯定否定词_cb_是'}, {'item': '余兼羽', 'offset': 24, 'wordtag_label': '人物类_实体', 'length': 3}]}] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `linking`:实现基于词类的linking,默认为True。 +* `params_path`:使用自定义模型参数,默认为None。 +* `tag_path`:使用自定义标签文件,默认为None。 +* `term_schema_path`:使用自定义TermType词类体系,默认为None。 +* `term_data_path`:使用自定义百科知识树文件,默认为None。 + ### 知识挖掘-名词短语标注 ```python @@ -198,7 +250,6 @@ nptag = Taskflow("knowledge_mining", model="nptag") nptag("糖醋排骨") >>> [{'text': '糖醋排骨', 'label': '菜品'}] -nptag = Taskflow("knowledge_mining", model="nptag", batch_size=2) nptag(["糖醋排骨", "红曲霉菌"]) >>> [{'text': '糖醋排骨', 'label': '菜品'}, {'text': '红曲霉菌', 'label': '微生物'}] @@ -208,6 +259,12 @@ nptag(["糖醋排骨", "红曲霉菌"]) >>> [{'text': '糖醋排骨', 'label': '菜品', 'category': '饮食类_菜品'}, {'text': '红曲霉菌', 'label': '微生物', 'category': '生物类_微生物'}] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 +* `max_seq_len`:最大序列长度,默认为64。 +* `linking`:实现与WordTag类别标签的linking,默认为False。 + ### 生成式问答 ```python @@ -221,6 +278,10 @@ qa(["中国国土面积有多大?", "中国的首都在哪里?"]) >>> [{'text': '中国国土面积有多大?', 'answer': '960万平方公里。'}, {'text': '中国的首都在哪里?', 'answer': '北京。'}] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 + ### 智能写诗 ```python @@ -234,6 +295,10 @@ poetry(["林密不见人", "举头邀明月"]) >>> [{'text': '林密不见人', 'answer': ',但闻人语响。'}, {'text': '举头邀明月', 'answer': ',低头思故乡。'}] ``` +可配置参数: + +* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。 + ## FAQ ### Q1 Taskflow如何修改任务保存路径? diff --git a/paddlenlp/taskflow/taskflow.py b/paddlenlp/taskflow/taskflow.py index fac3e7e51ea1..adc47017204e 100644 --- a/paddlenlp/taskflow/taskflow.py +++ b/paddlenlp/taskflow/taskflow.py @@ -28,6 +28,7 @@ from .question_answering import QuestionAnsweringTask from .dependency_parsing import DDParserTask from .text_correction import CSCTask +from .text_similarity import TextSimilarityTask warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False) @@ -166,6 +167,17 @@ "model": "csc-ernie-1.0" } }, + 'text_similarity': { + "models": { + "simbert-base-chinese": { + "task_class": TextSimilarityTask, + "task_flag": "text_similarity-simbert-base-chinese" + }, + }, + "default": { + "model": "simbert-base-chinese" + } + }, } diff --git a/paddlenlp/taskflow/text_similarity.py b/paddlenlp/taskflow/text_similarity.py new file mode 100644 index 000000000000..d43347cf0fba --- /dev/null +++ b/paddlenlp/taskflow/text_similarity.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.transformers import BertModel, BertTokenizer + +from ..data import Pad, Tuple +from .utils import static_mode_guard +from .task import Task + +usage = r""" + from paddlenlp import Taskflow + + similarity = Taskflow("text_similarity") + similarity([["世界上什么东西最小", "世界上什么东西最小?"]]) + ''' + [{'text1': '世界上什么东西最小', 'text2': '世界上什么东西最小?', 'similarity': 0.992725}] + ''' + + similarity = Taskflow("text_similarity", batch_size=2) + similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) + ''' + [{'text1': '光眼睛大就好看吗', 'text2': '眼睛好看吗?', 'similarity': 0.74502707}, {'text1': '小蝌蚪找妈妈怎么样', 'text2': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] + ''' + """ + +class TextSimilarityTask(Task): + """ + Text similarity task using SimBERT to predict the similarity of sentence pair. + Args: + task(string): The name of task. + model(string): The model name in the task. + kwargs (dict, optional): Additional keyword arguments passed along to the specific task. + """ + + def __init__(self, + task, + model, + batch_size=1, + max_seq_len=128, + **kwargs): + super().__init__(task=task, model=model, **kwargs) + self._static_mode = True + self._construct_tokenizer(model) + self._get_inference_model() + self._batch_size = batch_size + self._max_seq_len = max_seq_len + self._usage = usage + + def _construct_input_spec(self): + """ + Construct the input spec for the convert dygraph model to static model. + """ + self._input_spec = [ + paddle.static.InputSpec( + shape=[None, None], dtype="int64", name='input_ids'), + paddle.static.InputSpec( + shape=[None], dtype="int64", name='token_type_ids'), + ] + + def _construct_model(self, model): + """ + Construct the inference model for the predictor. + """ + self._model = BertModel.from_pretrained(model, pool_act='linear') + self._model.eval() + + def _construct_tokenizer(self, model): + """ + Construct the tokenizer for the predictor. + """ + self._tokenizer = BertTokenizer.from_pretrained(model) + + def _check_input_text(self, inputs): + inputs = inputs[0] + if not all([isinstance(i, list) and i \ + and all(i) and len(i) == 2 for i in inputs]): + raise TypeError( + "Invalid input format.") + return inputs + + def _preprocess(self, inputs): + """ + Transform the raw text to the model inputs, two steps involved: + 1) Transform the raw text to token ids. + 2) Generate the other model inputs from the raw text and token ids. + """ + inputs = self._check_input_text(inputs) + num_workers = self.kwargs[ + 'num_workers'] if 'num_workers' in self.kwargs else 0 + lazy_load = self.kwargs[ + 'lazy_load'] if 'lazy_load' in self.kwargs else False + + examples = [] + + for data in inputs: + text1, text2 = data[0], data[1] + + text1_encoded_inputs = self._tokenizer( + text=text1, max_seq_len=self._max_seq_len) + text1_input_ids = text1_encoded_inputs["input_ids"] + text1_token_type_ids = text1_encoded_inputs["token_type_ids"] + + text2_encoded_inputs = self._tokenizer( + text=text2, max_seq_len=self._max_seq_len) + text2_input_ids = text2_encoded_inputs["input_ids"] + text2_token_type_ids = text2_encoded_inputs["token_type_ids"] + + examples.append((text1_input_ids, text1_token_type_ids, + text2_input_ids, text2_token_type_ids)) + + batches = [ + examples[idx:idx + self._batch_size] + for idx in range(0, len(examples), self._batch_size) + ] + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # text1_input_ids + Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # text1_token_type_ids + Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # text2_input_ids + Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # text2_token_type_ids + ): [data for data in fn(samples)] + + outputs = {} + outputs['data_loader'] = batches + outputs['text'] = inputs + self._batchify_fn = batchify_fn + return outputs + + def _run_model(self, inputs): + """ + Run the task model from the outputs of the `_tokenize` function. + """ + results = [] + with static_mode_guard(): + for batch in inputs['data_loader']: + text1_ids, text1_segment_ids, text2_ids, text2_segment_ids = self._batchify_fn(batch) + self.input_handles[0].copy_from_cpu(text1_ids) + self.input_handles[1].copy_from_cpu(text1_segment_ids) + self.predictor.run() + vecs_text1 = self.output_handle[1].copy_to_cpu() + + self.input_handles[0].copy_from_cpu(text2_ids) + self.input_handles[1].copy_from_cpu(text2_segment_ids) + self.predictor.run() + vecs_text2 = self.output_handle[1].copy_to_cpu() + + vecs_text1 = vecs_text1 / (vecs_text1**2).sum(axis=1, + keepdims=True)**0.5 + vecs_text2 = vecs_text2 / (vecs_text2**2).sum(axis=1, + keepdims=True)**0.5 + similarity = (vecs_text1 * vecs_text2).sum(axis=1) + results.extend(similarity) + inputs['result'] = results + return inputs + + def _postprocess(self, inputs): + """ + The model output is tag ids, this function will convert the model output to raw text. + """ + final_results = [] + for text, similarity in zip(inputs['text'], inputs['result']): + result = {} + result['text1'] = text[0] + result['text2'] = text[1] + result['similarity'] = similarity + final_results.append(result) + return final_results \ No newline at end of file