From 3e68984baa230481225ca040611dc4ba307eaf14 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Wed, 26 Jan 2022 17:41:30 +0800 Subject: [PATCH] Fix WordTag decode bug (#1642) * fix wordtag decode * Update README.md * Update README.md * Update codestyle --- paddlenlp/taskflow/knowledge_mining.py | 6 +++--- paddlenlp/taskflow/named_entity_recognition.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddlenlp/taskflow/knowledge_mining.py b/paddlenlp/taskflow/knowledge_mining.py index 4d0a9f7cace2..5cd2dd2a328f 100644 --- a/paddlenlp/taskflow/knowledge_mining.py +++ b/paddlenlp/taskflow/knowledge_mining.py @@ -191,7 +191,6 @@ class WordTagTask(Task): def __init__(self, model, task, - batch_size=1, params_path=None, tag_path=None, term_schema_path=None, @@ -419,11 +418,12 @@ def _reset_offset(self, pred_words): def _decode(self, batch_texts, batch_pred_tags): batch_results = [] for sent_index in range(len(batch_texts)): + sent = batch_texts[sent_index] tags = [ self._index_to_tags[index] - for index in batch_pred_tags[sent_index][self.summary_num:-1] + for index in batch_pred_tags[sent_index][self.summary_num:len( + sent) + self.summary_num] ] - sent = batch_texts[sent_index] if self._custom: self._custom.parse_customization(sent, tags, prefix=True) sent_out = [] diff --git a/paddlenlp/taskflow/named_entity_recognition.py b/paddlenlp/taskflow/named_entity_recognition.py index 55860cefa7af..695a7fc27268 100644 --- a/paddlenlp/taskflow/named_entity_recognition.py +++ b/paddlenlp/taskflow/named_entity_recognition.py @@ -87,11 +87,12 @@ def __init__(self, model, task, **kwargs): def _decode(self, batch_texts, batch_pred_tags): batch_results = [] for sent_index in range(len(batch_texts)): + sent = batch_texts[sent_index] tags = [ self._index_to_tags[index] - for index in batch_pred_tags[sent_index][self.summary_num:-1] + for index in batch_pred_tags[sent_index][self.summary_num:len( + sent) + self.summary_num] ] - sent = batch_texts[sent_index] if self._custom: self._custom.parse_customization(sent, tags, prefix=True) sent_out = [] @@ -100,12 +101,12 @@ def _decode(self, batch_texts, batch_pred_tags): for ind, tag in enumerate(tags): if partial_word == "": partial_word = sent[ind] - tags_out.append(tag.split('-')[1]) + tags_out.append(tag.split('-')[-1]) continue if tag.startswith("B") or tag.startswith("S") or tag.startswith( "O"): sent_out.append(partial_word) - tags_out.append(tag.split('-')[1]) + tags_out.append(tag.split('-')[-1]) partial_word = sent[ind] continue partial_word += sent[ind]