-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add text similarity task for Taskflow #1345
Changes from 5 commits
478cf12
15d90bf
78a9649
0924f23
abbc852
6643178
b5efa9d
3eea9bd
e34d519
8a0ace8
35069d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
- [文本纠错](#文本纠错) | ||
- [句法分析](#句法分析) | ||
- [情感分析](#情感分析) | ||
- [文本相似度](#文本相似度) | ||
- [知识挖掘-词类知识标注](#知识挖掘-词类知识标注) | ||
- [知识挖掘-名词短语标注](#知识挖掘-名词短语标注) | ||
- [生成式问答](#生成式问答) | ||
|
@@ -31,6 +32,7 @@ | |
| 文本纠错 | 开放域对话(TODO) | | ||
| 句法分析 | 自动对联(TODO) | | ||
| 情感分析 | | | ||
| 文本相似度 | | | ||
| 知识挖掘-词类知识标注 | | | ||
| 知识挖掘-名词短语标注 | | | ||
|
||
|
@@ -174,6 +176,20 @@ senta("作为老的四星酒店,房间依然很整洁,相当不错。机场 | |
>>> [{'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。', 'label': 'positive', 'score': 0.984320878982544}] | ||
``` | ||
|
||
### 文本相似度 | ||
|
||
```python | ||
from paddlenlp import Taskflow | ||
|
||
similarity = Taskflow("text_similarity") | ||
similarity([["世界上什么东西最小", "世界上什么东西最小?"]]) | ||
>>> [{'query': '世界上什么东西最小', 'title': '世界上什么东西最小?', 'similarity': 0.992725}] | ||
|
||
similarity = Taskflow("text_similarity", batch_size=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 还是得告诉开发者,为什么这个batch_size=2有什么用。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其他examples都得同步增强下这里的API参数描述。不然这里会误解,必须要设置batch size=2,才能传入两条样本 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,这里修改了代码示例,新增可配置参数说明 |
||
similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) | ||
>>> [{'query': '光眼睛大就好看吗', 'title': '眼睛好看吗?', 'similarity': 0.7450271}, {'query': '小蝌蚪找妈妈怎么样', 'title': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] | ||
``` | ||
|
||
### 知识挖掘-词类知识标注 | ||
|
||
```python | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddlenlp.transformers import BertModel, BertTokenizer | ||
|
||
from ..data import Pad, Tuple | ||
from .utils import static_mode_guard | ||
from .task import Task | ||
|
||
usage = r""" | ||
from paddlenlp import Taskflow | ||
|
||
similarity = Taskflow("text_similarity") | ||
similarity([["世界上什么东西最小", "世界上什么东西最小?"]]) | ||
''' | ||
[{'query': '世界上什么东西最小', 'title': '世界上什么东西最小?', 'similarity': 0.992725}] | ||
''' | ||
|
||
similarity = Taskflow("text_similarity", batch_size=2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch_size这地方需要手动配置吗?是否可以根据输入的size自动获得呢? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是说batch_size=1的话,不能同时输入两条?还是说这个batch size是作为predictor的关键参数 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. batch_size目前是手动配置的,默认值是1,考虑是让用户结合机器本身情况配置 batch_size是predictor的关键参数 |
||
similarity([["光眼睛大就好看吗", "眼睛好看吗?"], ["小蝌蚪找妈妈怎么样", "小蝌蚪找妈妈是谁画的"]]) | ||
''' | ||
[{'query': '光眼睛大就好看吗', 'title': '眼睛好看吗?', 'similarity': 0.7450271}, {'query': '小蝌蚪找妈妈怎么样', 'title': '小蝌蚪找妈妈是谁画的', 'similarity': 0.8192149}] | ||
''' | ||
""" | ||
|
||
class TextSimilarityTask(Task): | ||
""" | ||
Text similarity task using SimBERT to predict the similarity of sentence pair. | ||
Args: | ||
task(string): The name of task. | ||
model(string): The model name in the task. | ||
kwargs (dict, optional): Additional keyword arguments passed along to the specific task. | ||
""" | ||
|
||
def __init__(self, | ||
task, | ||
model, | ||
batch_size=1, | ||
max_seq_len=128, | ||
**kwargs): | ||
super().__init__(task=task, model=model, **kwargs) | ||
self._static_mode = True | ||
self._construct_tokenizer(model) | ||
self._get_inference_model() | ||
self._batch_size = batch_size | ||
self._max_seq_len = max_seq_len | ||
self._usage = usage | ||
|
||
def _construct_input_spec(self): | ||
""" | ||
Construct the input spec for the convert dygraph model to static model. | ||
""" | ||
self._input_spec = [ | ||
paddle.static.InputSpec( | ||
shape=[None, None], dtype="int64", name='input_ids'), | ||
paddle.static.InputSpec( | ||
shape=[None], dtype="int64", name='token_type_ids'), | ||
] | ||
|
||
def _construct_model(self, model): | ||
""" | ||
Construct the inference model for the predictor. | ||
""" | ||
self._model = BertModel.from_pretrained(model, pool_act='linear') | ||
self._model.eval() | ||
|
||
def _construct_tokenizer(self, model): | ||
""" | ||
Construct the tokenizer for the predictor. | ||
""" | ||
self._tokenizer = BertTokenizer.from_pretrained(model) | ||
|
||
def _check_input_text(self, inputs): | ||
inputs = inputs[0] | ||
if not all([isinstance(i, list) and i \ | ||
and all(i) and len(i) == 2 for i in inputs]): | ||
raise TypeError( | ||
"Invalid input format.") | ||
return inputs | ||
|
||
def _preprocess(self, inputs): | ||
""" | ||
Transform the raw text to the model inputs, two steps involved: | ||
1) Transform the raw text to token ids. | ||
2) Generate the other model inputs from the raw text and token ids. | ||
""" | ||
inputs = self._check_input_text(inputs) | ||
num_workers = self.kwargs[ | ||
'num_workers'] if 'num_workers' in self.kwargs else 0 | ||
lazy_load = self.kwargs[ | ||
'lazy_load'] if 'lazy_load' in self.kwargs else False | ||
|
||
examples = [] | ||
|
||
for data in inputs: | ||
query, title = data[0], data[1] | ||
|
||
query_encoded_inputs = self._tokenizer( | ||
text=query, max_seq_len=self._max_seq_len) | ||
query_input_ids = query_encoded_inputs["input_ids"] | ||
query_token_type_ids = query_encoded_inputs["token_type_ids"] | ||
|
||
title_encoded_inputs = self._tokenizer( | ||
text=title, max_seq_len=self._max_seq_len) | ||
title_input_ids = title_encoded_inputs["input_ids"] | ||
title_token_type_ids = title_encoded_inputs["token_type_ids"] | ||
|
||
examples.append((query_input_ids, query_token_type_ids, | ||
title_input_ids, title_token_type_ids)) | ||
|
||
batches = [ | ||
examples[idx:idx + self._batch_size] | ||
for idx in range(0, len(examples), self._batch_size) | ||
] | ||
|
||
batchify_fn = lambda samples, fn=Tuple( | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # query_input | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # query_segment | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_id), # title_input | ||
Pad(axis=0, pad_val=self._tokenizer.pad_token_type_id), # tilte_segment | ||
): [data for data in fn(samples)] | ||
|
||
outputs = {} | ||
outputs['data_loader'] = batches | ||
outputs['text'] = inputs | ||
self._batchify_fn = batchify_fn | ||
return outputs | ||
|
||
def _run_model(self, inputs): | ||
""" | ||
Run the task model from the outputs of the `_tokenize` function. | ||
""" | ||
results = [] | ||
with static_mode_guard(): | ||
for batch in inputs['data_loader']: | ||
q_ids, q_segment_ids, t_ids, t_segment_ids = self._batchify_fn(batch) | ||
self.input_handles[0].copy_from_cpu(q_ids) | ||
self.input_handles[1].copy_from_cpu(q_segment_ids) | ||
self.predictor.run() | ||
vecs_query = self.output_handle[1].copy_to_cpu() | ||
|
||
self.input_handles[0].copy_from_cpu(t_ids) | ||
self.input_handles[1].copy_from_cpu(t_segment_ids) | ||
self.predictor.run() | ||
vecs_title = self.output_handle[1].copy_to_cpu() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议整体内部代码都体现为text1和text2。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
vecs_query = vecs_query / (vecs_query**2).sum(axis=1, | ||
keepdims=True)**0.5 | ||
vecs_title = vecs_title / (vecs_title**2).sum(axis=1, | ||
keepdims=True)**0.5 | ||
similarity = (vecs_query * vecs_title).sum(axis=1) | ||
results.extend(similarity) | ||
inputs['result'] = results | ||
return inputs | ||
|
||
def _postprocess(self, inputs): | ||
""" | ||
The model output is tag ids, this function will convert the model output to raw text. | ||
""" | ||
final_results = [] | ||
for text, similarity in zip(inputs['text'], inputs['result']): | ||
result = {} | ||
result['query'] = text[0] | ||
result['title'] = text[1] | ||
result['similarity'] = similarity | ||
final_results.append(result) | ||
return final_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入的key可能采用text1,text2 更加准确。如果用query和title会被倾向于认为是短文本与长文本匹配
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改