From 8980d64f5ff5e466401c8f4690823f75a2ba96f9 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Fri, 12 Aug 2022 17:13:49 +0800 Subject: [PATCH] Add label and loss support for BERT/RoBERTa/ERNIE (#3013) * Add label and loss support for BERT/RoBERTa. * Add label and loss support for ERNIE. * Update api docs. --- .../source/paddlenlp.dataaug.base_augment.rst | 7 + docs/source/paddlenlp.dataaug.rst | 17 + docs/source/paddlenlp.dataaug.word_delete.rst | 7 + docs/source/paddlenlp.dataaug.word_insert.rst | 7 + .../paddlenlp.dataaug.word_substitute.rst | 7 + docs/source/paddlenlp.dataaug.word_swap.rst | 7 + docs/source/paddlenlp.datasets.rst | 7 +- docs/source/paddlenlp.rst | 1 + .../paddlenlp.taskflow.code_generation.rst | 7 + docs/source/paddlenlp.taskflow.rst | 2 + ...ddlenlp.taskflow.text2image_generation.rst | 7 + ...paddlenlp.transformers.artist.modeling.rst | 7 + docs/source/paddlenlp.transformers.artist.rst | 14 + ...addlenlp.transformers.artist.tokenizer.rst | 7 + ...addlenlp.transformers.codegen.modeling.rst | 7 + .../source/paddlenlp.transformers.codegen.rst | 14 + ...ddlenlp.transformers.codegen.tokenizer.rst | 7 + ...dlenlp.transformers.dallebart.modeling.rst | 7 + .../paddlenlp.transformers.dallebart.rst | 14 + ...lenlp.transformers.dallebart.tokenizer.rst | 7 + ....transformers.ernie_m.faster_tokenizer.rst | 7 + .../source/paddlenlp.transformers.ernie_m.rst | 1 + ...dlenlp.transformers.gau_alpha.modeling.rst | 7 + .../paddlenlp.transformers.gau_alpha.rst | 14 + ...lenlp.transformers.gau_alpha.tokenizer.rst | 7 + .../paddlenlp.transformers.model_outputs.rst | 6 + .../paddlenlp.transformers.opt.modeling.rst | 7 + docs/source/paddlenlp.transformers.opt.rst | 13 + docs/source/paddlenlp.transformers.rst | 8 + ...p.transformers.sentencepiece_model_pb2.rst | 6 + ...transformers.tinybert.faster_tokenizer.rst | 7 + .../paddlenlp.transformers.tinybert.rst | 1 + .../paddlenlp.transformers.xlm.modeling.rst | 7 + docs/source/paddlenlp.transformers.xlm.rst | 14 + .../paddlenlp.transformers.xlm.tokenizer.rst | 7 + paddlenlp/transformers/bert/modeling.py | 216 +++++++++--- paddlenlp/transformers/bert/tokenizer.py | 3 + paddlenlp/transformers/ernie/modeling.py | 207 ++++++++++-- paddlenlp/transformers/model_outputs.py | 44 +++ paddlenlp/transformers/roberta/modeling.py | 314 ++++++++++++++---- tests/transformers/bert/test_modeling.py | 107 +++++- tests/transformers/test_modeling_common.py | 7 +- tests/transformers/test_tokenizer_common.py | 3 +- 43 files changed, 1001 insertions(+), 172 deletions(-) create mode 100644 docs/source/paddlenlp.dataaug.base_augment.rst create mode 100644 docs/source/paddlenlp.dataaug.rst create mode 100644 docs/source/paddlenlp.dataaug.word_delete.rst create mode 100644 docs/source/paddlenlp.dataaug.word_insert.rst create mode 100644 docs/source/paddlenlp.dataaug.word_substitute.rst create mode 100644 docs/source/paddlenlp.dataaug.word_swap.rst create mode 100644 docs/source/paddlenlp.taskflow.code_generation.rst create mode 100644 docs/source/paddlenlp.taskflow.text2image_generation.rst create mode 100644 docs/source/paddlenlp.transformers.artist.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.artist.rst create mode 100644 docs/source/paddlenlp.transformers.artist.tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.codegen.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.codegen.rst create mode 100644 docs/source/paddlenlp.transformers.codegen.tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.dallebart.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.dallebart.rst create mode 100644 docs/source/paddlenlp.transformers.dallebart.tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.ernie_m.faster_tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.gau_alpha.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.gau_alpha.rst create mode 100644 docs/source/paddlenlp.transformers.gau_alpha.tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.model_outputs.rst create mode 100644 docs/source/paddlenlp.transformers.opt.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.opt.rst create mode 100644 docs/source/paddlenlp.transformers.sentencepiece_model_pb2.rst create mode 100644 docs/source/paddlenlp.transformers.tinybert.faster_tokenizer.rst create mode 100644 docs/source/paddlenlp.transformers.xlm.modeling.rst create mode 100644 docs/source/paddlenlp.transformers.xlm.rst create mode 100644 docs/source/paddlenlp.transformers.xlm.tokenizer.rst diff --git a/docs/source/paddlenlp.dataaug.base_augment.rst b/docs/source/paddlenlp.dataaug.base_augment.rst new file mode 100644 index 000000000000..be43595b87cb --- /dev/null +++ b/docs/source/paddlenlp.dataaug.base_augment.rst @@ -0,0 +1,7 @@ +base\_augment +====================================== + +.. automodule:: paddlenlp.dataaug.base_augment + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.dataaug.rst b/docs/source/paddlenlp.dataaug.rst new file mode 100644 index 000000000000..3df52b5ab89c --- /dev/null +++ b/docs/source/paddlenlp.dataaug.rst @@ -0,0 +1,17 @@ +paddlenlp.dataaug +========================= + +.. automodule:: paddlenlp.dataaug + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.dataaug.base_augment + paddlenlp.dataaug.word_delete + paddlenlp.dataaug.word_insert + paddlenlp.dataaug.word_substitute + paddlenlp.dataaug.word_swap diff --git a/docs/source/paddlenlp.dataaug.word_delete.rst b/docs/source/paddlenlp.dataaug.word_delete.rst new file mode 100644 index 000000000000..36f4fd46bb83 --- /dev/null +++ b/docs/source/paddlenlp.dataaug.word_delete.rst @@ -0,0 +1,7 @@ +word\_delete +===================================== + +.. automodule:: paddlenlp.dataaug.word_delete + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.dataaug.word_insert.rst b/docs/source/paddlenlp.dataaug.word_insert.rst new file mode 100644 index 000000000000..90aea8210552 --- /dev/null +++ b/docs/source/paddlenlp.dataaug.word_insert.rst @@ -0,0 +1,7 @@ +word\_insert +===================================== + +.. automodule:: paddlenlp.dataaug.word_insert + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.dataaug.word_substitute.rst b/docs/source/paddlenlp.dataaug.word_substitute.rst new file mode 100644 index 000000000000..dd343af4cc47 --- /dev/null +++ b/docs/source/paddlenlp.dataaug.word_substitute.rst @@ -0,0 +1,7 @@ +word\_substitute +========================================= + +.. automodule:: paddlenlp.dataaug.word_substitute + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.dataaug.word_swap.rst b/docs/source/paddlenlp.dataaug.word_swap.rst new file mode 100644 index 000000000000..bedc8d99621e --- /dev/null +++ b/docs/source/paddlenlp.dataaug.word_swap.rst @@ -0,0 +1,7 @@ +word\_swap +=================================== + +.. automodule:: paddlenlp.dataaug.word_swap + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.datasets.rst b/docs/source/paddlenlp.datasets.rst index 5567142ddd85..6681f0fb142f 100644 --- a/docs/source/paddlenlp.datasets.rst +++ b/docs/source/paddlenlp.datasets.rst @@ -5,8 +5,13 @@ paddlenlp.datasets :members: :no-undoc-members: + .. toctree:: :maxdepth: 4 - paddlenlp.datasets.dataset + +.. toctree:: + :maxdepth: 4 + + paddlenlp.datasets.dataset diff --git a/docs/source/paddlenlp.rst b/docs/source/paddlenlp.rst index 45be09479d2d..fbc7672055fa 100644 --- a/docs/source/paddlenlp.rst +++ b/docs/source/paddlenlp.rst @@ -11,6 +11,7 @@ paddlenlp :maxdepth: 4 paddlenlp.data + paddlenlp.dataaug paddlenlp.datasets paddlenlp.embeddings paddlenlp.experimental diff --git a/docs/source/paddlenlp.taskflow.code_generation.rst b/docs/source/paddlenlp.taskflow.code_generation.rst new file mode 100644 index 000000000000..7fbd7ab20bb0 --- /dev/null +++ b/docs/source/paddlenlp.taskflow.code_generation.rst @@ -0,0 +1,7 @@ +code\_generation +========================================== + +.. automodule:: paddlenlp.taskflow.code_generation + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.taskflow.rst b/docs/source/paddlenlp.taskflow.rst index 4dfe61684638..c879b3dfc31d 100644 --- a/docs/source/paddlenlp.taskflow.rst +++ b/docs/source/paddlenlp.taskflow.rst @@ -16,6 +16,7 @@ paddlenlp.taskflow .. toctree:: :maxdepth: 4 + paddlenlp.taskflow.code_generation paddlenlp.taskflow.dependency_parsing paddlenlp.taskflow.dialogue paddlenlp.taskflow.information_extraction @@ -28,6 +29,7 @@ paddlenlp.taskflow paddlenlp.taskflow.sentiment_analysis paddlenlp.taskflow.task paddlenlp.taskflow.taskflow + paddlenlp.taskflow.text2image_generation paddlenlp.taskflow.text_correction paddlenlp.taskflow.text_generation paddlenlp.taskflow.text_similarity diff --git a/docs/source/paddlenlp.taskflow.text2image_generation.rst b/docs/source/paddlenlp.taskflow.text2image_generation.rst new file mode 100644 index 000000000000..c21b2636403b --- /dev/null +++ b/docs/source/paddlenlp.taskflow.text2image_generation.rst @@ -0,0 +1,7 @@ +text2image\_generation +================================================ + +.. automodule:: paddlenlp.taskflow.text2image_generation + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.artist.modeling.rst b/docs/source/paddlenlp.transformers.artist.modeling.rst new file mode 100644 index 000000000000..751786f8b99f --- /dev/null +++ b/docs/source/paddlenlp.transformers.artist.modeling.rst @@ -0,0 +1,7 @@ +modeling +============================================= + +.. automodule:: paddlenlp.transformers.artist.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.artist.rst b/docs/source/paddlenlp.transformers.artist.rst new file mode 100644 index 000000000000..cacfe699e3e5 --- /dev/null +++ b/docs/source/paddlenlp.transformers.artist.rst @@ -0,0 +1,14 @@ +artist +===================================== + +.. automodule:: paddlenlp.transformers.artist + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.artist.modeling + paddlenlp.transformers.artist.tokenizer diff --git a/docs/source/paddlenlp.transformers.artist.tokenizer.rst b/docs/source/paddlenlp.transformers.artist.tokenizer.rst new file mode 100644 index 000000000000..0e1a05f005ae --- /dev/null +++ b/docs/source/paddlenlp.transformers.artist.tokenizer.rst @@ -0,0 +1,7 @@ +tokenizer +============================================== + +.. automodule:: paddlenlp.transformers.artist.tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.codegen.modeling.rst b/docs/source/paddlenlp.transformers.codegen.modeling.rst new file mode 100644 index 000000000000..94bf3c40080b --- /dev/null +++ b/docs/source/paddlenlp.transformers.codegen.modeling.rst @@ -0,0 +1,7 @@ +modeling +============================================== + +.. automodule:: paddlenlp.transformers.codegen.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.codegen.rst b/docs/source/paddlenlp.transformers.codegen.rst new file mode 100644 index 000000000000..06b98a2db02c --- /dev/null +++ b/docs/source/paddlenlp.transformers.codegen.rst @@ -0,0 +1,14 @@ +codegen +====================================== + +.. automodule:: paddlenlp.transformers.codegen + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.codegen.modeling + paddlenlp.transformers.codegen.tokenizer diff --git a/docs/source/paddlenlp.transformers.codegen.tokenizer.rst b/docs/source/paddlenlp.transformers.codegen.tokenizer.rst new file mode 100644 index 000000000000..49ebd1510971 --- /dev/null +++ b/docs/source/paddlenlp.transformers.codegen.tokenizer.rst @@ -0,0 +1,7 @@ +tokenizer +=============================================== + +.. automodule:: paddlenlp.transformers.codegen.tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.dallebart.modeling.rst b/docs/source/paddlenlp.transformers.dallebart.modeling.rst new file mode 100644 index 000000000000..43b9008645f6 --- /dev/null +++ b/docs/source/paddlenlp.transformers.dallebart.modeling.rst @@ -0,0 +1,7 @@ +modeling +================================================ + +.. automodule:: paddlenlp.transformers.dallebart.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.dallebart.rst b/docs/source/paddlenlp.transformers.dallebart.rst new file mode 100644 index 000000000000..b4c076fe1095 --- /dev/null +++ b/docs/source/paddlenlp.transformers.dallebart.rst @@ -0,0 +1,14 @@ +dallebart +======================================== + +.. automodule:: paddlenlp.transformers.dallebart + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.dallebart.modeling + paddlenlp.transformers.dallebart.tokenizer diff --git a/docs/source/paddlenlp.transformers.dallebart.tokenizer.rst b/docs/source/paddlenlp.transformers.dallebart.tokenizer.rst new file mode 100644 index 000000000000..dc6ce082462a --- /dev/null +++ b/docs/source/paddlenlp.transformers.dallebart.tokenizer.rst @@ -0,0 +1,7 @@ +tokenizer +================================================= + +.. automodule:: paddlenlp.transformers.dallebart.tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.ernie_m.faster_tokenizer.rst b/docs/source/paddlenlp.transformers.ernie_m.faster_tokenizer.rst new file mode 100644 index 000000000000..34cc72f0a814 --- /dev/null +++ b/docs/source/paddlenlp.transformers.ernie_m.faster_tokenizer.rst @@ -0,0 +1,7 @@ +faster\_tokenizer +======================================================== + +.. automodule:: paddlenlp.transformers.ernie_m.faster_tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.ernie_m.rst b/docs/source/paddlenlp.transformers.ernie_m.rst index c65459cda686..0cb463814bd2 100644 --- a/docs/source/paddlenlp.transformers.ernie_m.rst +++ b/docs/source/paddlenlp.transformers.ernie_m.rst @@ -10,5 +10,6 @@ ernie\_m .. toctree:: :maxdepth: 4 + paddlenlp.transformers.ernie_m.faster_tokenizer paddlenlp.transformers.ernie_m.modeling paddlenlp.transformers.ernie_m.tokenizer diff --git a/docs/source/paddlenlp.transformers.gau_alpha.modeling.rst b/docs/source/paddlenlp.transformers.gau_alpha.modeling.rst new file mode 100644 index 000000000000..9bc4a2e4e01c --- /dev/null +++ b/docs/source/paddlenlp.transformers.gau_alpha.modeling.rst @@ -0,0 +1,7 @@ +modeling +================================================= + +.. automodule:: paddlenlp.transformers.gau_alpha.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.gau_alpha.rst b/docs/source/paddlenlp.transformers.gau_alpha.rst new file mode 100644 index 000000000000..91684515e801 --- /dev/null +++ b/docs/source/paddlenlp.transformers.gau_alpha.rst @@ -0,0 +1,14 @@ +gau\_alpha +========================================= + +.. automodule:: paddlenlp.transformers.gau_alpha + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.gau_alpha.modeling + paddlenlp.transformers.gau_alpha.tokenizer diff --git a/docs/source/paddlenlp.transformers.gau_alpha.tokenizer.rst b/docs/source/paddlenlp.transformers.gau_alpha.tokenizer.rst new file mode 100644 index 000000000000..40d966a2e0ef --- /dev/null +++ b/docs/source/paddlenlp.transformers.gau_alpha.tokenizer.rst @@ -0,0 +1,7 @@ +tokenizer +================================================== + +.. automodule:: paddlenlp.transformers.gau_alpha.tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.model_outputs.rst b/docs/source/paddlenlp.transformers.model_outputs.rst new file mode 100644 index 000000000000..7c3220fc0fd0 --- /dev/null +++ b/docs/source/paddlenlp.transformers.model_outputs.rst @@ -0,0 +1,6 @@ +model\_outputs +============================================ + +.. automodule:: paddlenlp.transformers.model_outputs + :members: + :no-undoc-members: diff --git a/docs/source/paddlenlp.transformers.opt.modeling.rst b/docs/source/paddlenlp.transformers.opt.modeling.rst new file mode 100644 index 000000000000..71e2af7e61dd --- /dev/null +++ b/docs/source/paddlenlp.transformers.opt.modeling.rst @@ -0,0 +1,7 @@ +modeling +========================================== + +.. automodule:: paddlenlp.transformers.opt.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.opt.rst b/docs/source/paddlenlp.transformers.opt.rst new file mode 100644 index 000000000000..445e3c2639b9 --- /dev/null +++ b/docs/source/paddlenlp.transformers.opt.rst @@ -0,0 +1,13 @@ +opt +================================== + +.. automodule:: paddlenlp.transformers.opt + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.opt.modeling diff --git a/docs/source/paddlenlp.transformers.rst b/docs/source/paddlenlp.transformers.rst index 7d15efd49cbb..bff73dcbad1b 100644 --- a/docs/source/paddlenlp.transformers.rst +++ b/docs/source/paddlenlp.transformers.rst @@ -11,6 +11,7 @@ paddlenlp.transformers :maxdepth: 4 paddlenlp.transformers.albert + paddlenlp.transformers.artist paddlenlp.transformers.auto paddlenlp.transformers.bart paddlenlp.transformers.bert @@ -19,8 +20,10 @@ paddlenlp.transformers paddlenlp.transformers.blenderbot paddlenlp.transformers.blenderbot_small paddlenlp.transformers.chinesebert + paddlenlp.transformers.codegen paddlenlp.transformers.convbert paddlenlp.transformers.ctrl + paddlenlp.transformers.dallebart paddlenlp.transformers.distilbert paddlenlp.transformers.electra paddlenlp.transformers.ernie @@ -31,6 +34,7 @@ paddlenlp.transformers paddlenlp.transformers.ernie_m paddlenlp.transformers.fnet paddlenlp.transformers.funnel + paddlenlp.transformers.gau_alpha paddlenlp.transformers.gpt paddlenlp.transformers.layoutlm paddlenlp.transformers.layoutlmv2 @@ -41,6 +45,7 @@ paddlenlp.transformers paddlenlp.transformers.mobilebert paddlenlp.transformers.mpnet paddlenlp.transformers.nezha + paddlenlp.transformers.opt paddlenlp.transformers.ppminilm paddlenlp.transformers.prophetnet paddlenlp.transformers.reformer @@ -56,6 +61,7 @@ paddlenlp.transformers paddlenlp.transformers.transformer paddlenlp.transformers.unified_transformer paddlenlp.transformers.unimo + paddlenlp.transformers.xlm paddlenlp.transformers.xlnet @@ -67,8 +73,10 @@ paddlenlp.transformers paddlenlp.transformers.distill_utils paddlenlp.transformers.export paddlenlp.transformers.generation_utils + paddlenlp.transformers.model_outputs paddlenlp.transformers.model_utils paddlenlp.transformers.optimization + paddlenlp.transformers.sentencepiece_model_pb2 paddlenlp.transformers.tokenizer_utils paddlenlp.transformers.tokenizer_utils_base paddlenlp.transformers.tokenizer_utils_faster diff --git a/docs/source/paddlenlp.transformers.sentencepiece_model_pb2.rst b/docs/source/paddlenlp.transformers.sentencepiece_model_pb2.rst new file mode 100644 index 000000000000..27793b44085a --- /dev/null +++ b/docs/source/paddlenlp.transformers.sentencepiece_model_pb2.rst @@ -0,0 +1,6 @@ +sentencepiece\_model\_pb2 +======================================================= + +.. automodule:: paddlenlp.transformers.sentencepiece_model_pb2 + :members: + :no-undoc-members: diff --git a/docs/source/paddlenlp.transformers.tinybert.faster_tokenizer.rst b/docs/source/paddlenlp.transformers.tinybert.faster_tokenizer.rst new file mode 100644 index 000000000000..dd16db861af0 --- /dev/null +++ b/docs/source/paddlenlp.transformers.tinybert.faster_tokenizer.rst @@ -0,0 +1,7 @@ +faster\_tokenizer +======================================================== + +.. automodule:: paddlenlp.transformers.tinybert.faster_tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.tinybert.rst b/docs/source/paddlenlp.transformers.tinybert.rst index e73b38da2d7d..84d1f4ba476a 100644 --- a/docs/source/paddlenlp.transformers.tinybert.rst +++ b/docs/source/paddlenlp.transformers.tinybert.rst @@ -10,5 +10,6 @@ tinybert .. toctree:: :maxdepth: 4 + paddlenlp.transformers.tinybert.faster_tokenizer paddlenlp.transformers.tinybert.modeling paddlenlp.transformers.tinybert.tokenizer diff --git a/docs/source/paddlenlp.transformers.xlm.modeling.rst b/docs/source/paddlenlp.transformers.xlm.modeling.rst new file mode 100644 index 000000000000..24df639bdc43 --- /dev/null +++ b/docs/source/paddlenlp.transformers.xlm.modeling.rst @@ -0,0 +1,7 @@ +modeling +========================================== + +.. automodule:: paddlenlp.transformers.xlm.modeling + :members: + :no-undoc-members: + :show-inheritance: diff --git a/docs/source/paddlenlp.transformers.xlm.rst b/docs/source/paddlenlp.transformers.xlm.rst new file mode 100644 index 000000000000..54ec485dd13b --- /dev/null +++ b/docs/source/paddlenlp.transformers.xlm.rst @@ -0,0 +1,14 @@ +xlm +================================== + +.. automodule:: paddlenlp.transformers.xlm + :members: + :no-undoc-members: + :show-inheritance: + + +.. toctree:: + :maxdepth: 4 + + paddlenlp.transformers.xlm.modeling + paddlenlp.transformers.xlm.tokenizer diff --git a/docs/source/paddlenlp.transformers.xlm.tokenizer.rst b/docs/source/paddlenlp.transformers.xlm.tokenizer.rst new file mode 100644 index 000000000000..9ba4a1cddbe0 --- /dev/null +++ b/docs/source/paddlenlp.transformers.xlm.tokenizer.rst @@ -0,0 +1,7 @@ +tokenizer +=========================================== + +.. automodule:: paddlenlp.transformers.xlm.tokenizer + :members: + :no-undoc-members: + :show-inheritance: diff --git a/paddlenlp/transformers/bert/modeling.py b/paddlenlp/transformers/bert/modeling.py index 4295260212ca..8a0d955606f8 100644 --- a/paddlenlp/transformers/bert/modeling.py +++ b/paddlenlp/transformers/bert/modeling.py @@ -619,14 +619,14 @@ def forward(self, Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): - Whether to return a `ModelOutput` object. If `False`, the output + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - An instance of `BaseModelOutputWithPoolingAndCrossAttentions` if + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of - `BaseModelOutputWithPoolingAndCrossAttentions`. + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. Example: .. code-block:: @@ -744,6 +744,8 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + start_positions=None, + end_positions=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -755,19 +757,32 @@ def forward(self, See :class:`BertModel`. token_type_ids (Tensor, optional): See :class:`BertModel`. + position_ids(Tensor, optional): + See :class:`BertModel`. + attention_mask (Tensor, optional): + See :class:`BertModel`. + start_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (`start_logits`, `end_logits`). - - With the fields: - - - `start_logits` (Tensor): - A tensor of the input token classification logits, indicates the start position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. - - - `end_logits` (Tensor): - A tensor of the input token classification logits, indicates the end position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. + An instance of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput`. Example: .. code-block:: @@ -802,12 +817,28 @@ def forward(self, start_logits, end_logits = paddle.unstack(x=logits, axis=0) total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if start_positions.ndim > 1: + start_positions = start_positions.squeeze(-1) + if start_positions.ndim > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = paddle.shape(start_logits)[1] + start_positions = start_positions.clip(0, ignored_index) + end_positions = end_positions.clip(0, ignored_index) + + loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( + loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, @@ -846,6 +877,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -859,12 +891,27 @@ def forward(self, See :class:`BertModel`. position_ids(Tensor, optional): See :class:`BertModel`. - attention_mask (list, optional): + attention_mask (Tensor, optional): See :class:`BertModel`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input text classification logits. - Shape as `[batch_size, num_classes]` and dtype as float32. + An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: @@ -897,14 +944,26 @@ def forward(self, pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - # TODO(guosheng): Support loss loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -942,6 +1001,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -957,10 +1017,22 @@ def forward(self, See :class:`BertModel`. attention_mask (list, optional): See :class:`BertModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input token classification logits. - Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: @@ -993,14 +1065,18 @@ def forward(self, sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) - # TODO(guosheng): Support loss loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1169,6 +1245,8 @@ def forward(self, position_ids=None, attention_mask=None, masked_positions=None, + labels=None, + next_sentence_label=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1185,20 +1263,30 @@ def forward(self, See :class:`BertModel`. masked_positions(Tensor, optional): See :class:`BertPretrainingHeads`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. + next_sentence_label (Tensor of shape `(batch_size,)`, optional): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.bert.BertForPreTrainingOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (``prediction_scores``, ``seq_relationship_score``). - - With the fields: - - - `prediction_scores` (Tensor): - The scores of masked token prediction. Its data type should be float32. - If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size]. - Otherwise, its shape is [batch_size, mask_token_num, vocab_size]. - - - `seq_relationship_score` (Tensor): - The scores of next sentence prediction. - Its data type should be float32 and its shape is [batch_size, 2]. + An instance of :class:`~paddlenlp.transformers.bert.BertForPreTrainingOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.bert.BertForPreTrainingOutput`. """ with paddle.static.amp.fp16_guard(): @@ -1214,6 +1302,16 @@ def forward(self, sequence_output, pooled_output, masked_positions) total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.reshape( + (-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) + next_sentence_loss = loss_fct( + seq_relationship_score.reshape((-1, 2)), + next_sentence_label.reshape((-1, ))) + total_loss = masked_lm_loss + next_sentence_loss if not return_dict: output = (prediction_scores, seq_relationship_score) + outputs[2:] @@ -1221,6 +1319,7 @@ def forward(self, output) if total_loss is not None else output return BertForPreTrainingOutput( + loss=total_loss, prediction_logits=prediction_scores, seq_relationship_logits=seq_relationship_score, hidden_states=outputs.hidden_states, @@ -1315,6 +1414,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1330,10 +1430,24 @@ def forward(self, See :class:`BertModel` and shape as [batch_size, num_choice, sequence_length]. attention_mask (list, optional): See :class:`BertModel` and shape as [batch_size, num_choice, sequence_length]. + labels (Tensor of shape `(batch_size, )`, optional): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `reshaped_logits`, a tensor of the multiple choice classification logits. - Shape as `[batch_size, num_choice]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput`. Example: .. code-block:: @@ -1416,14 +1530,17 @@ def forward(self, reshaped_logits = logits.reshape( shape=(-1, self.num_choices)) # logits: (bs, num_choice) - # TODO(guosheng): Support loss loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( + loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1471,6 +1588,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1485,10 +1603,24 @@ def forward(self, See :class:`BertModel`. attention_mask (Tensor, optional): See :class:`BertModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., vocab_size]` + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `prediction_scores`, The scores of masked token prediction. - Its data type should be float32 and shape is [batch_size, sequence_length, vocab_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. Example: .. code-block:: @@ -1518,8 +1650,13 @@ def forward(self, sequence_output = outputs[0] prediction_scores = self.cls(sequence_output, masked_positions=None) - # TODO(guosheng): Support loss masked_lm_loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss( + ) # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.reshape((-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((masked_lm_loss, ) + @@ -1527,6 +1664,7 @@ def forward(self, output[0] if len(output) == 1 else output) return MaskedLMOutput( + loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/paddlenlp/transformers/bert/tokenizer.py b/paddlenlp/transformers/bert/tokenizer.py index 0459e470152f..bc140101af24 100644 --- a/paddlenlp/transformers/bert/tokenizer.py +++ b/paddlenlp/transformers/bert/tokenizer.py @@ -497,6 +497,9 @@ def vocab_size(self): return len(self.vocab) + def get_vocab(self): + return dict(self.vocab.token_to_idx, **self.added_tokens_encoder) + def _tokenize(self, text): """ End-to-end tokenization for BERT models. diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index a559843b99dc..04671be1c5d4 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -657,14 +657,14 @@ def forward(self, Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): - Whether to return a `ModelOutput` object. If `False`, the output + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - An instance of `BaseModelOutputWithPoolingAndCrossAttentions` if + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of - `BaseModelOutputWithPoolingAndCrossAttentions`. + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. Example: .. code-block:: @@ -765,6 +765,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -778,10 +779,26 @@ def forward(self, See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input text classification logits. - Shape as `[batch_size, num_classes]` and dtype as float32. + An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. + Example: .. code-block:: @@ -810,12 +827,24 @@ def forward(self, logits = self.classifier(pooled_output) loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -844,6 +873,8 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + start_positions=None, + end_positions=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -857,20 +888,28 @@ def forward(self, See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. - + start_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (`start_logits`, `end_logits`). - - With the fields: - - - `start_logits` (Tensor): - A tensor of the input token classification logits, indicates the start position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. - - - `end_logits` (Tensor): - A tensor of the input token classification logits, indicates the end position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. + An instance of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput`. Example: .. code-block:: @@ -901,12 +940,28 @@ def forward(self, start_logits, end_logits = paddle.unstack(x=logits, axis=0) total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if start_positions.ndim > 1: + start_positions = start_positions.squeeze(-1) + if start_positions.ndim > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = paddle.shape(start_logits)[1] + start_positions = start_positions.clip(0, ignored_index) + end_positions = end_positions.clip(0, ignored_index) + + loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( + loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, @@ -945,6 +1000,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -958,10 +1014,22 @@ def forward(self, See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `logits`, a tensor of the input token classification logits. - Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: @@ -990,12 +1058,17 @@ def forward(self, logits = self.classifier(sequence_output) loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1130,6 +1203,8 @@ def forward(self, position_ids=None, attention_mask=None, masked_positions=None, + labels=None, + next_sentence_label=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1143,20 +1218,30 @@ def forward(self, See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. + next_sentence_label (Tensor of shape `(batch_size,)`, optional): + Labels for computing the next sequence prediction (classification) loss. Input should be a sequence + pair (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a continuation of sequence A, + - 1 indicates sequence B is a random sequence. + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (``prediction_scores``, ``seq_relationship_score``). - - With the fields: - - - `prediction_scores` (Tensor): - The scores of masked token prediction. Its data type should be float32. - If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size]. - Otherwise, its shape is [batch_size, mask_token_num, vocab_size]. - - - `seq_relationship_score` (Tensor): - The scores of next sentence prediction. - Its data type should be float32 and its shape is [batch_size, 2]. + An instance of :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.bert.ErnieForPreTrainingOutput`. """ with paddle.static.amp.fp16_guard(): @@ -1172,6 +1257,16 @@ def forward(self, sequence_output, pooled_output, masked_positions) total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.reshape( + (-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) + next_sentence_loss = loss_fct( + seq_relationship_score.reshape((-1, 2)), + next_sentence_label.reshape((-1, ))) + total_loss = masked_lm_loss + next_sentence_loss if not return_dict: output = (prediction_scores, seq_relationship_score) + outputs[2:] @@ -1179,6 +1274,7 @@ def forward(self, output) if total_loss is not None else output return ErnieForPreTrainingOutput( + loss=total_loss, prediction_logits=prediction_scores, seq_relationship_logits=seq_relationship_score, hidden_states=outputs.hidden_states, @@ -1283,6 +1379,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1297,10 +1394,24 @@ def forward(self, See :class:`ErnieModel`. attention_mask (Tensor, optional): See :class:`ErnieModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., vocab_size]` + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `prediction_scores`, The scores of masked token prediction. - Its data type should be float32 and shape is [batch_size, sequence_length, vocab_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. Example: .. code-block:: @@ -1331,6 +1442,12 @@ def forward(self, prediction_scores = self.cls(sequence_output, masked_positions=None) masked_lm_loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss( + ) # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.reshape((-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((masked_lm_loss, ) + @@ -1338,6 +1455,7 @@ def forward(self, output[0] if len(output) == 1 else output) return MaskedLMOutput( + loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -1374,6 +1492,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -1389,10 +1508,24 @@ def forward(self, See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. attention_mask (list, optional): See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. + labels (Tensor of shape `(batch_size, )`, optional): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor: Returns tensor `reshaped_logits`, a tensor of the multiple choice classification logits. - Shape as `[batch_size, num_choice]` and dtype as `float32`. + An instance of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput`. """ # input_ids: [bs, num_choice, seq_l] @@ -1425,12 +1558,16 @@ def forward(self, shape=(-1, self.num_choices)) # logits: (bs, num_choice) loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( + loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, diff --git a/paddlenlp/transformers/model_outputs.py b/paddlenlp/transformers/model_outputs.py index 15586aa9a696..b7c9d31355b2 100644 --- a/paddlenlp/transformers/model_outputs.py +++ b/paddlenlp/transformers/model_outputs.py @@ -590,3 +590,47 @@ class MaskedLMOutput(ModelOutput): logits: paddle.Tensor = None hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`paddle.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`paddle.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `paddle.Tensor` tuples of length `config.n_layers`, with each tuple containing the cached key, + value states of the self-attention and the cross-attention layers if model is used in encoder-decoder + setting. Only relevant if `config.is_decoder = True`. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + """ + + loss: Optional[paddle.Tensor] = None + logits: paddle.Tensor = None + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None + hidden_states: Optional[Tuple[paddle.Tensor]] = None + attentions: Optional[Tuple[paddle.Tensor]] = None + cross_attentions: Optional[Tuple[paddle.Tensor]] = None diff --git a/paddlenlp/transformers/roberta/modeling.py b/paddlenlp/transformers/roberta/modeling.py index 761b3ff0f690..57a8601889f8 100644 --- a/paddlenlp/transformers/roberta/modeling.py +++ b/paddlenlp/transformers/roberta/modeling.py @@ -29,6 +29,7 @@ QuestionAnsweringModelOutput, MultipleChoiceModelOutput, MaskedLMOutput, + CausalLMOutputWithCrossAttentions, ModelOutput, ) @@ -401,14 +402,14 @@ def forward(self, Whether to return the attentions tensors of all attention layers. Defaults to `False`. return_dict (bool, optional): - Whether to return a `ModelOutput` object. If `False`, the output + Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - An instance of `BaseModelOutputWithPoolingAndCrossAttentions` if + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions` if `return_dict=True`. Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of - `BaseModelOutputWithPoolingAndCrossAttentions`. + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPoolingAndCrossAttentions`. Example: .. code-block:: @@ -496,6 +497,8 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + start_positions=None, + end_positions=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -509,27 +512,28 @@ def forward(self, See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. + start_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (Tensor of shape `(batch_size,)`, optional): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. output_hidden_states (bool, optional): - See :class:`RobertaModel`. + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - tuple: Returns tuple (`start_logits`, `end_logits`) by default if output_hidden_states is `False`. - Returns tuple (`start_logits`, `end_logits`, `encoder_outputs`) if output_hidden_states is set to `True`. - - With the fields: - - - `start_logits` (Tensor): - A tensor of the input token classification logits, indicates the start position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. - - - `end_logits` (Tensor): - A tensor of the input token classification logits, indicates the end position of the labelled span. - Its data type should be float32 and its shape is [batch_size, sequence_length]. - - - `encoder_outputs` (List(Tensor)): - A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder. - The length of the list is `num_hidden_layers`. - Each Tensor has a data type of float32 and a shape of [batch_size, sequence_length, hidden_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput`. Example: .. code-block:: @@ -560,12 +564,28 @@ def forward(self, start_logits, end_logits = paddle.unstack(x=logits, axis=0) total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if start_positions.ndim > 1: + start_positions = start_positions.squeeze(-1) + if start_positions.ndim > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = paddle.shape(start_logits)[1] + start_positions = start_positions.clip(0, ignored_index) + end_positions = end_positions.clip(0, ignored_index) + + loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss, ) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( + loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, @@ -604,6 +624,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -617,23 +638,25 @@ def forward(self, See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. + labels (Tensor of shape `(batch_size,)`, optional): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1` + a regression loss is computed (Mean-Square loss), If `num_classes > 1` + a classification loss is computed (Cross-Entropy). output_hidden_states (bool, optional): - See :class:`RobertaModel`. + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor or tuple: Returns tensor `logits` by default. - Returns tuple (`logits`, `encoder_outputs`) if output_hidden_states is set to `True`. - - With the fields: - - - `logits` (Tensor): - a tensor of the input text classification logits. - Its data type should be float32 and it has a shape of [batch_size, num_classes]. - - - `encoder_outputs` (List(Tensor)): - A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder. - The length of the list is `num_hidden_layers`. - Each Tensor has a data type of float32 and a shape of [batch_size, sequence_length, hidden_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput`. Example: .. code-block:: @@ -662,12 +685,24 @@ def forward(self, logits = self.classifier(pooled_output) loss = None + if labels is not None: + if self.num_classes == 1: + loss_fct = paddle.nn.MSELoss() + loss = loss_fct(logits, labels) + elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) + else: + loss_fct = paddle.nn.BCEWithLogitsLoss() + loss = loss_fct(logits, labels) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return SequenceClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -705,6 +740,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -718,23 +754,22 @@ def forward(self, See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`. output_hidden_states (bool, optional): - See :class:`RobertaModel`. + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor or tuple: Returns tensor `logits` by default. - Returns tuple (`logits`, `encoder_outputs`) if output_hidden_states is set to `True`. - - With the fields: - - - `logits` (Tensor): - a tensor of the input token classification logits. - Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. - - - `encoder_outputs` (List(Tensor)): - A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder. - The length of the list is `num_hidden_layers`. - Each Tensor has a data type of float32 and a shape of [batch_size, sequence_length, hidden_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput`. Example: .. code-block:: @@ -764,12 +799,17 @@ def forward(self, logits = self.classifier(sequence_output) loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape((-1, self.num_classes)), + labels.reshape((-1, ))) if not return_dict: output = (logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return TokenClassifierOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -777,6 +817,20 @@ def forward(self, class RobertaForMultipleChoice(RobertaPretrainedModel): + """ + RoBerta Model with a linear layer on top of the hidden-states output layer, + designed for multiple choice tasks like RocStories/SWAG tasks. + + Args: + bert (:class:`RobertaModel`): + An instance of RobertaModel. + num_choices (int, optional): + The number of choices. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of Bert. + If None, use the same value as `hidden_dropout_prob` of `RobertaModel` + instance `bert`. Defaults to None. + """ def __init__(self, roberta): super().__init__() @@ -792,9 +846,93 @@ def forward(self, token_type_ids=None, attention_mask=None, position_ids=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): + r""" + The RobertaForMultipleChoice forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. + token_type_ids(Tensor, optional): + See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. + position_ids(Tensor, optional): + See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. + attention_mask (list, optional): + See :class:`RobertaModel` and shape as [batch_size, num_choice, sequence_length]. + labels (Tensor of shape `(batch_size, )`, optional): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + output_hidden_states (bool, optional): + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MultipleChoiceModelOutput`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import BertForMultipleChoice, BertTokenizer + from paddlenlp.data import Pad, Dict + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForMultipleChoice.from_pretrained('bert-base-uncased', num_choices=2) + + data = [ + { + "question": "how do you turn on an ipad screen?", + "answer1": "press the volume button.", + "answer2": "press the lock button.", + "label": 1, + }, + { + "question": "how do you indent something?", + "answer1": "leave a space before starting the writing", + "answer2": "press the spacebar", + "label": 0, + }, + ] + + text = [] + text_pair = [] + for d in data: + text.append(d["question"]) + text_pair.append(d["answer1"]) + text.append(d["question"]) + text_pair.append(d["answer2"]) + + inputs = tokenizer(text, text_pair) + batchify_fn = lambda samples, fn=Dict( + { + "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids + "token_type_ids": Pad( + axis=0, pad_val=tokenizer.pad_token_type_id + ), # token_type_ids + } + ): fn(samples) + inputs = batchify_fn(inputs) + + reshaped_logits = model( + input_ids=paddle.to_tensor(inputs[0], dtype="int64"), + token_type_ids=paddle.to_tensor(inputs[1], dtype="int64"), + ) + print(reshaped_logits.shape) + # [2, 2] + + """ num_choices = input_ids.shape[1] @@ -823,12 +961,16 @@ def forward(self, reshaped_logits = logits.reshape((-1, num_choices)) loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits, ) + outputs[2:] return ((loss, ) + output) if loss is not None else ( output[0] if len(output) == 1 else output) return MultipleChoiceModelOutput( + loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -868,6 +1010,7 @@ def forward(self, attention_mask=None, token_type_ids=None, position_ids=None, + labels=None, output_hidden_states=False, output_attentions=False, return_dict=False): @@ -882,23 +1025,24 @@ def forward(self, See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., vocab_size]` output_hidden_states (bool, optional): - See :class:`RobertaModel`. + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - Tensor or tuple: Returns tensor `prediction_scores` by default. - Returns tuple (`prediction_scores`, `encoder_outputs`) if output_hidden_states is set to `True`. - - With the fields: - - - `prediction_scores` (Tensor): - The scores of masked token prediction. - Its data type should be float32 and shape is [batch_size, sequence_length, vocab_size]. - - - `encoder_outputs` (List(Tensor)): - A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder. - The length of the list is `num_hidden_layers`. - Each Tensor has a data type of float32 and a shape of [batch_size, sequence_length, hidden_size]. + An instance of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.MaskedLMOutput`. Example: .. code-block:: @@ -929,6 +1073,12 @@ def forward(self, prediction_scores = self.lm_head(sequence_output) masked_lm_loss = None + if labels is not None: + loss_fct = paddle.nn.CrossEntropyLoss( + ) # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.reshape((-1, prediction_scores.shape[-1])), + labels.reshape((-1, ))) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((masked_lm_loss, ) + @@ -936,6 +1086,7 @@ def forward(self, output[0] if len(output) == 1 else output) return MaskedLMOutput( + loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, @@ -996,6 +1147,7 @@ def forward(self, attention_mask=None, token_type_ids=None, position_ids=None, + labels=None, past_key_values=None, use_cache=None, output_attentions=False, @@ -1017,19 +1169,24 @@ def forward(self, See :class:`RobertaModel`. attention_mask (Tensor, optional): See :class:`RobertaModel`. - output_attentions (bool, optional): - See :class:`RobertaModel`. + labels (Tensor of shape `(batch_size, sequence_length)`, optional): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., vocab_size]`. output_hidden_states (bool, optional): - See :class:`RobertaModel`. + Whether to return the hidden states of all layers. + Defaults to `False`. + output_attentions (bool, optional): + Whether to return the attentions tensors of all attention layers. + Defaults to `False`. return_dict (bool, optional): - See :class:`RobertaModel`. - - + Whether to return a :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` object. If + `False`, the output will be a tuple of tensors. Defaults to `False`. Returns: - An instance of `MaskedLMOutput` if `return_dict=True`. Otherwise it - returns a tuple of tensors corresponding to ordered and not None - (depending on the input arguments) fields of `MaskedLMOutput`. + An instance of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions` if `return_dict=True`. + Otherwise it returns a tuple of tensors corresponding to ordered and + not None (depending on the input arguments) fields of :class:`~paddlenlp.transformers.model_outputs.CausalLMOutputWithCrossAttentions`. Example: .. code-block:: @@ -1048,6 +1205,8 @@ def forward(self, # [1, 13, 30522] """ + if labels is not None: + use_cache = False outputs = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -1062,12 +1221,21 @@ def forward(self, prediction_scores = self.lm_head(sequence_output) lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :] + labels = labels[:, 1:] + loss_fct = paddle.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shifted_prediction_scores.reshape( + (-1, prediction_scores.shape[-1])), labels.reshape((-1, ))) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((lm_loss, ) + output) if lm_loss is not None else ( output[0] if len(output) == 1 else output) - return MaskedLMOutput( + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, diff --git a/tests/transformers/bert/test_modeling.py b/tests/transformers/bert/test_modeling.py index e968c2becc65..7b2a7e093a86 100644 --- a/tests/transformers/bert/test_modeling.py +++ b/tests/transformers/bert/test_modeling.py @@ -32,6 +32,7 @@ def __init__( is_training=True, use_input_mask=True, use_token_type_ids=True, + use_labels=True, vocab_size=99, hidden_size=32, num_hidden_layers=5, @@ -58,6 +59,7 @@ def __init__( self.is_training = is_training self.use_input_mask = use_input_mask self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers @@ -92,8 +94,18 @@ def prepare_config_and_inputs(self): token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], + self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], + self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + config = self.get_config() - return config, input_ids, token_type_ids, input_mask + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels def get_config(self): return { @@ -119,6 +131,9 @@ def create_and_check_model( input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertModel(**config) model.eval() @@ -139,14 +154,19 @@ def create_and_check_for_masked_lm( input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertForMaskedLM(BertModel(**config)) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) + token_type_ids=token_type_ids, + labels=token_labels) self.parent.assertEqual( - result.shape, [self.batch_size, self.seq_length, self.vocab_size]) + result[1].shape, + [self.batch_size, self.seq_length, self.vocab_size]) def create_and_check_model_past_large_inputs( self, @@ -154,6 +174,9 @@ def create_and_check_model_past_large_inputs( input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertModel(**config) model.eval() @@ -200,12 +223,39 @@ def create_and_check_model_past_large_inputs( output_from_no_past_slice, atol=1e-3)) + def create_and_check_for_pretraining( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = BertForPretraining(BertModel(**config)) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + next_sentence_label=sequence_labels, + ) + self.parent.assertEqual( + result[1].shape, + [self.batch_size, self.seq_length, self.vocab_size]) + self.parent.assertEqual(result[2].shape, [self.batch_size, 2]) + def create_and_check_for_multiple_choice( self, config, input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertForMultipleChoice(BertModel(**config), num_choices=self.num_choices) @@ -220,23 +270,34 @@ def create_and_check_for_multiple_choice( multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, token_type_ids=multiple_choice_token_type_ids, + labels=choice_labels, ) - self.parent.assertEqual(result.shape, + self.parent.assertEqual(result[1].shape, [self.batch_size, self.num_choices]) - def create_and_check_for_question_answering(self, config, input_ids, - token_type_ids, input_mask): + def create_and_check_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): model = BertForQuestionAnswering(BertModel(**config)) model.eval() result = model( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, ) - self.parent.assertEqual(result[0].shape, - [self.batch_size, self.seq_length]) self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length]) + self.parent.assertEqual(result[2].shape, + [self.batch_size, self.seq_length]) def create_and_check_for_sequence_classification( self, @@ -244,16 +305,18 @@ def create_and_check_for_sequence_classification( input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertForSequenceClassification(BertModel(**config), num_classes=self.num_classes) model.eval() - result = model( - input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - ) - self.parent.assertEqual(result.shape, + result = model(input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=sequence_labels) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.num_classes]) def create_and_check_for_token_classification( @@ -262,15 +325,20 @@ def create_and_check_for_token_classification( input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ): model = BertForTokenClassification(BertModel(**config), num_classes=self.num_classes) model.eval() result = model(input_ids, attention_mask=input_mask, - token_type_ids=token_type_ids) + token_type_ids=token_type_ids, + labels=token_labels) self.parent.assertEqual( - result.shape, [self.batch_size, self.seq_length, self.num_classes]) + result[1].shape, + [self.batch_size, self.seq_length, self.num_classes]) def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -279,6 +347,9 @@ def prepare_config_and_inputs_for_common(self): input_ids, token_type_ids, input_mask, + sequence_labels, + token_labels, + choice_labels, ) = config_and_inputs inputs_dict = { "input_ids": input_ids, @@ -322,6 +393,10 @@ def test_for_multiple_choice(self): self.model_tester.create_and_check_for_multiple_choice( *config_and_inputs) + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering( diff --git a/tests/transformers/test_modeling_common.py b/tests/transformers/test_modeling_common.py index 3ef8e99fb554..d823ca23ec74 100644 --- a/tests/transformers/test_modeling_common.py +++ b/tests/transformers/test_modeling_common.py @@ -72,13 +72,8 @@ def _make_model_instance(self, config, model_class): return model_class(self.base_model_class(**config)) def test_save_load(self): - config, input_ids, token_type_ids, input_mask = self.model_tester.prepare_config_and_inputs( + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) - inputs_dict = { - "input_ids": input_ids, - "token_type_ids": token_type_ids, - "attention_mask": input_mask, - } for model_class in self.all_model_classes: model = self._make_model_instance(config, model_class) model.eval() diff --git a/tests/transformers/test_tokenizer_common.py b/tests/transformers/test_tokenizer_common.py index 5fafe4c9de16..cc8daf06f7c9 100644 --- a/tests/transformers/test_tokenizer_common.py +++ b/tests/transformers/test_tokenizer_common.py @@ -437,8 +437,7 @@ def test_tokenizers_common_ids_setters(self): "mask_token", ] - vocab = dict(tokenizer.vocab._token_to_idx, - **tokenizer.added_tokens_encoder) + vocab = tokenizer.get_vocab() token_id_to_test_setters = next(iter(vocab.values())) token_to_test_setters = tokenizer.convert_ids_to_tokens( token_id_to_test_setters, skip_special_tokens=False)