diff --git a/README.md b/README.md
index fa457a239c68..55088c5b5888 100644
--- a/README.md
+++ b/README.md
@@ -13,9 +13,9 @@
## News
+* [2022-03-21] PaddleNLP**一键预测工具**[Taskflow](./docs/model_zoo/taskflow.md)全新升级!🚀欢迎体验更丰富的功能、更便捷的使用方式;新推出适合不同场景的中文分词、命名实体识别模式!
* [2021-12-28] PaddleNLP新发**语义检索、问答、评论观点抽取和情感倾向分析** [产业化案例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications),🚀快速搭建系统!配套视频课程[直通车](https://aistudio.baidu.com/aistudio/course/introduce/24902)!
* [2021-12-12] PaddleNLP v2.2版本已发布!:tada: 欢迎体验更快的文本处理[FasterTokenizer](./examples/faster/faster_ernie)、更快的预训练模型[FasterERNIE](./examples/faster/faster_ernie)、更快的文本生成[FasterGeneration](./examples/faster/faster_generation);新推出『解语』名词短语标注工具[NPTag](./examples/text_to_knowledge/nptag)、超快中文小模型[PP-MiniLM](./examples/model_compression/pp-minilm)! 更多详细升级信息请查看[Release Note](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.1.0)。
-* [2021-12-12] 飞桨新产品**端到端问答工具**🚀[RocketQA](https://github.com/PaddlePaddle/RocketQA)全新发布!:tada:
## 简介
diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md
index 46a34cf6d123..380199d46353 100644
--- a/docs/model_zoo/taskflow.md
+++ b/docs/model_zoo/taskflow.md
@@ -693,3 +693,5 @@ ner = Taskflow("ner", home_path="/workspace")
## 参考资料
1. [fxsjy/jieba](https://github.com/fxsjy/jieba)
+2. [ZhuiyiTechnology/simbert]( https://github.com/ZhuiyiTechnology/simbert)
+3. [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413)
diff --git a/docs/model_zoo/transformers.rst b/docs/model_zoo/transformers.rst
index c889fbdf5e47..f0744538f9c4 100644
--- a/docs/model_zoo/transformers.rst
+++ b/docs/model_zoo/transformers.rst
@@ -818,7 +818,7 @@ Reference
`ymcui/Chinese-XLNet `_,
`huggingface/xlnet_chinese_large `_,
`Knover/luge-dialogue `_,
- `huawei-noah/Pretrained-Language-Model/NEZHA-PyTorch/ `_
+ `huawei-noah/Pretrained-Language-Model/NEZHA-PyTorch/ `_,
`ZhuiyiTechnology/simbert `_
- Lan, Zhenzhong, et al. "Albert: A lite bert for self-supervised learning of language representations." arXiv preprint arXiv:1909.11942 (2019).
- Lewis, Mike, et al. "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension." arXiv preprint arXiv:1910.13461 (2019).
@@ -826,6 +826,7 @@ Reference
- Zaheer, Manzil, et al. "Big bird: Transformers for longer sequences." arXiv preprint arXiv:2007.14062 (2020).
- Stephon, Emily, et al. "Blenderbot: Recipes for building an open-domain chatbot." arXiv preprint arXiv:2004.13637 (2020).
- Stephon, Emily, et al. "Blenderbot-Small: Recipes for building an open-domain chatbot." arXiv preprint arXiv:2004.13637 (2020).
+- Zhang, zhengyan, et al. "CPM: A Large-scale Generative Chinese Pre-trained Language Model." arXiv preprint arXiv:2012.00413 (2020).
- Jiang, Zihang, et al. "ConvBERT: Improving BERT with Span-based Dynamic Convolution." arXiv preprint arXiv:2008.02496 (2020).
- Nitish, Bryan, et al. "CTRL: A Conditional Transformer Language Model for Controllable Generation." arXiv preprint arXiv:1909.05858 (2019).
- Sanh, Victor, et al. "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter." arXiv preprint arXiv:1910.01108 (2019).
diff --git a/examples/dialogue/plato-xl/infer.py b/examples/dialogue/plato-xl/infer.py
index cdf6856cac58..79d6d81c74b0 100644
--- a/examples/dialogue/plato-xl/infer.py
+++ b/examples/dialogue/plato-xl/infer.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
import time
import argparse
from pprint import pprint
@@ -73,6 +74,11 @@ def setup_args():
default=4,
type=int,
help="The number of candidate to procedure beam search. ")
+ parser.add_argument(
+ "--num_return_sequences",
+ default=1,
+ type=int,
+ help="The number of returned sequences. ")
args = parser.parse_args()
@@ -93,10 +99,17 @@ def postprocess_response(token_ids, tokenizer):
def infer(args):
+ if args.faster and args.use_fp16_decoding and os.getenv("PPFG_QKV_MEM_OPT",
+ "0") == "1":
+ paddle.set_default_dtype("float16")
+
model_name = 'plato-xl'
- model = UnifiedTransformerLMHeadModel.from_pretrained(model_name)
+ model = UnifiedTransformerLMHeadModel.from_pretrained(
+ model_name, load_state_as_np=True)
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name)
+ model.eval()
+
context = [
"Hi , Becky , what's up ?",
"Not much , except that my mother-in-law is driving me up the wall .",
@@ -136,6 +149,7 @@ def infer(args):
top_k=args.topk,
top_p=args.topp,
num_beams=args.num_beams,
+ num_return_sequences=args.num_return_sequences,
use_fp16_decoding=args.use_fp16_decoding,
use_faster=args.faster)
diff --git a/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py b/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py
index 3ed583056a12..582ee07b79e8 100644
--- a/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py
+++ b/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py
@@ -23,7 +23,7 @@
import paddle
from paddle import inference
from paddle.io import DataLoader
-from paddlenlp.datasets import load_dataset
+from datasets import load_dataset
from paddlenlp.data import Pad, Stack, Dict
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
diff --git a/paddlenlp/data/sampler.py b/paddlenlp/data/sampler.py
index 6e13a4ab3f5b..305fc5df7c37 100644
--- a/paddlenlp/data/sampler.py
+++ b/paddlenlp/data/sampler.py
@@ -18,7 +18,6 @@
import six
import numpy as np
-import paddle.distributed as dist
class SamplerHelper(object):
@@ -391,6 +390,8 @@ def __len__(self):
print(list(sampler)) # indices of dataset elements
# [0, 2]
"""
+ import paddle.distributed as dist
+
if num_replicas is None:
num_replicas = dist.get_world_size()
if rank is None:
diff --git a/paddlenlp/datasets/cnn_dailymail.py b/paddlenlp/datasets/cnn_dailymail.py
index 3eb172a55606..1768da85ab1c 100644
--- a/paddlenlp/datasets/cnn_dailymail.py
+++ b/paddlenlp/datasets/cnn_dailymail.py
@@ -20,7 +20,12 @@
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints
-from paddle.distributed import ParallelEnv
+try:
+ from paddle.distributed import ParallelEnv
+except Exception as e:
+ import warnings
+ warnings.warn("paddle.distributed is not contains in you paddle!")
+
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.utils.log import logger
from . import DatasetBuilder
diff --git a/paddlenlp/datasets/dataset.py b/paddlenlp/datasets/dataset.py
index 5a185378ecd8..8fe2283105cd 100644
--- a/paddlenlp/datasets/dataset.py
+++ b/paddlenlp/datasets/dataset.py
@@ -26,7 +26,12 @@
import paddlenlp
import datasets
-import paddle.distributed as dist
+try:
+ import paddle.distributed as dist
+except Exception as e:
+ import warnings
+ warnings.warn("paddle.distributed is not contains in you paddle!")
+
from paddle.io import Dataset, IterableDataset
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url, _get_unique_endpoints
@@ -45,13 +50,15 @@
def load_from_ppnlp(path, **kwargs):
ppnlp_path = paddlenlp.datasets.__path__[0]
- path = os.path.split(path)[-1]
- path = os.path.join(ppnlp_path, path + '.py')
- return origin_load_dataset(path, **kwargs)
+ new_path = os.path.split(path)[-1]
+ new_path = os.path.join(ppnlp_path, 'hf_datasets', new_path + '.py')
+ if os.path.exists(new_path):
+ return origin_load_dataset(new_path, **kwargs)
+ else:
+ return origin_load_dataset(path, **kwargs)
-if os.environ.get('ISINTRANET', '0') == '1':
- datasets.load_dataset = load_from_ppnlp
+datasets.load_dataset = load_from_ppnlp
class DatasetTuple:
diff --git a/paddlenlp/datasets/dureader_robust.py b/paddlenlp/datasets/dureader_robust.py
index bce297ca2883..3c49627eb242 100644
--- a/paddlenlp/datasets/dureader_robust.py
+++ b/paddlenlp/datasets/dureader_robust.py
@@ -1,6 +1,4 @@
-# coding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,123 +12,74 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Lint as: python3
-
+import collections
import json
import os
-import datasets
-from datasets.tasks import QuestionAnsweringExtractive
-
-logger = datasets.logging.get_logger(__name__)
-
-_DESCRIPTION = """\
-DureaderRobust is a chinese reading comprehension \
-dataset, designed to evaluate the MRC models from \
-three aspects: over-sensitivity, over-stability \
-and generalization.
-"""
-
-_URL = "https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz"
-
-
-class DureaderRobustConfig(datasets.BuilderConfig):
- """BuilderConfig for DureaderRobust."""
-
- def __init__(self, **kwargs):
- """BuilderConfig for DureaderRobust.
-
- Args:
- **kwargs: keyword arguments forwarded to super.
- """
- super(DureaderRobustConfig, self).__init__(**kwargs)
-
-
-class DureaderRobust(datasets.GeneratorBasedBuilder):
- BUILDER_CONFIGS = [
- DureaderRobustConfig(
- name="plain_text",
- version=datasets.Version("1.0.0", ""),
- description="Plain text", ),
- ]
-
- def _info(self):
- return datasets.DatasetInfo(
- description=_DESCRIPTION,
- features=datasets.Features({
- "id": datasets.Value("string"),
- "title": datasets.Value("string"),
- "context": datasets.Value("string"),
- "question": datasets.Value("string"),
- "answers": datasets.features.Sequence({
- "text": datasets.Value("string"),
- "answer_start": datasets.Value("int32"),
- }),
- }),
- # No default supervised_keys (as we have to pass both question
- # and context as input).
- supervised_keys=None,
- homepage="https://arxiv.org/abs/2004.11142",
- task_templates=[
- QuestionAnsweringExtractive(
- question_column="question",
- context_column="context",
- answers_column="answers")
- ], )
-
- def _split_generators(self, dl_manager):
- dl_dir = dl_manager.download_and_extract(_URL)
-
- return [
- datasets.SplitGenerator(
- name=datasets.Split.TRAIN,
- gen_kwargs={
- "filepath": os.path.join(dl_dir, 'dureader_robust-data',
- 'train.json')
- }),
- datasets.SplitGenerator(
- name=datasets.Split.VALIDATION,
- gen_kwargs={
- "filepath": os.path.join(dl_dir, 'dureader_robust-data',
- 'dev.json')
- }),
- datasets.SplitGenerator(
- name=datasets.Split.TEST,
- gen_kwargs={
- "filepath": os.path.join(dl_dir, 'dureader_robust-data',
- 'test.json')
- }),
- ]
-
- def _generate_examples(self, filepath):
- """This function returns the examples in the raw (text) form."""
- logger.info("generating examples from = %s", filepath)
- key = 0
- with open(filepath, encoding="utf-8") as f:
- durobust = json.load(f)
- for article in durobust["data"]:
- title = article.get("title", "")
- for paragraph in article["paragraphs"]:
- context = paragraph[
- "context"] # do not strip leading blank spaces GH-2585
- for qa in paragraph["qas"]:
- answer_starts = [
- answer["answer_start"]
- for answer in qa.get("answers", '')
- ]
- answers = [
- answer["text"] for answer in qa.get("answers", '')
- ]
- # Features currently used are "context", "question", and "answers".
- # Others are extracted here for the ease of future expansions.
- yield key, {
- "title": title,
- "context": context,
- "question": qa["question"],
- "id": qa["id"],
- "answers": {
- "answer_start": answer_starts,
- "text": answers,
- },
- }
- key += 1
+from paddle.dataset.common import md5file
+from paddle.utils.download import get_path_from_url
+from paddlenlp.utils.env import DATA_HOME
+from . import DatasetBuilder
+
+__all__ = ['DuReaderRobust']
+
+
+class DuReaderRobust(DatasetBuilder):
+ '''
+ The machine reading comprehension dataset (i.e. DuReader robust) is designed
+ to measure the robustness of a reading comprehension model, including the
+ over-sensitivity, over-stability and generalization ability of the model.
+ '''
+
+ URL = 'https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz'
+ MD5 = '82f3d191a115ec17808856866787606e'
+ META_INFO = collections.namedtuple('META_INFO', ('file', 'md5'))
+ SPLITS = {
+ 'train': META_INFO(
+ os.path.join('dureader_robust-data', 'train.json'),
+ '800a3dcb742f9fdf9b11e0a83433d4be'),
+ 'dev': META_INFO(
+ os.path.join('dureader_robust-data', 'dev.json'),
+ 'ae73cec081eaa28a735204c4898a2222'),
+ 'test': META_INFO(
+ os.path.join('dureader_robust-data', 'test.json'),
+ 'e0e8aa5c7b6d11b6fc3935e29fc7746f')
+ }
+
+ def _get_data(self, mode, **kwargs):
+ default_root = os.path.join(DATA_HOME, self.__class__.__name__)
+ filename, data_hash = self.SPLITS[mode]
+ fullname = os.path.join(default_root, filename)
+ if not os.path.exists(fullname) or (data_hash and
+ not md5file(fullname) == data_hash):
+ get_path_from_url(self.URL, default_root, self.MD5)
+
+ return fullname
+
+ def _read(self, filename, *args):
+ with open(filename, "r", encoding="utf8") as f:
+ input_data = json.load(f)["data"]
+ for entry in input_data:
+ title = entry.get("title", "").strip()
+ for paragraph in entry["paragraphs"]:
+ context = paragraph["context"].strip()
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question = qa["question"].strip()
+ answer_starts = [
+ answer["answer_start"]
+ for answer in qa.get("answers", [])
+ ]
+ answers = [
+ answer["text"].strip()
+ for answer in qa.get("answers", [])
+ ]
+
+ yield {
+ 'id': qas_id,
+ 'title': title,
+ 'context': context,
+ 'question': question,
+ 'answers': answers,
+ 'answer_starts': answer_starts
+ }
diff --git a/paddlenlp/datasets/hf_datasets/dureader_robust.py b/paddlenlp/datasets/hf_datasets/dureader_robust.py
new file mode 100644
index 000000000000..bce297ca2883
--- /dev/null
+++ b/paddlenlp/datasets/hf_datasets/dureader_robust.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+
+import json
+import os
+
+import datasets
+from datasets.tasks import QuestionAnsweringExtractive
+
+logger = datasets.logging.get_logger(__name__)
+
+_DESCRIPTION = """\
+DureaderRobust is a chinese reading comprehension \
+dataset, designed to evaluate the MRC models from \
+three aspects: over-sensitivity, over-stability \
+and generalization.
+"""
+
+_URL = "https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz"
+
+
+class DureaderRobustConfig(datasets.BuilderConfig):
+ """BuilderConfig for DureaderRobust."""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig for DureaderRobust.
+
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(DureaderRobustConfig, self).__init__(**kwargs)
+
+
+class DureaderRobust(datasets.GeneratorBasedBuilder):
+ BUILDER_CONFIGS = [
+ DureaderRobustConfig(
+ name="plain_text",
+ version=datasets.Version("1.0.0", ""),
+ description="Plain text", ),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features({
+ "id": datasets.Value("string"),
+ "title": datasets.Value("string"),
+ "context": datasets.Value("string"),
+ "question": datasets.Value("string"),
+ "answers": datasets.features.Sequence({
+ "text": datasets.Value("string"),
+ "answer_start": datasets.Value("int32"),
+ }),
+ }),
+ # No default supervised_keys (as we have to pass both question
+ # and context as input).
+ supervised_keys=None,
+ homepage="https://arxiv.org/abs/2004.11142",
+ task_templates=[
+ QuestionAnsweringExtractive(
+ question_column="question",
+ context_column="context",
+ answers_column="answers")
+ ], )
+
+ def _split_generators(self, dl_manager):
+ dl_dir = dl_manager.download_and_extract(_URL)
+
+ return [
+ datasets.SplitGenerator(
+ name=datasets.Split.TRAIN,
+ gen_kwargs={
+ "filepath": os.path.join(dl_dir, 'dureader_robust-data',
+ 'train.json')
+ }),
+ datasets.SplitGenerator(
+ name=datasets.Split.VALIDATION,
+ gen_kwargs={
+ "filepath": os.path.join(dl_dir, 'dureader_robust-data',
+ 'dev.json')
+ }),
+ datasets.SplitGenerator(
+ name=datasets.Split.TEST,
+ gen_kwargs={
+ "filepath": os.path.join(dl_dir, 'dureader_robust-data',
+ 'test.json')
+ }),
+ ]
+
+ def _generate_examples(self, filepath):
+ """This function returns the examples in the raw (text) form."""
+ logger.info("generating examples from = %s", filepath)
+ key = 0
+ with open(filepath, encoding="utf-8") as f:
+ durobust = json.load(f)
+ for article in durobust["data"]:
+ title = article.get("title", "")
+ for paragraph in article["paragraphs"]:
+ context = paragraph[
+ "context"] # do not strip leading blank spaces GH-2585
+ for qa in paragraph["qas"]:
+ answer_starts = [
+ answer["answer_start"]
+ for answer in qa.get("answers", '')
+ ]
+ answers = [
+ answer["text"] for answer in qa.get("answers", '')
+ ]
+ # Features currently used are "context", "question", and "answers".
+ # Others are extracted here for the ease of future expansions.
+ yield key, {
+ "title": title,
+ "context": context,
+ "question": qa["question"],
+ "id": qa["id"],
+ "answers": {
+ "answer_start": answer_starts,
+ "text": answers,
+ },
+ }
+ key += 1
diff --git a/paddlenlp/datasets/xnli.py b/paddlenlp/datasets/xnli.py
index 7616b1e27c45..7aa6f9d5471a 100644
--- a/paddlenlp/datasets/xnli.py
+++ b/paddlenlp/datasets/xnli.py
@@ -20,7 +20,12 @@
from paddle.dataset.common import md5file
from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints
-from paddle.distributed import ParallelEnv
+try:
+ from paddle.distributed import ParallelEnv
+except Exception as e:
+ import warnings
+ warnings.warn("paddle.distributed is not contains in you paddle!")
+
from paddlenlp.utils.env import DATA_HOME
from paddlenlp.utils.log import logger
from . import DatasetBuilder
diff --git a/paddlenlp/ops/distributed/parallel.py b/paddlenlp/ops/distributed/parallel.py
index b5e2b0111c07..9f937a219fc0 100644
--- a/paddlenlp/ops/distributed/parallel.py
+++ b/paddlenlp/ops/distributed/parallel.py
@@ -15,7 +15,11 @@
import paddle
import paddle.nn as nn
from paddle.fluid.framework import in_dygraph_mode
-from paddle.distributed.fleet import fleet
+try:
+ from paddle.distributed.fleet import fleet
+except Exception as e:
+ import warnings
+ warnings.warn("paddle.distributed is not contains in you paddle!")
__all__ = [
'guard',
diff --git a/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py b/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py
index a519a393345a..7d62a281a94f 100644
--- a/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py
+++ b/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py
@@ -51,11 +51,6 @@ def parse_args():
"--max_out_len", default=64, type=int, help="Maximum output length. ")
parser.add_argument(
"--min_out_len", default=1, type=int, help="Minimum output length. ")
- parser.add_argument(
- "--num_return_sequence",
- default=1,
- type=int,
- help="The number of returned sequence. ")
parser.add_argument(
"--temperature",
default=1.0,
@@ -95,8 +90,12 @@ def do_predict(args):
place = "gpu"
place = paddle.set_device(place)
+ if args.use_fp16_decoding and os.getenv("PPFG_QKV_MEM_OPT", "0") == "1":
+ paddle.set_default_dtype("float16")
+
model_name = 'plato-xl'
- model = UnifiedTransformerLMHeadModel.from_pretrained(model_name)
+ model = UnifiedTransformerLMHeadModel.from_pretrained(
+ model_name, load_state_as_np=True)
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name)
plato = FasterUnifiedTransformer(
diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py
index b4d64ca9e7c7..02553bd307e7 100644
--- a/paddlenlp/transformers/__init__.py
+++ b/paddlenlp/transformers/__init__.py
@@ -98,4 +98,4 @@
from .luke.tokenizer import *
from .megatronbert.modeling import *
from .megatronbert.tokenizer import *
-from .semantic_indexing.modeling import *
+from .semantic_search.modeling import *
diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py
index 0480bbc7ed2e..e7c55cff98d0 100644
--- a/paddlenlp/transformers/ernie/modeling.py
+++ b/paddlenlp/transformers/ernie/modeling.py
@@ -238,6 +238,32 @@ class ErniePretrainedModel(PretrainedModel):
"vocab_size": 30522,
"pad_token_id": 0,
},
+ "rocketqa-zh-dureader-cross-encoder": {
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "relu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "max_position_embeddings": 513,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "type_vocab_size": 2,
+ "vocab_size": 18000,
+ "pad_token_id": 0,
+ },
+ "rocketqa-v1-marco-cross-encoder": {
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 768,
+ "initializer_range": 0.02,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 12,
+ "num_hidden_layers": 12,
+ "type_vocab_size": 4,
+ "vocab_size": 30522,
+ "pad_token_id": 0,
+ },
}
resource_files_names = {"model_state": "model_state.pdparams"}
pretrained_resource_files_map = {
@@ -260,6 +286,10 @@ class ErniePretrainedModel(PretrainedModel):
"https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_query_encoder.pdparams",
"rocketqa-v1-marco-para-encoder":
"https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_para_encoder.pdparams",
+ "rocketqa-zh-dureader-cross-encoder":
+ "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_zh_dureader_cross_encoder.pdparams",
+ "rocketqa-v1-marco-cross-encoder":
+ "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_cross_encoder.pdparams",
}
}
base_model_prefix = "ernie"
diff --git a/paddlenlp/transformers/ernie/tokenizer.py b/paddlenlp/transformers/ernie/tokenizer.py
index 09110c1aa55c..fd6c89243268 100644
--- a/paddlenlp/transformers/ernie/tokenizer.py
+++ b/paddlenlp/transformers/ernie/tokenizer.py
@@ -101,6 +101,10 @@ class ErnieTokenizer(PretrainedTokenizer):
"https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt",
"rocketqa-v1-marco-para-encoder":
"https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt",
+ "rocketqa-zh-dureader-cross-encoder":
+ "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-zh-dureader-vocab.txt",
+ "rocketqa-v1-marco-cross-encoder":
+ "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt",
}
}
pretrained_init_configuration = {
@@ -143,6 +147,12 @@ class ErnieTokenizer(PretrainedTokenizer):
"rocketqa-v1-marco-para-encoder": {
"do_lower_case": True
},
+ "rocketqa-zh-dureader-cross-encoder": {
+ "do_lower_case": True
+ },
+ "rocketqa-v1-marco-cross-encoder": {
+ "do_lower_case": True
+ },
}
def __init__(self,
diff --git a/paddlenlp/transformers/semantic_indexing/__init__.py b/paddlenlp/transformers/semantic_search/__init__.py
similarity index 100%
rename from paddlenlp/transformers/semantic_indexing/__init__.py
rename to paddlenlp/transformers/semantic_search/__init__.py
diff --git a/paddlenlp/transformers/semantic_indexing/modeling.py b/paddlenlp/transformers/semantic_search/modeling.py
similarity index 73%
rename from paddlenlp/transformers/semantic_indexing/modeling.py
rename to paddlenlp/transformers/semantic_search/modeling.py
index 897d6bd94f71..ec9817420a3f 100644
--- a/paddlenlp/transformers/semantic_indexing/modeling.py
+++ b/paddlenlp/transformers/semantic_search/modeling.py
@@ -18,13 +18,15 @@
from ..ernie.modeling import ErniePretrainedModel
-__all__ = ['ErnieDualEncoder']
+__all__ = ['ErnieDualEncoder', 'ErnieCrossEncoder']
class ErnieEncoder(ErniePretrainedModel):
- def __init__(self, ernie):
+ def __init__(self, ernie, dropout=None, num_classes=2):
super(ErnieEncoder, self).__init__()
self.ernie = ernie # allow ernie to be config
+ self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
+ self.classifier = nn.Linear(768, num_classes)
self.apply(self.init_weights)
def init_weights(self, layer):
@@ -37,15 +39,12 @@ def forward(self,
token_type_ids=None,
position_ids=None,
attention_mask=None):
- sequence_output, _ = self.ernie(
+ sequence_output, pool_output = self.ernie(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
-
- # Outputs pooled_embedding
- pooled_output = sequence_output[:, 0]
- return pooled_output
+ return sequence_output, pool_output
class ErnieDualEncoder(nn.Layer):
@@ -92,7 +91,6 @@ def __init__(self,
elif title_model_name_or_path is not None:
self.title_ernie = ErnieEncoder.from_pretrained(
title_model_name_or_path)
- self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
def get_semantic_embedding(self, data_loader):
self.eval()
@@ -116,12 +114,13 @@ def get_pooled_embedding(self,
assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \
"Please check whether your parameter for `is_query` are consistent with DualEncoder initialization."
if is_query:
- pooled_embedding = self.query_ernie(input_ids, token_type_ids,
- position_ids, attention_mask)
+ sequence_output, _ = self.query_ernie(input_ids, token_type_ids,
+ position_ids, attention_mask)
+
else:
- pooled_embedding = self.title_ernie(input_ids, token_type_ids,
- position_ids, attention_mask)
- return pooled_embedding
+ sequence_output, _ = self.title_ernie(input_ids, token_type_ids,
+ position_ids, attention_mask)
+ return sequence_output[:, 0]
def cosine_sim(self,
query_input_ids,
@@ -207,3 +206,66 @@ def forward(self,
outputs = {"loss": loss, "accuracy": accuracy}
return outputs
+
+
+class ErnieCrossEncoder(nn.Layer):
+ """
+ Example:
+
+ .. code-block::
+
+ import paddle
+ from paddlenlp.transformers import ErnieCrossEncoder, ErnieTokenizer
+
+ model = ErnieCrossEncoder("rocketqa-zh-dureader-cross-encoder")
+ tokenizer = ErnieTokenizer.from_pretrained("rocketqa-zh-dureader-cross-encoder")
+
+ inputs = tokenizer("你们好", text_pair="你好")
+ inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
+
+ # Get similarity probability of text pair.
+ embedding = model.matching(**inputs)
+
+ """
+
+ def __init__(self, pretrain_model_name_or_path, num_classes=2,
+ dropout=None):
+ super().__init__()
+ self.ernie = ErnieEncoder.from_pretrained(pretrain_model_name_or_path)
+
+ def matching(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None,
+ return_prob_distributation=False):
+ _, pooled_output = self.ernie(
+ input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask)
+ pooled_output = self.ernie.dropout(pooled_output)
+ cls_embedding = self.ernie.classifier(pooled_output)
+ probs = F.softmax(cls_embedding, axis=1)
+ if return_prob_distributation:
+ return probs
+ return probs[:, 1]
+
+ def forward(self,
+ input_ids,
+ token_type_ids=None,
+ position_ids=None,
+ attention_mask=None,
+ labels=None):
+ probs = self.matching(
+ input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ return_prob_distributation=True)
+ accuracy = paddle.metric.accuracy(input=probs, label=labels)
+ loss = F.cross_entropy(input=logits, label=labels)
+
+ outputs = {"loss": loss, "accuracy": accuracy}
+
+ return outputs
diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py
index 305f74d5973c..a038d48dd7d0 100644
--- a/paddlenlp/transformers/tokenizer_utils.py
+++ b/paddlenlp/transformers/tokenizer_utils.py
@@ -21,6 +21,7 @@
import os
import six
import unicodedata
+from collections import OrderedDict, UserDict
from shutil import copyfile
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union
@@ -45,6 +46,29 @@
]
+class BatchEncoding(UserDict):
+ def __init__(self, data=None):
+ super().__init__(data)
+
+ def __getitem__(self, item):
+ if isinstance(item, str):
+ return self.data[item]
+ else:
+ raise KeyError(
+ "Indexing with integers is not available when using tokenizer.__call__()"
+ " with return_dict=True. Please set return_dict to False to use integer indexing."
+ )
+
+ def keys(self):
+ return self.data.keys()
+
+ def values(self):
+ return self.data.values()
+
+ def items(self):
+ return self.data.items()
+
+
def convert_to_unicode(text):
"""
Converts `text` to Unicode (if it's not already), assuming utf-8 input.
@@ -559,7 +583,8 @@ def __call__(self,
return_attention_mask=False,
return_length=False,
return_overflowing_tokens=False,
- return_special_tokens_mask=False):
+ return_special_tokens_mask=False,
+ return_dict=True):
"""
Performs tokenization and uses the tokenized tokens to prepare model
inputs. It supports sequence or sequence pair as input, and batch input
@@ -695,7 +720,8 @@ def __call__(self,
return_attention_mask=return_attention_mask,
return_length=return_length,
return_overflowing_tokens=return_overflowing_tokens,
- return_special_tokens_mask=return_special_tokens_mask)
+ return_special_tokens_mask=return_special_tokens_mask,
+ return_dict=return_dict)
else:
return self.encode(
text=text,
@@ -1534,7 +1560,8 @@ def batch_encode(self,
return_attention_mask=False,
return_length=False,
return_overflowing_tokens=False,
- return_special_tokens_mask=False):
+ return_special_tokens_mask=False,
+ return_dict=True):
"""
Performs tokenization and uses the tokenized tokens to prepare model
inputs. It supports batch inputs of sequence or sequence pair.
@@ -1765,10 +1792,13 @@ def get_input_ids(text):
range(len(encoded_inputs["input_ids"])))
encoded_inputs['overflow_to_sample'] = example_id
- for key, value in encoded_inputs.items():
- if key not in batch_outputs:
- batch_outputs[key] = []
- batch_outputs[key].append(value)
+ if return_dict:
+ for key, value in encoded_inputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+ else:
+ batch_encode_inputs.append(encoded_inputs)
if offset + length == len(second_ids):
break
offset += min(length, stride)
@@ -1787,12 +1817,16 @@ def get_input_ids(text):
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask)
- for key, value in encoded_inputs.items():
- if key not in batch_outputs:
- batch_outputs[key] = []
- batch_outputs[key].append(value)
+ if return_dict:
+ for key, value in encoded_inputs.items():
+ if key not in batch_outputs:
+ batch_outputs[key] = []
+ batch_outputs[key].append(value)
+ else:
+ batch_encode_inputs.append(encoded_inputs)
- return batch_outputs
+ return BatchEncoding(
+ batch_outputs) if return_dict else batch_encode_inputs
def get_offset_mapping(self, text):
"""
diff --git a/paddlenlp/transformers/unified_transformer/modeling.py b/paddlenlp/transformers/unified_transformer/modeling.py
index 49c61f37ad6c..c799b49b4522 100644
--- a/paddlenlp/transformers/unified_transformer/modeling.py
+++ b/paddlenlp/transformers/unified_transformer/modeling.py
@@ -107,7 +107,7 @@ class UnifiedTransformerPretrainedModel(PretrainedModel):
"attention_probs_dropout_prob": 0.1,
"normalize_before": True,
"max_position_embeddings": 1024,
- "type_vocab_size": 2,
+ "type_vocab_size": 3,
"role_type_size": 128,
"initializer_range": 0.02,
"unk_token_id": 0,