diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index d325f3ebbc8c..fb766d25a4c9 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -110,6 +110,7 @@ from .mbart.configuration import * from .megatronbert.modeling import * from .megatronbert.tokenizer import * +from .megatronbert.configuration import * from .prophetnet.modeling import * from .prophetnet.tokenizer import * from .mobilebert.modeling import * diff --git a/paddlenlp/transformers/megatronbert/configuration.py b/paddlenlp/transformers/megatronbert/configuration.py new file mode 100644 index 000000000000..2bc3e695a4cf --- /dev/null +++ b/paddlenlp/transformers/megatronbert/configuration.py @@ -0,0 +1,156 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MBart model configuration""" +from __future__ import annotations + +from paddlenlp.transformers.configuration_utils import PretrainedConfig + +__all__ = [ + "MegatronBert_PRETRAINED_INIT_CONFIGURATION", + "MegatronBert_PRETRAINED_RESOURCE_FILES_MAP", + "MegatronBertConfig", +] + +MegatronBert_PRETRAINED_INIT_CONFIGURATION = { + "megatronbert-cased": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 512, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 2, + "vocab_size": 29056, + "pad_token_id": 0, + }, + "megatronbert-uncased": { + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "max_position_embeddings": 512, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "type_vocab_size": 2, + "vocab_size": 30592, + "pad_token_id": 0, + }, +} + +MegatronBert_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "megatronbert-cased": "http://bj.bcebos.com/paddlenlp/models/transformers/megatron-bert/megatronbert-cased/model_state.pdparams", + "megatronbert-uncased": "http://bj.bcebos.com/paddlenlp/models/transformers/megatron-bert/megatronbert-uncased/model_state.pdparams", + } +} + + +class MegatronBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MegatronBertModel`]. It is used to instantiate a + MEGATRON_BERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the MEGATRON_BERT + [nvidia/megatron-bert-uncased-345m](https://huggingface.co/nvidia/megatron-bert-uncased-345m) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in `MegatronBertModel`. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `MegatronBert`. + hidden_size (int, optional): + Dimensionality of the encoder layer and pooler layer. Defaults to `1024`. + pad_token_id (int, optional): + The index of padding token in the token vocabulary. + Defaults to `0`. + type_vocab_size (int, optional): + The vocabulary size of `token_type_ids`. + Defaults to `2`. + hidden_act (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions + are supported. Defaults to `"gelu"`. + attention_probs_dropout_prob (float, optional): + The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. + Defaults to `0.1`. + num_attention_heads (int, optional): + Number of attention heads for each attention layer in the Transformer encoder. + Defaults to `16`. + num_hidden_layers (int, optional): + Number of hidden layers in the Transformer encoder. Defaults to `24`. + max_position_embeddings (int, optional): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + hidden_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the embeddings and encoder. + Defaults to `0.1`. + intermediate_size (int, optional): + Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors + to ff layers are firstly projected from `hidden_size` to `intermediate_size`, + and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. + Defaults to `4096`. + position_embedding_type (str, optional): + Type of position embedding. Defaults to "absolute" + initializer_range (float, optional): + The standard deviation of the normal initializer. + Defaults to 0.02. + + .. note:: + A normal_initializer initializes weight matrices as normal distributions. + See :meth:`MegatronBertPretrainedModel.init_weights()` for how weights are initialized in `MegatronBertModel`. + + """ + model_type = "megatronbert" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=29056, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=0, + position_embedding_type="absolute", + # use_cache=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + # self.use_cache = use_cache diff --git a/paddlenlp/transformers/megatronbert/modeling.py b/paddlenlp/transformers/megatronbert/modeling.py index 071dd215b32b..f9e7f276f89e 100644 --- a/paddlenlp/transformers/megatronbert/modeling.py +++ b/paddlenlp/transformers/megatronbert/modeling.py @@ -15,10 +15,17 @@ import math import paddle -from paddle import einsum, nn +import paddle.nn as nn +from paddle import einsum +from ...utils.env import CONFIG_NAME from .. import PretrainedModel, register_base_model from ..activations import get_activation +from .configuration import ( + MegatronBert_PRETRAINED_INIT_CONFIGURATION, + MegatronBert_PRETRAINED_RESOURCE_FILES_MAP, + MegatronBertConfig, +) __all__ = [ "MegatronBertModel", @@ -45,46 +52,13 @@ class MegatronBertPretrainedModel(PretrainedModel): See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ + model_config_file = CONFIG_NAME + resource_files_names = {"model_state": "model_state.pdparams"} - pretrained_init_configuration = { - "megatronbert-cased": { - "attention_probs_dropout_prob": 0.1, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1024, - "initializer_range": 0.02, - "intermediate_size": 4096, - "max_position_embeddings": 512, - "num_attention_heads": 16, - "num_hidden_layers": 24, - "type_vocab_size": 2, - "vocab_size": 29056, - "pad_token_id": 0, - }, - "megatronbert-uncased": { - "attention_probs_dropout_prob": 0.1, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "hidden_size": 1024, - "initializer_range": 0.02, - "intermediate_size": 4096, - "max_position_embeddings": 512, - "num_attention_heads": 16, - "num_hidden_layers": 24, - "type_vocab_size": 2, - "vocab_size": 30592, - "pad_token_id": 0, - }, - } - pretrained_resource_files_map = { - "model_state": { - "megatronbert-cased": "http://bj.bcebos.com/paddlenlp/models/transformers/" - "megatron-bert/megatronbert-cased/model_state.pdparams", - "megatronbert-uncased": "http://bj.bcebos.com/paddlenlp/models/transformers/" - "megatron-bert/megatronbert-cased/model_state.pdparams", - } - } + pretrained_init_configuration = MegatronBert_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = MegatronBert_PRETRAINED_RESOURCE_FILES_MAP base_model_prefix = "megatronbert" + config_class = MegatronBertConfig def init_weights(self, layer): """Initialization hook""" @@ -107,25 +81,16 @@ def init_weights(self, layer): class MegatronBertEmbeddings(nn.Layer): """Construct the embeddings from word, position and token_type embeddings.""" - def __init__( - self, - vocab_size=29056, - hidden_size=1024, - pad_token_id=0, - type_vocab_size=2, - max_position_embeddings=512, - hidden_dropout_prob=0.1, - position_embedding_type="absolute", - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.register_buffer("position_ids", paddle.arange(end=max_position_embeddings).expand((1, -1))) - self.position_embedding_type = position_embedding_type + self.register_buffer("position_ids", paddle.arange(end=config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = config.position_embedding_type def forward( self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 @@ -157,28 +122,21 @@ def forward( class MegatronBertSelfAttention(nn.Layer): - def __init__( - self, - hidden_size=1024, - num_attention_heads=16, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - position_embedding_type=None, - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertSelfAttention, self).__init__() - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = config.position_embedding_type if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": - self.max_position_embeddings = max_position_embeddings - self.distance_embedding = nn.Embedding(2 * max_position_embeddings - 1, self.attention_head_size) + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) def transpose_for_scores(self, x): new_x_shape = x.shape[:-1] + [self.num_attention_heads, self.attention_head_size] @@ -232,14 +190,10 @@ def forward(self, hidden_states, attention_mask=None): class MegatronBertSelfOutput(nn.Layer): - def __init__( - self, - hidden_size=1024, - hidden_dropout_prob=0.1, - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, residual): hidden_states = self.dense(hidden_states) @@ -248,24 +202,11 @@ def forward(self, hidden_states, residual): class MegatronBertAttention(nn.Layer): - def __init__( - self, - hidden_size=1024, - num_attention_heads=16, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - position_embedding_type=None, - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertAttention, self).__init__() - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) - self.self = MegatronBertSelfAttention( - num_attention_heads=num_attention_heads, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_position_embeddings=max_position_embeddings, - position_embedding_type=position_embedding_type, - ) - self.output = MegatronBertSelfOutput(hidden_size=hidden_size, hidden_dropout_prob=hidden_dropout_prob) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.self = MegatronBertSelfAttention(config) + self.output = MegatronBertSelfOutput(config) self.pruned_heads = set() def forward(self, hidden_states, attention_mask=None): @@ -277,10 +218,10 @@ def forward(self, hidden_states, attention_mask=None): class MegatronBertIntermediate(nn.Layer): - def __init__(self, hidden_size, intermediate_size, hidden_act): + def __init__(self, config: MegatronBertConfig): super(MegatronBertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = get_activation(hidden_act) + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_activation(config.hidden_act) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -289,10 +230,10 @@ def forward(self, hidden_states): class MegatronBertOutput(nn.Layer): - def __init__(self, intermediate_size, hidden_dropout_prob=0.1, hidden_size=1024): + def __init__(self, config: MegatronBertConfig): super(MegatronBertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -301,35 +242,13 @@ def forward(self, hidden_states, input_tensor): class MegatronBertLayer(nn.Layer): - def __init__( - self, - hidden_size=1024, - hidden_act="gelu", - num_attention_heads=16, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - intermediate_size=4096, - position_embedding_type=None, - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertLayer, self).__init__() self.seq_len_dim = 1 - self.attention = MegatronBertAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_position_embeddings=max_position_embeddings, - position_embedding_type=position_embedding_type, - ) - - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) - self.intermediate = MegatronBertIntermediate( - hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act - ) - self.output = MegatronBertOutput( - intermediate_size, hidden_dropout_prob=hidden_dropout_prob, hidden_size=hidden_size - ) + self.attention = MegatronBertAttention(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.intermediate = MegatronBertIntermediate(config) + self.output = MegatronBertOutput(config) def forward(self, hidden_states, attention_mask=None): self_attention_outputs = self.attention(hidden_states, attention_mask) @@ -350,38 +269,13 @@ def feed_forward_chunk(self, attention_output): class MegatronBertEncoder(nn.Layer): - def __init__( - self, - hidden_size=1024, - hidden_act="gelu", - num_attention_heads=16, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - intermediate_size=4096, - position_embedding_type=None, - num_hidden_layers=24, - ): + def __init__(self, config: MegatronBertConfig): super(MegatronBertEncoder, self).__init__() - self.layer = nn.LayerList( - [ - MegatronBertLayer( - hidden_size=hidden_size, - hidden_act=hidden_act, - num_attention_heads=num_attention_heads, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_position_embeddings=max_position_embeddings, - intermediate_size=intermediate_size, - position_embedding_type=position_embedding_type, - ) - for _ in range(num_hidden_layers) - ] - ) + self.layer = nn.LayerList([MegatronBertLayer(config) for _ in range(config.num_hidden_layers)]) # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None): for i, layer_module in enumerate(self.layer): @@ -396,9 +290,9 @@ def forward(self, hidden_states, attention_mask=None): class MegatronBertPooler(nn.Layer): - def __init__(self, hidden_size=1024): + def __init__(self, config: MegatronBertConfig): super(MegatronBertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): @@ -423,96 +317,20 @@ class MegatronBertModel(MegatronBertPretrainedModel): and refer to the Paddle documentation for all matter related to general usage and behavior. Args: - vocab_size (int): - Vocabulary size of `inputs_ids` in `MegatronBertModel`. Also is the vocab size of token embedding matrix. - Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `MegatronBert`. - hidden_size (int, optional): - Dimensionality of the encoder layer and pooler layer. Defaults to `1024`. - pad_token_id (int, optional): - The index of padding token in the token vocabulary. - Defaults to `0`. - type_vocab_size (int, optional): - The vocabulary size of `token_type_ids`. - Defaults to `2`. - hidden_act (str, optional): - The non-linear activation function in the feed-forward layer. - ``"gelu"``, ``"relu"`` and any other paddle supported activation functions - are supported. Defaults to `"gelu"`. - attention_probs_dropout_prob (float, optional): - The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. - Defaults to `0.1`. - num_attention_heads (int, optional): - Number of attention heads for each attention layer in the Transformer encoder. - Defaults to `16`. - num_hidden_layers (int, optional): - Number of hidden layers in the Transformer encoder. Defaults to `24`. - max_position_embeddings (int, optional): - The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input - sequence. Defaults to `512`. - hidden_dropout_prob (float, optional): - The dropout probability for all fully connected layers in the embeddings and encoder. - Defaults to `0.1`. - intermediate_size (int, optional): - Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors - to ff layers are firstly projected from `hidden_size` to `intermediate_size`, - and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. - Defaults to `4096`. - position_embedding_type (str, optional): - Type of position embedding. Defaults to "absolute" - initializer_range (float, optional): - The standard deviation of the normal initializer. - Defaults to 0.02. - - .. note:: - A normal_initializer initializes weight matrices as normal distributions. - See :meth:`MegatronBertPretrainedModel.init_weights()` for how weights are initialized in `MegatronBertModel`. - + Args: + config (:class:`MegatronBertConfig`): + An instance of MegatronBertConfig used to construct MBartModel. """ - def __init__( - self, - vocab_size=29056, - hidden_size=1024, - pad_token_id=0, - type_vocab_size=2, - hidden_act="gelu", - attention_probs_dropout_prob=0.1, - num_attention_heads=16, - num_hidden_layers=24, - max_position_embeddings=512, - hidden_dropout_prob=0.1, - intermediate_size=4096, - position_embedding_type="absolute", - initializer_range=0.02, - ): - super(MegatronBertModel, self).__init__() - - self.num_hidden_layers = num_hidden_layers - self.pad_token_id = pad_token_id - self.initializer_range = initializer_range - self.embeddings = MegatronBertEmbeddings( - vocab_size=vocab_size, - hidden_size=hidden_size, - pad_token_id=pad_token_id, - type_vocab_size=type_vocab_size, - max_position_embeddings=max_position_embeddings, - hidden_dropout_prob=hidden_dropout_prob, - position_embedding_type=position_embedding_type, - ) - self.encoder = MegatronBertEncoder( - hidden_size=hidden_size, - hidden_act=hidden_act, - num_attention_heads=num_attention_heads, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_position_embeddings=max_position_embeddings, - intermediate_size=intermediate_size, - position_embedding_type=position_embedding_type, - num_hidden_layers=num_hidden_layers, - ) - - self.pooler = MegatronBertPooler(hidden_size=hidden_size) + def __init__(self, config: MegatronBertConfig): + super(MegatronBertModel, self).__init__(config) + self.num_hidden_layers = config.num_hidden_layers + self.pad_token_id = config.pad_token_id + self.initializer_range = config.initializer_range + self.embeddings = MegatronBertEmbeddings(config) + self.encoder = MegatronBertEncoder(config) + self.pooler = MegatronBertPooler(config) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -620,10 +438,10 @@ class MegatronBertForQuestionAnswering(MegatronBertPretrainedModel): """ - def __init__(self, megatronbert): - super(MegatronBertForQuestionAnswering, self).__init__() - self.megatronbert = megatronbert - self.qa_outputs = nn.Linear(self.megatronbert.config["hidden_size"], 2) + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForQuestionAnswering, self).__init__(config) + self.megatronbert = MegatronBertModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -703,13 +521,13 @@ class MegatronBertForSequenceClassification(MegatronBertPretrainedModel): The number of labels. """ - def __init__(self, megatronbert, num_labels): - super(MegatronBertForSequenceClassification, self).__init__() - self.num_labels = num_labels + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels - self.megatronbert = megatronbert - self.dropout = nn.Dropout(self.megatronbert.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.megatronbert.config["hidden_size"], num_labels) + self.megatronbert = MegatronBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.apply(self.init_weights) @@ -756,11 +574,11 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, attent class MegatronBertPredictionHeadTransform(nn.Layer): - def __init__(self, hidden_size, hidden_act): + def __init__(self, config: MegatronBertConfig): super(MegatronBertPredictionHeadTransform, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.transform_act_fn = get_activation(hidden_act) - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.transform_act_fn = get_activation(config.hidden_act) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -770,16 +588,19 @@ def forward(self, hidden_states): class MegatronBertLMPredictionHead(nn.Layer): - def __init__(self, hidden_size, vocab_size, hidden_act): + def __init__(self, config: MegatronBertConfig): super(MegatronBertLMPredictionHead, self).__init__() - self.transform = MegatronBertPredictionHeadTransform(hidden_size, hidden_act) + self.transform = MegatronBertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. + self.decoder_weight = self.create_parameter( - shape=[vocab_size, hidden_size], dtype=self.transform.weight.dtype, is_bias=False + shape=[config.vocab_size, config.hidden_size], dtype=self.transform.dense.weight.dtype, is_bias=False + ) + self.decoder_bias = self.create_parameter( + shape=[config.vocab_size], dtype=self.decoder_weight.dtype, is_bias=True ) - self.decoder_bias = self.create_parameter(shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) @@ -788,11 +609,9 @@ def forward(self, hidden_states): class MegatronBertOnlyMLMHead(nn.Layer): - def __init__(self, hidden_size, vocab_size, hidden_act): + def __init__(self, config: MegatronBertConfig): super(MegatronBertOnlyMLMHead, self).__init__() - self.predictions = MegatronBertLMPredictionHead( - hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act - ) + self.predictions = MegatronBertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) @@ -800,9 +619,9 @@ def forward(self, sequence_output): class MegatronBertOnlyNSPHead(nn.Layer): - def __init__(self, hidden_size): + def __init__(self, config: MegatronBertConfig): super(MegatronBertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(hidden_size, 2) + self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) @@ -810,12 +629,10 @@ def forward(self, pooled_output): class MegatronBertPreTrainingHeads(nn.Layer): - def __init__(self, hidden_size, vocab_size, hidden_act): + def __init__(self, config: MegatronBertConfig): super(MegatronBertPreTrainingHeads, self).__init__() - self.predictions = MegatronBertLMPredictionHead( - hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act - ) - self.seq_relationship = nn.Linear(hidden_size, 2) + self.predictions = MegatronBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, sequence_output, pooled_output): prediction_scores = self.predictions(sequence_output) @@ -833,15 +650,11 @@ class MegatronBertForPreTraining(MegatronBertPretrainedModel): """ - def __init__(self, megatronbert): - super(MegatronBertForPreTraining, self).__init__() + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForPreTraining, self).__init__(config) - self.megatronbert = megatronbert - self.cls = MegatronBertPreTrainingHeads( - hidden_size=self.megatronbert.config["hidden_size"], - vocab_size=self.megatronbert.config["vocab_size"], - hidden_act=self.megatronbert.config["hidden_act"], - ) + self.megatronbert = MegatronBertModel(config) + self.cls = MegatronBertPreTrainingHeads(config) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -907,15 +720,11 @@ class MegatronBertForCausalLM(MegatronBertPretrainedModel): """ - def __init__(self, megatronbert): - super(MegatronBertForCausalLM, self).__init__() + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForCausalLM, self).__init__(config) - self.megatronbert = megatronbert - self.cls = MegatronBertOnlyMLMHead( - hidden_size=self.megatronbert.config["hidden_size"], - vocab_size=self.megatronbert.config["vocab_size"], - hidden_act=self.megatronbert.config["hidden_act"], - ) + self.megatronbert = MegatronBertModel(config) + self.cls = MegatronBertOnlyMLMHead(config) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -971,15 +780,11 @@ class MegatronBertForMaskedLM(MegatronBertPretrainedModel): """ - def __init__(self, megatronbert): - super(MegatronBertForMaskedLM, self).__init__() + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForMaskedLM, self).__init__(config) - self.megatronbert = megatronbert - self.cls = MegatronBertOnlyMLMHead( - hidden_size=self.megatronbert.config["hidden_size"], - vocab_size=self.megatronbert.config["vocab_size"], - hidden_act=self.megatronbert.config["hidden_act"], - ) + self.megatronbert = MegatronBertModel(config) + self.cls = MegatronBertOnlyMLMHead(config) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -1042,11 +847,11 @@ class MegatronBertForNextSentencePrediction(MegatronBertPretrainedModel): An instance of :class:`MegatronBertModel`. """ - def __init__(self, megatronbert): - super(MegatronBertForNextSentencePrediction, self).__init__() + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForNextSentencePrediction, self).__init__(config) - self.megatronbert = megatronbert - self.cls = MegatronBertOnlyNSPHead(hidden_size=self.megatronbert.config["hidden_size"]) + self.megatronbert = MegatronBertModel(config) + self.cls = MegatronBertOnlyNSPHead(config) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -1102,12 +907,12 @@ class MegatronBertForMultipleChoice(MegatronBertPretrainedModel): An instance of :class:`MegatronBertModel`. """ - def __init__(self, megatronbert): - super(MegatronBertForMultipleChoice, self).__init__() + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForMultipleChoice, self).__init__(config) - self.megatronbert = megatronbert - self.dropout = nn.Dropout(self.megatronbert.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.megatronbert.config["hidden_size"], 1) + self.megatronbert = MegatronBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) # Initialize weights and apply final processing self.apply(self.init_weights) @@ -1174,12 +979,12 @@ class MegatronBertForTokenClassification(MegatronBertPretrainedModel): The number of labels. """ - def __init__(self, megatronbert, num_labels): - super(MegatronBertForTokenClassification, self).__init__() - self.num_labels = num_labels - self.megatronbert = megatronbert - self.dropout = nn.Dropout(self.megatronbert.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.megatronbert.config["hidden_size"], self.num_labels) + def __init__(self, config: MegatronBertConfig): + super(MegatronBertForTokenClassification, self).__init__(config) + self.num_labels = config.num_labels + self.megatronbert = MegatronBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.apply(self.init_weights) def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None): diff --git a/paddlenlp/transformers/megatronbert/tokenizer.py b/paddlenlp/transformers/megatronbert/tokenizer.py index a5bc8080458d..24f7c24266d9 100644 --- a/paddlenlp/transformers/megatronbert/tokenizer.py +++ b/paddlenlp/transformers/megatronbert/tokenizer.py @@ -98,4 +98,5 @@ def __init__( pad_token=pad_token, cls_token=cls_token, mask_token=mask_token, + **kwargs, ) diff --git a/tests/transformers/megatronbert/__init__.py b/tests/transformers/megatronbert/__init__.py new file mode 100644 index 000000000000..97043fd7ba68 --- /dev/null +++ b/tests/transformers/megatronbert/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/megatronbert/test_modeling.py b/tests/transformers/megatronbert/test_modeling.py new file mode 100644 index 000000000000..3220f8fdff38 --- /dev/null +++ b/tests/transformers/megatronbert/test_modeling.py @@ -0,0 +1,359 @@ +# # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +import unittest + +import paddle + +from paddlenlp.transformers import ( + MegatronBertConfig, + MegatronBertForCausalLM, + MegatronBertForMaskedLM, + MegatronBertForMultipleChoice, + MegatronBertForNextSentencePrediction, + MegatronBertForPreTraining, + MegatronBertForQuestionAnswering, + MegatronBertForSequenceClassification, + MegatronBertForTokenClassification, + MegatronBertModel, + MegatronBertPretrainedModel, +) + +from ...testing_utils import slow +from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +class MegatronBertModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + pad_token_id=0, + type_sequence_label_size=2, + use_relative_position=True, + num_labels=3, + num_choices=4, + num_classes=3, + scope=None, + return_dict=False, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.type_sequence_label_size = type_sequence_label_size + self.use_relative_position = use_relative_position + self.num_classes = num_classes + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.return_dict = return_dict + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + + if self.parent.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return MegatronBertConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + use_relative_position=self.use_relative_position, + num_class=self.num_classes, + num_labels=self.num_labels, + num_choices=self.num_choices, + ) + + def create_and_check_model( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertModel(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.hidden_size]) + + def create_and_check_for_multiple_choice( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForMultipleChoice(config) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_input_mask = input_mask.unsqueeze(1).expand([-1, self.num_choices, -1]) + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + ) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_choices]) + + def create_and_check_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForQuestionAnswering(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + start_logits, end_logits = result[0], result[1] + + self.parent.assertEqual(start_logits.shape, [self.batch_size, self.seq_length]) + self.parent.assertEqual(end_logits.shape, [self.batch_size, self.seq_length]) + + def create_and_check_for_sequence_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForSequenceClassification(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_classes]) + + def create_and_check_for_next_sentence_prediction( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForNextSentencePrediction(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + if paddle.is_tensor(result): + result = [result] + self.parent.assertEqual(result[0].shape, [self.batch_size, 2]) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForTokenClassification(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.num_classes]) + + def create_and_check_for_causal_lm( + self, + config: MegatronBertConfig, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForCausalLM(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result[0].shape, [self.seq_length, self.vocab_size]) + + def create_and_check_for_masked_lm( + self, + config: MegatronBertConfig, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = MegatronBertForMaskedLM(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + self.parent.assertEqual(result[0].shape, [self.seq_length, self.vocab_size]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = MegatronBertModel + return_dict: bool = False + use_labels: bool = False + test_resize_embeddings: bool = False + + all_model_classes = ( + MegatronBertModel, + MegatronBertForQuestionAnswering, + MegatronBertForSequenceClassification, + MegatronBertForNextSentencePrediction, + MegatronBertForCausalLM, + MegatronBertForPreTraining, + MegatronBertForMaskedLM, + MegatronBertForMultipleChoice, + MegatronBertForTokenClassification, + ) + + def setUp(self): + self.model_tester = MegatronBertModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_next_sentence_prediction(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_next_sentence_prediction(*config_and_inputs) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in list(MegatronBertPretrainedModel.pretrained_init_configuration)[:1]: + model = MegatronBertModel.from_pretrained(model_name) + self.assertIsNotNone(model) diff --git a/tests/transformers/megatronbert/test_tokenizer.py b/tests/transformers/megatronbert/test_tokenizer.py new file mode 100644 index 000000000000..fc1d69cd7e5b --- /dev/null +++ b/tests/transformers/megatronbert/test_tokenizer.py @@ -0,0 +1,138 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from paddlenlp.transformers import MegatronBertTokenizer + +from ...testing_utils import slow +from ...transformers.test_tokenizer_common import TokenizerTesterMixin + + +class MegatronBertTokenizerTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = MegatronBertTokenizer + space_between_special_tokens = True + test_seq2seq = False + + def setUp(self): + super().setUp() + + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + + self.vocab_file = os.path.join(self.tmpdirname, MegatronBertTokenizer.resource_files_names["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11]) + + def test_clean_text(self): + tokenizer = self.get_tokenizer() + # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340 + self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("megatronbert-uncased") + + text = tokenizer.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)["input_ids"] + text_2 = tokenizer.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] + + def test_offsets_with_special_characters(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + sentence = f"A, naïve {tokenizer.mask_token} AllenNLP sentence." + tokens = tokenizer.encode( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + do_lower_case = tokenizer.do_lower_case if hasattr(tokenizer, "do_lower_case") else False + expected_results = ( + [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "A"), + ((1, 2), ","), + ((3, 5), "na"), + ((5, 6), "##ï"), + ((6, 8), "##ve"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "Allen"), + ((21, 23), "##NL"), + ((23, 24), "##P"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + if not do_lower_case + else [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "a"), + ((1, 2), ","), + ((3, 8), "naive"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "allen"), + ((21, 23), "##nl"), + ((23, 24), "##p"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + ) + + self.assertEqual( + [e[1] for e in expected_results], tokenizer.convert_ids_to_tokens(tokens["input_ids"]) + ) + self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"])