diff --git a/README.md b/README.md index fa457a239c68..55088c5b5888 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,9 @@ ## News +* [2022-03-21] PaddleNLP**一键预测工具**[Taskflow](./docs/model_zoo/taskflow.md)全新升级!🚀欢迎体验更丰富的功能、更便捷的使用方式;新推出适合不同场景的中文分词、命名实体识别模式! * [2021-12-28] PaddleNLP新发**语义检索、问答、评论观点抽取和情感倾向分析** [产业化案例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications),🚀快速搭建系统!配套视频课程[直通车](https://aistudio.baidu.com/aistudio/course/introduce/24902)! * [2021-12-12] PaddleNLP v2.2版本已发布!:tada: 欢迎体验更快的文本处理[FasterTokenizer](./examples/faster/faster_ernie)、更快的预训练模型[FasterERNIE](./examples/faster/faster_ernie)、更快的文本生成[FasterGeneration](./examples/faster/faster_generation);新推出『解语』名词短语标注工具[NPTag](./examples/text_to_knowledge/nptag)、超快中文小模型[PP-MiniLM](./examples/model_compression/pp-minilm)! 更多详细升级信息请查看[Release Note](https://github.com/PaddlePaddle/PaddleNLP/releases/tag/v2.1.0)。 -* [2021-12-12] 飞桨新产品**端到端问答工具**🚀[RocketQA](https://github.com/PaddlePaddle/RocketQA)全新发布!:tada: ## 简介 diff --git a/docs/model_zoo/taskflow.md b/docs/model_zoo/taskflow.md index 46a34cf6d123..380199d46353 100644 --- a/docs/model_zoo/taskflow.md +++ b/docs/model_zoo/taskflow.md @@ -693,3 +693,5 @@ ner = Taskflow("ner", home_path="/workspace") ## 参考资料 1. [fxsjy/jieba](https://github.com/fxsjy/jieba) +2. [ZhuiyiTechnology/simbert]( https://github.com/ZhuiyiTechnology/simbert) +3. [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) diff --git a/docs/model_zoo/transformers.rst b/docs/model_zoo/transformers.rst index c889fbdf5e47..f0744538f9c4 100644 --- a/docs/model_zoo/transformers.rst +++ b/docs/model_zoo/transformers.rst @@ -818,7 +818,7 @@ Reference `ymcui/Chinese-XLNet `_, `huggingface/xlnet_chinese_large `_, `Knover/luge-dialogue `_, - `huawei-noah/Pretrained-Language-Model/NEZHA-PyTorch/ `_ + `huawei-noah/Pretrained-Language-Model/NEZHA-PyTorch/ `_, `ZhuiyiTechnology/simbert `_ - Lan, Zhenzhong, et al. "Albert: A lite bert for self-supervised learning of language representations." arXiv preprint arXiv:1909.11942 (2019). - Lewis, Mike, et al. "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension." arXiv preprint arXiv:1910.13461 (2019). @@ -826,6 +826,7 @@ Reference - Zaheer, Manzil, et al. "Big bird: Transformers for longer sequences." arXiv preprint arXiv:2007.14062 (2020). - Stephon, Emily, et al. "Blenderbot: Recipes for building an open-domain chatbot." arXiv preprint arXiv:2004.13637 (2020). - Stephon, Emily, et al. "Blenderbot-Small: Recipes for building an open-domain chatbot." arXiv preprint arXiv:2004.13637 (2020). +- Zhang, zhengyan, et al. "CPM: A Large-scale Generative Chinese Pre-trained Language Model." arXiv preprint arXiv:2012.00413 (2020). - Jiang, Zihang, et al. "ConvBERT: Improving BERT with Span-based Dynamic Convolution." arXiv preprint arXiv:2008.02496 (2020). - Nitish, Bryan, et al. "CTRL: A Conditional Transformer Language Model for Controllable Generation." arXiv preprint arXiv:1909.05858 (2019). - Sanh, Victor, et al. "DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter." arXiv preprint arXiv:1910.01108 (2019). diff --git a/examples/dialogue/plato-xl/infer.py b/examples/dialogue/plato-xl/infer.py index cdf6856cac58..79d6d81c74b0 100644 --- a/examples/dialogue/plato-xl/infer.py +++ b/examples/dialogue/plato-xl/infer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import time import argparse from pprint import pprint @@ -73,6 +74,11 @@ def setup_args(): default=4, type=int, help="The number of candidate to procedure beam search. ") + parser.add_argument( + "--num_return_sequences", + default=1, + type=int, + help="The number of returned sequences. ") args = parser.parse_args() @@ -93,10 +99,17 @@ def postprocess_response(token_ids, tokenizer): def infer(args): + if args.faster and args.use_fp16_decoding and os.getenv("PPFG_QKV_MEM_OPT", + "0") == "1": + paddle.set_default_dtype("float16") + model_name = 'plato-xl' - model = UnifiedTransformerLMHeadModel.from_pretrained(model_name) + model = UnifiedTransformerLMHeadModel.from_pretrained( + model_name, load_state_as_np=True) tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name) + model.eval() + context = [ "Hi , Becky , what's up ?", "Not much , except that my mother-in-law is driving me up the wall .", @@ -136,6 +149,7 @@ def infer(args): top_k=args.topk, top_p=args.topp, num_beams=args.num_beams, + num_return_sequences=args.num_return_sequences, use_fp16_decoding=args.use_fp16_decoding, use_faster=args.faster) diff --git a/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py b/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py index 3ed583056a12..582ee07b79e8 100644 --- a/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py +++ b/examples/machine_reading_comprehension/SQuAD/deploy/python/predict.py @@ -23,7 +23,7 @@ import paddle from paddle import inference from paddle.io import DataLoader -from paddlenlp.datasets import load_dataset +from datasets import load_dataset from paddlenlp.data import Pad, Stack, Dict from paddlenlp.metrics.squad import squad_evaluate, compute_prediction diff --git a/paddlenlp/data/sampler.py b/paddlenlp/data/sampler.py index 6e13a4ab3f5b..305fc5df7c37 100644 --- a/paddlenlp/data/sampler.py +++ b/paddlenlp/data/sampler.py @@ -18,7 +18,6 @@ import six import numpy as np -import paddle.distributed as dist class SamplerHelper(object): @@ -391,6 +390,8 @@ def __len__(self): print(list(sampler)) # indices of dataset elements # [0, 2] """ + import paddle.distributed as dist + if num_replicas is None: num_replicas = dist.get_world_size() if rank is None: diff --git a/paddlenlp/datasets/cnn_dailymail.py b/paddlenlp/datasets/cnn_dailymail.py index 3eb172a55606..1768da85ab1c 100644 --- a/paddlenlp/datasets/cnn_dailymail.py +++ b/paddlenlp/datasets/cnn_dailymail.py @@ -20,7 +20,12 @@ from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints -from paddle.distributed import ParallelEnv +try: + from paddle.distributed import ParallelEnv +except Exception as e: + import warnings + warnings.warn("paddle.distributed is not contains in you paddle!") + from paddlenlp.utils.env import DATA_HOME from paddlenlp.utils.log import logger from . import DatasetBuilder diff --git a/paddlenlp/datasets/dataset.py b/paddlenlp/datasets/dataset.py index 5a185378ecd8..8fe2283105cd 100644 --- a/paddlenlp/datasets/dataset.py +++ b/paddlenlp/datasets/dataset.py @@ -26,7 +26,12 @@ import paddlenlp import datasets -import paddle.distributed as dist +try: + import paddle.distributed as dist +except Exception as e: + import warnings + warnings.warn("paddle.distributed is not contains in you paddle!") + from paddle.io import Dataset, IterableDataset from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url, _get_unique_endpoints @@ -45,13 +50,15 @@ def load_from_ppnlp(path, **kwargs): ppnlp_path = paddlenlp.datasets.__path__[0] - path = os.path.split(path)[-1] - path = os.path.join(ppnlp_path, path + '.py') - return origin_load_dataset(path, **kwargs) + new_path = os.path.split(path)[-1] + new_path = os.path.join(ppnlp_path, 'hf_datasets', new_path + '.py') + if os.path.exists(new_path): + return origin_load_dataset(new_path, **kwargs) + else: + return origin_load_dataset(path, **kwargs) -if os.environ.get('ISINTRANET', '0') == '1': - datasets.load_dataset = load_from_ppnlp +datasets.load_dataset = load_from_ppnlp class DatasetTuple: diff --git a/paddlenlp/datasets/dureader_robust.py b/paddlenlp/datasets/dureader_robust.py index bce297ca2883..3c49627eb242 100644 --- a/paddlenlp/datasets/dureader_robust.py +++ b/paddlenlp/datasets/dureader_robust.py @@ -1,6 +1,4 @@ -# coding=utf-8 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,123 +12,74 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lint as: python3 - +import collections import json import os -import datasets -from datasets.tasks import QuestionAnsweringExtractive - -logger = datasets.logging.get_logger(__name__) - -_DESCRIPTION = """\ -DureaderRobust is a chinese reading comprehension \ -dataset, designed to evaluate the MRC models from \ -three aspects: over-sensitivity, over-stability \ -and generalization. -""" - -_URL = "https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz" - - -class DureaderRobustConfig(datasets.BuilderConfig): - """BuilderConfig for DureaderRobust.""" - - def __init__(self, **kwargs): - """BuilderConfig for DureaderRobust. - - Args: - **kwargs: keyword arguments forwarded to super. - """ - super(DureaderRobustConfig, self).__init__(**kwargs) - - -class DureaderRobust(datasets.GeneratorBasedBuilder): - BUILDER_CONFIGS = [ - DureaderRobustConfig( - name="plain_text", - version=datasets.Version("1.0.0", ""), - description="Plain text", ), - ] - - def _info(self): - return datasets.DatasetInfo( - description=_DESCRIPTION, - features=datasets.Features({ - "id": datasets.Value("string"), - "title": datasets.Value("string"), - "context": datasets.Value("string"), - "question": datasets.Value("string"), - "answers": datasets.features.Sequence({ - "text": datasets.Value("string"), - "answer_start": datasets.Value("int32"), - }), - }), - # No default supervised_keys (as we have to pass both question - # and context as input). - supervised_keys=None, - homepage="https://arxiv.org/abs/2004.11142", - task_templates=[ - QuestionAnsweringExtractive( - question_column="question", - context_column="context", - answers_column="answers") - ], ) - - def _split_generators(self, dl_manager): - dl_dir = dl_manager.download_and_extract(_URL) - - return [ - datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={ - "filepath": os.path.join(dl_dir, 'dureader_robust-data', - 'train.json') - }), - datasets.SplitGenerator( - name=datasets.Split.VALIDATION, - gen_kwargs={ - "filepath": os.path.join(dl_dir, 'dureader_robust-data', - 'dev.json') - }), - datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={ - "filepath": os.path.join(dl_dir, 'dureader_robust-data', - 'test.json') - }), - ] - - def _generate_examples(self, filepath): - """This function returns the examples in the raw (text) form.""" - logger.info("generating examples from = %s", filepath) - key = 0 - with open(filepath, encoding="utf-8") as f: - durobust = json.load(f) - for article in durobust["data"]: - title = article.get("title", "") - for paragraph in article["paragraphs"]: - context = paragraph[ - "context"] # do not strip leading blank spaces GH-2585 - for qa in paragraph["qas"]: - answer_starts = [ - answer["answer_start"] - for answer in qa.get("answers", '') - ] - answers = [ - answer["text"] for answer in qa.get("answers", '') - ] - # Features currently used are "context", "question", and "answers". - # Others are extracted here for the ease of future expansions. - yield key, { - "title": title, - "context": context, - "question": qa["question"], - "id": qa["id"], - "answers": { - "answer_start": answer_starts, - "text": answers, - }, - } - key += 1 +from paddle.dataset.common import md5file +from paddle.utils.download import get_path_from_url +from paddlenlp.utils.env import DATA_HOME +from . import DatasetBuilder + +__all__ = ['DuReaderRobust'] + + +class DuReaderRobust(DatasetBuilder): + ''' + The machine reading comprehension dataset (i.e. DuReader robust) is designed + to measure the robustness of a reading comprehension model, including the + over-sensitivity, over-stability and generalization ability of the model. + ''' + + URL = 'https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz' + MD5 = '82f3d191a115ec17808856866787606e' + META_INFO = collections.namedtuple('META_INFO', ('file', 'md5')) + SPLITS = { + 'train': META_INFO( + os.path.join('dureader_robust-data', 'train.json'), + '800a3dcb742f9fdf9b11e0a83433d4be'), + 'dev': META_INFO( + os.path.join('dureader_robust-data', 'dev.json'), + 'ae73cec081eaa28a735204c4898a2222'), + 'test': META_INFO( + os.path.join('dureader_robust-data', 'test.json'), + 'e0e8aa5c7b6d11b6fc3935e29fc7746f') + } + + def _get_data(self, mode, **kwargs): + default_root = os.path.join(DATA_HOME, self.__class__.__name__) + filename, data_hash = self.SPLITS[mode] + fullname = os.path.join(default_root, filename) + if not os.path.exists(fullname) or (data_hash and + not md5file(fullname) == data_hash): + get_path_from_url(self.URL, default_root, self.MD5) + + return fullname + + def _read(self, filename, *args): + with open(filename, "r", encoding="utf8") as f: + input_data = json.load(f)["data"] + for entry in input_data: + title = entry.get("title", "").strip() + for paragraph in entry["paragraphs"]: + context = paragraph["context"].strip() + for qa in paragraph["qas"]: + qas_id = qa["id"] + question = qa["question"].strip() + answer_starts = [ + answer["answer_start"] + for answer in qa.get("answers", []) + ] + answers = [ + answer["text"].strip() + for answer in qa.get("answers", []) + ] + + yield { + 'id': qas_id, + 'title': title, + 'context': context, + 'question': question, + 'answers': answers, + 'answer_starts': answer_starts + } diff --git a/paddlenlp/datasets/hf_datasets/dureader_robust.py b/paddlenlp/datasets/hf_datasets/dureader_robust.py new file mode 100644 index 000000000000..bce297ca2883 --- /dev/null +++ b/paddlenlp/datasets/hf_datasets/dureader_robust.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 + +import json +import os + +import datasets +from datasets.tasks import QuestionAnsweringExtractive + +logger = datasets.logging.get_logger(__name__) + +_DESCRIPTION = """\ +DureaderRobust is a chinese reading comprehension \ +dataset, designed to evaluate the MRC models from \ +three aspects: over-sensitivity, over-stability \ +and generalization. +""" + +_URL = "https://bj.bcebos.com/paddlenlp/datasets/dureader_robust-data.tar.gz" + + +class DureaderRobustConfig(datasets.BuilderConfig): + """BuilderConfig for DureaderRobust.""" + + def __init__(self, **kwargs): + """BuilderConfig for DureaderRobust. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(DureaderRobustConfig, self).__init__(**kwargs) + + +class DureaderRobust(datasets.GeneratorBasedBuilder): + BUILDER_CONFIGS = [ + DureaderRobustConfig( + name="plain_text", + version=datasets.Version("1.0.0", ""), + description="Plain text", ), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features({ + "id": datasets.Value("string"), + "title": datasets.Value("string"), + "context": datasets.Value("string"), + "question": datasets.Value("string"), + "answers": datasets.features.Sequence({ + "text": datasets.Value("string"), + "answer_start": datasets.Value("int32"), + }), + }), + # No default supervised_keys (as we have to pass both question + # and context as input). + supervised_keys=None, + homepage="https://arxiv.org/abs/2004.11142", + task_templates=[ + QuestionAnsweringExtractive( + question_column="question", + context_column="context", + answers_column="answers") + ], ) + + def _split_generators(self, dl_manager): + dl_dir = dl_manager.download_and_extract(_URL) + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, + gen_kwargs={ + "filepath": os.path.join(dl_dir, 'dureader_robust-data', + 'train.json') + }), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, + gen_kwargs={ + "filepath": os.path.join(dl_dir, 'dureader_robust-data', + 'dev.json') + }), + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepath": os.path.join(dl_dir, 'dureader_robust-data', + 'test.json') + }), + ] + + def _generate_examples(self, filepath): + """This function returns the examples in the raw (text) form.""" + logger.info("generating examples from = %s", filepath) + key = 0 + with open(filepath, encoding="utf-8") as f: + durobust = json.load(f) + for article in durobust["data"]: + title = article.get("title", "") + for paragraph in article["paragraphs"]: + context = paragraph[ + "context"] # do not strip leading blank spaces GH-2585 + for qa in paragraph["qas"]: + answer_starts = [ + answer["answer_start"] + for answer in qa.get("answers", '') + ] + answers = [ + answer["text"] for answer in qa.get("answers", '') + ] + # Features currently used are "context", "question", and "answers". + # Others are extracted here for the ease of future expansions. + yield key, { + "title": title, + "context": context, + "question": qa["question"], + "id": qa["id"], + "answers": { + "answer_start": answer_starts, + "text": answers, + }, + } + key += 1 diff --git a/paddlenlp/datasets/xnli.py b/paddlenlp/datasets/xnli.py index 7616b1e27c45..7aa6f9d5471a 100644 --- a/paddlenlp/datasets/xnli.py +++ b/paddlenlp/datasets/xnli.py @@ -20,7 +20,12 @@ from paddle.dataset.common import md5file from paddle.utils.download import get_path_from_url, _decompress, _get_unique_endpoints -from paddle.distributed import ParallelEnv +try: + from paddle.distributed import ParallelEnv +except Exception as e: + import warnings + warnings.warn("paddle.distributed is not contains in you paddle!") + from paddlenlp.utils.env import DATA_HOME from paddlenlp.utils.log import logger from . import DatasetBuilder diff --git a/paddlenlp/ops/distributed/parallel.py b/paddlenlp/ops/distributed/parallel.py index b5e2b0111c07..9f937a219fc0 100644 --- a/paddlenlp/ops/distributed/parallel.py +++ b/paddlenlp/ops/distributed/parallel.py @@ -15,7 +15,11 @@ import paddle import paddle.nn as nn from paddle.fluid.framework import in_dygraph_mode -from paddle.distributed.fleet import fleet +try: + from paddle.distributed.fleet import fleet +except Exception as e: + import warnings + warnings.warn("paddle.distributed is not contains in you paddle!") __all__ = [ 'guard', diff --git a/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py b/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py index a519a393345a..7d62a281a94f 100644 --- a/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py +++ b/paddlenlp/ops/faster_transformer/sample/plato_export_model_sample.py @@ -51,11 +51,6 @@ def parse_args(): "--max_out_len", default=64, type=int, help="Maximum output length. ") parser.add_argument( "--min_out_len", default=1, type=int, help="Minimum output length. ") - parser.add_argument( - "--num_return_sequence", - default=1, - type=int, - help="The number of returned sequence. ") parser.add_argument( "--temperature", default=1.0, @@ -95,8 +90,12 @@ def do_predict(args): place = "gpu" place = paddle.set_device(place) + if args.use_fp16_decoding and os.getenv("PPFG_QKV_MEM_OPT", "0") == "1": + paddle.set_default_dtype("float16") + model_name = 'plato-xl' - model = UnifiedTransformerLMHeadModel.from_pretrained(model_name) + model = UnifiedTransformerLMHeadModel.from_pretrained( + model_name, load_state_as_np=True) tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name) plato = FasterUnifiedTransformer( diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index b4d64ca9e7c7..02553bd307e7 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -98,4 +98,4 @@ from .luke.tokenizer import * from .megatronbert.modeling import * from .megatronbert.tokenizer import * -from .semantic_indexing.modeling import * +from .semantic_search.modeling import * diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index 0480bbc7ed2e..e7c55cff98d0 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -238,6 +238,32 @@ class ErniePretrainedModel(PretrainedModel): "vocab_size": 30522, "pad_token_id": 0, }, + "rocketqa-zh-dureader-cross-encoder": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 513, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 18000, + "pad_token_id": 0, + }, + "rocketqa-v1-marco-cross-encoder": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 4, + "vocab_size": 30522, + "pad_token_id": 0, + }, } resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = { @@ -260,6 +286,10 @@ class ErniePretrainedModel(PretrainedModel): "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_query_encoder.pdparams", "rocketqa-v1-marco-para-encoder": "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_para_encoder.pdparams", + "rocketqa-zh-dureader-cross-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_zh_dureader_cross_encoder.pdparams", + "rocketqa-v1-marco-cross-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa_v1_marco_cross_encoder.pdparams", } } base_model_prefix = "ernie" diff --git a/paddlenlp/transformers/ernie/tokenizer.py b/paddlenlp/transformers/ernie/tokenizer.py index 09110c1aa55c..fd6c89243268 100644 --- a/paddlenlp/transformers/ernie/tokenizer.py +++ b/paddlenlp/transformers/ernie/tokenizer.py @@ -101,6 +101,10 @@ class ErnieTokenizer(PretrainedTokenizer): "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt", "rocketqa-v1-marco-para-encoder": "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt", + "rocketqa-zh-dureader-cross-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-zh-dureader-vocab.txt", + "rocketqa-v1-marco-cross-encoder": + "https://bj.bcebos.com/paddlenlp/models/transformers/rocketqa/rocketqa-v1-marco-vocab.txt", } } pretrained_init_configuration = { @@ -143,6 +147,12 @@ class ErnieTokenizer(PretrainedTokenizer): "rocketqa-v1-marco-para-encoder": { "do_lower_case": True }, + "rocketqa-zh-dureader-cross-encoder": { + "do_lower_case": True + }, + "rocketqa-v1-marco-cross-encoder": { + "do_lower_case": True + }, } def __init__(self, diff --git a/paddlenlp/transformers/semantic_indexing/__init__.py b/paddlenlp/transformers/semantic_search/__init__.py similarity index 100% rename from paddlenlp/transformers/semantic_indexing/__init__.py rename to paddlenlp/transformers/semantic_search/__init__.py diff --git a/paddlenlp/transformers/semantic_indexing/modeling.py b/paddlenlp/transformers/semantic_search/modeling.py similarity index 73% rename from paddlenlp/transformers/semantic_indexing/modeling.py rename to paddlenlp/transformers/semantic_search/modeling.py index 897d6bd94f71..ec9817420a3f 100644 --- a/paddlenlp/transformers/semantic_indexing/modeling.py +++ b/paddlenlp/transformers/semantic_search/modeling.py @@ -18,13 +18,15 @@ from ..ernie.modeling import ErniePretrainedModel -__all__ = ['ErnieDualEncoder'] +__all__ = ['ErnieDualEncoder', 'ErnieCrossEncoder'] class ErnieEncoder(ErniePretrainedModel): - def __init__(self, ernie): + def __init__(self, ernie, dropout=None, num_classes=2): super(ErnieEncoder, self).__init__() self.ernie = ernie # allow ernie to be config + self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) + self.classifier = nn.Linear(768, num_classes) self.apply(self.init_weights) def init_weights(self, layer): @@ -37,15 +39,12 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None): - sequence_output, _ = self.ernie( + sequence_output, pool_output = self.ernie( input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask) - - # Outputs pooled_embedding - pooled_output = sequence_output[:, 0] - return pooled_output + return sequence_output, pool_output class ErnieDualEncoder(nn.Layer): @@ -92,7 +91,6 @@ def __init__(self, elif title_model_name_or_path is not None: self.title_ernie = ErnieEncoder.from_pretrained( title_model_name_or_path) - self.dropout = nn.Dropout(dropout if dropout is not None else 0.1) def get_semantic_embedding(self, data_loader): self.eval() @@ -116,12 +114,13 @@ def get_pooled_embedding(self, assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \ "Please check whether your parameter for `is_query` are consistent with DualEncoder initialization." if is_query: - pooled_embedding = self.query_ernie(input_ids, token_type_ids, - position_ids, attention_mask) + sequence_output, _ = self.query_ernie(input_ids, token_type_ids, + position_ids, attention_mask) + else: - pooled_embedding = self.title_ernie(input_ids, token_type_ids, - position_ids, attention_mask) - return pooled_embedding + sequence_output, _ = self.title_ernie(input_ids, token_type_ids, + position_ids, attention_mask) + return sequence_output[:, 0] def cosine_sim(self, query_input_ids, @@ -207,3 +206,66 @@ def forward(self, outputs = {"loss": loss, "accuracy": accuracy} return outputs + + +class ErnieCrossEncoder(nn.Layer): + """ + Example: + + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieCrossEncoder, ErnieTokenizer + + model = ErnieCrossEncoder("rocketqa-zh-dureader-cross-encoder") + tokenizer = ErnieTokenizer.from_pretrained("rocketqa-zh-dureader-cross-encoder") + + inputs = tokenizer("你们好", text_pair="你好") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + + # Get similarity probability of text pair. + embedding = model.matching(**inputs) + + """ + + def __init__(self, pretrain_model_name_or_path, num_classes=2, + dropout=None): + super().__init__() + self.ernie = ErnieEncoder.from_pretrained(pretrain_model_name_or_path) + + def matching(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None, + return_prob_distributation=False): + _, pooled_output = self.ernie( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + pooled_output = self.ernie.dropout(pooled_output) + cls_embedding = self.ernie.classifier(pooled_output) + probs = F.softmax(cls_embedding, axis=1) + if return_prob_distributation: + return probs + return probs[:, 1] + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None, + labels=None): + probs = self.matching( + input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask, + return_prob_distributation=True) + accuracy = paddle.metric.accuracy(input=probs, label=labels) + loss = F.cross_entropy(input=logits, label=labels) + + outputs = {"loss": loss, "accuracy": accuracy} + + return outputs diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py index 305f74d5973c..a038d48dd7d0 100644 --- a/paddlenlp/transformers/tokenizer_utils.py +++ b/paddlenlp/transformers/tokenizer_utils.py @@ -21,6 +21,7 @@ import os import six import unicodedata +from collections import OrderedDict, UserDict from shutil import copyfile from typing import Iterable, Iterator, Optional, List, Any, Callable, Union @@ -45,6 +46,29 @@ ] +class BatchEncoding(UserDict): + def __init__(self, data=None): + super().__init__(data) + + def __getitem__(self, item): + if isinstance(item, str): + return self.data[item] + else: + raise KeyError( + "Indexing with integers is not available when using tokenizer.__call__()" + " with return_dict=True. Please set return_dict to False to use integer indexing." + ) + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def convert_to_unicode(text): """ Converts `text` to Unicode (if it's not already), assuming utf-8 input. @@ -559,7 +583,8 @@ def __call__(self, return_attention_mask=False, return_length=False, return_overflowing_tokens=False, - return_special_tokens_mask=False): + return_special_tokens_mask=False, + return_dict=True): """ Performs tokenization and uses the tokenized tokens to prepare model inputs. It supports sequence or sequence pair as input, and batch input @@ -695,7 +720,8 @@ def __call__(self, return_attention_mask=return_attention_mask, return_length=return_length, return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask) + return_special_tokens_mask=return_special_tokens_mask, + return_dict=return_dict) else: return self.encode( text=text, @@ -1534,7 +1560,8 @@ def batch_encode(self, return_attention_mask=False, return_length=False, return_overflowing_tokens=False, - return_special_tokens_mask=False): + return_special_tokens_mask=False, + return_dict=True): """ Performs tokenization and uses the tokenized tokens to prepare model inputs. It supports batch inputs of sequence or sequence pair. @@ -1765,10 +1792,13 @@ def get_input_ids(text): range(len(encoded_inputs["input_ids"]))) encoded_inputs['overflow_to_sample'] = example_id - for key, value in encoded_inputs.items(): - if key not in batch_outputs: - batch_outputs[key] = [] - batch_outputs[key].append(value) + if return_dict: + for key, value in encoded_inputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + else: + batch_encode_inputs.append(encoded_inputs) if offset + length == len(second_ids): break offset += min(length, stride) @@ -1787,12 +1817,16 @@ def get_input_ids(text): return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask) - for key, value in encoded_inputs.items(): - if key not in batch_outputs: - batch_outputs[key] = [] - batch_outputs[key].append(value) + if return_dict: + for key, value in encoded_inputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + else: + batch_encode_inputs.append(encoded_inputs) - return batch_outputs + return BatchEncoding( + batch_outputs) if return_dict else batch_encode_inputs def get_offset_mapping(self, text): """ diff --git a/paddlenlp/transformers/unified_transformer/modeling.py b/paddlenlp/transformers/unified_transformer/modeling.py index 49c61f37ad6c..c799b49b4522 100644 --- a/paddlenlp/transformers/unified_transformer/modeling.py +++ b/paddlenlp/transformers/unified_transformer/modeling.py @@ -107,7 +107,7 @@ class UnifiedTransformerPretrainedModel(PretrainedModel): "attention_probs_dropout_prob": 0.1, "normalize_before": True, "max_position_embeddings": 1024, - "type_vocab_size": 2, + "type_vocab_size": 3, "role_type_size": 128, "initializer_range": 0.02, "unk_token_id": 0,