diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 011d3ec8f85f..e5df9cfcc7b9 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -110,6 +110,7 @@ from .mobilebert.tokenizer import * from .mpnet.modeling import * from .mpnet.tokenizer import * +from .nezha.configuration import * from .nezha.modeling import * from .nezha.tokenizer import * from .ppminilm.modeling import * diff --git a/paddlenlp/transformers/nezha/configuration.py b/paddlenlp/transformers/nezha/configuration.py new file mode 100644 index 000000000000..5dc02196c284 --- /dev/null +++ b/paddlenlp/transformers/nezha/configuration.py @@ -0,0 +1,190 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" NeZha model configuration""" +from __future__ import annotations + +from typing import Dict + +from ..configuration_utils import PretrainedConfig + +__all__ = ["NEZHA_PRETRAINED_INIT_CONFIGURATION", "NeZhaConfig", "NEZHA_PRETRAINED_RESOURCE_FILES_MAP"] + +NEZHA_PRETRAINED_INIT_CONFIGURATION = { + "nezha-base-chinese": { + "vocab_size": 21128, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "max_relative_position": 64, + "type_vocab_size": 2, + "initializer_range": 0.02, + "use_relative_position": True, + }, + "nezha-large-chinese": { + "vocab_size": 21128, + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "max_relative_position": 64, + "type_vocab_size": 2, + "initializer_range": 0.02, + "use_relative_position": True, + }, + "nezha-base-wwm-chinese": { + "vocab_size": 21128, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "max_relative_position": 64, + "type_vocab_size": 2, + "initializer_range": 0.02, + "use_relative_position": True, + }, + "nezha-large-wwm-chinese": { + "vocab_size": 21128, + "hidden_size": 1024, + "num_hidden_layers": 24, + "num_attention_heads": 16, + "intermediate_size": 4096, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "max_relative_position": 64, + "type_vocab_size": 2, + "initializer_range": 0.02, + "use_relative_position": True, + }, +} +NEZHA_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "nezha-base-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-base-chinese.pdparams", + "nezha-large-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-large-chinese.pdparams", + "nezha-base-wwm-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-base-wwm-chinese.pdparams", + "nezha-large-wwm-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-large-wwm-chinese.pdparams", + } +} + + +class NeZhaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`NezhaModel`]. It is used to instantiate an Nezha + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Nezha + [sijunhe/nezha-cn-base](https://huggingface.co/sijunhe/nezha-cn-base) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, optional, defaults to 21128): + Vocabulary size of the NEZHA model. Defines the different tokens that can be represented by the + *inputs_ids* passed to the forward method of [`NezhaModel`]. + embedding_size (`int`, optional, defaults to 128): + Dimensionality of vocabulary embeddings. + hidden_size (`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, optional, defaults to 3072): + The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + hidden_dropout_prob (`float`, optional, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, optional, defaults to 2): + The vocabulary size of the *token_type_ids* passed into [`NezhaModel`]. + initializer_range (`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + classifier_dropout (`float`, optional, defaults to 0.1): + The dropout ratio for attached classifiers. + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + Example: + ```python + >>> from paddlenlp.transformers import NeZhaConfig, NeZhaModel + >>> # Initializing an Nezha configuration + >>> configuration = NeZhaConfig() + >>> # Initializing a model (with random weights) from the Nezha-base style configuration model + >>> model = NeZhaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + attribute_map: Dict[str, str] = {"dropout": "classifier_dropout", "num_classes": "num_labels"} + pretrained_init_configuration = NEZHA_PRETRAINED_INIT_CONFIGURATION + model_type = "nezha" + + def __init__( + self, + vocab_size=21128, + embedding_size=128, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + max_relative_position=64, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + classifier_dropout=0.1, + pad_token_id=0, + bos_token_id=2, + eos_token_id=3, + use_cache=True, + **kwargs + ): + super().__init__(pad_token_id=pad_token_id, **kwargs) + + self.vocab_size = vocab_size + self.embedding_size = embedding_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.max_relative_position = max_relative_position + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.classifier_dropout = classifier_dropout + self.use_cache = use_cache diff --git a/paddlenlp/transformers/nezha/modeling.py b/paddlenlp/transformers/nezha/modeling.py index d35a387127cd..499f5431e732 100644 --- a/paddlenlp/transformers/nezha/modeling.py +++ b/paddlenlp/transformers/nezha/modeling.py @@ -17,20 +17,26 @@ import copy import math -import numpy as np +import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F from paddlenlp.transformers import PretrainedModel, register_base_model +from ...utils.env import CONFIG_NAME +from .configuration import ( + NEZHA_PRETRAINED_INIT_CONFIGURATION, + NEZHA_PRETRAINED_RESOURCE_FILES_MAP, + NeZhaConfig, +) + __all__ = [ "NeZhaModel", "NeZhaPretrainedModel", "NeZhaForPretraining", "NeZhaForSequenceClassification", - "NeZhaPretrainingHeads", "NeZhaForTokenClassification", "NeZhaForQuestionAnswering", "NeZhaForMultipleChoice", @@ -77,36 +83,28 @@ def gelu_new(x): class NeZhaAttention(nn.Layer): - def __init__( - self, - hidden_size=768, - num_attention_heads=12, - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_relative_position=64, - layer_norm_eps=1e-12, - ): + def __init__(self, config: NeZhaConfig): super(NeZhaAttention, self).__init__() - if hidden_size % num_attention_heads != 0: + if config.hidden_size % config.num_attention_heads != 0: raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads) + "The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + "heads ({config.num_attention_heads})" ) - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) self.relative_positions_embeddings = self.generate_relative_positions_embeddings( - length=512, depth=self.attention_head_size, max_relative_position=max_relative_position + length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position ) - self.attention_dropout = nn.Dropout(attention_probs_dropout_prob) + self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.dense = nn.Linear(hidden_size, hidden_size) - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) - self.output_dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) def generate_relative_positions_embeddings(self, length, depth, max_relative_position=127): vocab_size = max_relative_position * 2 + 1 @@ -201,32 +199,15 @@ def forward(self, hidden_states, attention_mask): class NeZhaLayer(nn.Layer): - def __init__( - self, - hidden_size=768, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_relative_position=64, - layer_norm_eps=1e-12, - ): + def __init__(self, config: NeZhaConfig): super(NeZhaLayer, self).__init__() self.seq_len_dim = 1 - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) - self.attention = NeZhaAttention( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_relative_position=max_relative_position, - layer_norm_eps=layer_norm_eps, - ) - self.ffn = nn.Linear(hidden_size, intermediate_size) - self.ffn_output = nn.Linear(intermediate_size, hidden_size) - self.activation = ACT2FN[hidden_act] - self.dropout = nn.Dropout(hidden_dropout_prob) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.attention = NeZhaAttention(config) + self.ffn = nn.Linear(config.hidden_size, config.intermediate_size) + self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, attention_mask=None): attention_output, layer_att = self.attention(hidden_states, attention_mask) @@ -242,30 +223,10 @@ def forward(self, hidden_states, attention_mask=None): class NeZhaEncoder(nn.Layer): - def __init__( - self, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_relative_position=64, - layer_norm_eps="1e-12", - ): + def __init__(self, config: NeZhaConfig): super(NeZhaEncoder, self).__init__() - layer = NeZhaLayer( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - hidden_act=hidden_act, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_relative_position=max_relative_position, - layer_norm_eps=layer_norm_eps, - ) - self.layer = nn.LayerList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + layer = NeZhaLayer(config) + self.layer = nn.LayerList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask): all_encoder_layers = [] @@ -279,26 +240,18 @@ def forward(self, hidden_states, attention_mask): class NeZhaEmbeddings(nn.Layer): - def __init__( - self, - vocab_size, - hidden_size=768, - hidden_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - use_relative_position=True, - ): + def __init__(self, config: NeZhaConfig): super(NeZhaEmbeddings, self).__init__() - self.use_relative_position = use_relative_position + self.use_relative_position = config.use_relative_position - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - if not use_relative_position: - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + if not self.use_relative_position: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - self.layer_norm = nn.LayerNorm(hidden_size) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.shape[1] @@ -325,9 +278,9 @@ def forward(self, input_ids, token_type_ids=None): class NeZhaPooler(nn.Layer): - def __init__(self, hidden_size): + def __init__(self, config: NeZhaConfig): super(NeZhaPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): @@ -348,78 +301,14 @@ class NeZhaPretrainedModel(PretrainedModel): See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. """ - pretrained_init_configuration = { - "nezha-base-chinese": { - "vocab_size": 21128, - "hidden_size": 768, - "num_hidden_layers": 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "max_relative_position": 64, - "type_vocab_size": 2, - "initializer_range": 0.02, - "use_relative_position": True, - }, - "nezha-large-chinese": { - "vocab_size": 21128, - "hidden_size": 1024, - "num_hidden_layers": 24, - "num_attention_heads": 16, - "intermediate_size": 4096, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "max_relative_position": 64, - "type_vocab_size": 2, - "initializer_range": 0.02, - "use_relative_position": True, - }, - "nezha-base-wwm-chinese": { - "vocab_size": 21128, - "hidden_size": 768, - "num_hidden_layers": 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "max_relative_position": 64, - "type_vocab_size": 2, - "initializer_range": 0.02, - "use_relative_position": True, - }, - "nezha-large-wwm-chinese": { - "vocab_size": 21128, - "hidden_size": 1024, - "num_hidden_layers": 24, - "num_attention_heads": 16, - "intermediate_size": 4096, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "max_relative_position": 64, - "type_vocab_size": 2, - "initializer_range": 0.02, - "use_relative_position": True, - }, - } - pretrained_resource_files_map = { - "model_state": { - "nezha-base-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-base-chinese.pdparams", - "nezha-large-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-large-chinese.pdparams", - "nezha-base-wwm-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-base-wwm-chinese.pdparams", - "nezha-large-wwm-chinese": "https://bj.bcebos.com/paddlenlp/models/transformers/nezha/nezha-large-wwm-chinese.pdparams", - } - } + model_config_file = CONFIG_NAME + config_class = NeZhaConfig + resource_files_names = {"model_state": "model_state.pdparams"} base_model_prefix = "nezha" + pretrained_init_configuration = NEZHA_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = NEZHA_PRETRAINED_RESOURCE_FILES_MAP + def init_weights(self, layer): """Initialization hook""" if isinstance(layer, (nn.Linear, nn.Embedding)): @@ -429,9 +318,7 @@ def init_weights(self, layer): layer.weight.set_value( paddle.tensor.normal( mean=0.0, - std=self.initializer_range - if hasattr(self, "initializer_range") - else self.nezha.config["initializer_range"], + std=self.config.initializer_range, shape=layer.weight.shape, ) ) @@ -503,48 +390,15 @@ class NeZhaModel(NeZhaPretrainedModel): """ - def __init__( - self, - vocab_size, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - max_relative_position=64, - layer_norm_eps=1e-12, - use_relative_position=True, - ): - super(NeZhaModel, self).__init__() - self.initializer_range = initializer_range - - self.embeddings = NeZhaEmbeddings( - vocab_size=vocab_size, - hidden_size=hidden_size, - hidden_dropout_prob=hidden_dropout_prob, - max_position_embeddings=max_position_embeddings, - type_vocab_size=type_vocab_size, - use_relative_position=use_relative_position, - ) + def __init__(self, config: NeZhaConfig): + super(NeZhaModel, self).__init__(config) + self.initializer_range = config.initializer_range - self.encoder = NeZhaEncoder( - hidden_size=hidden_size, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - intermediate_size=intermediate_size, - hidden_act=hidden_act, - hidden_dropout_prob=hidden_dropout_prob, - attention_probs_dropout_prob=attention_probs_dropout_prob, - max_relative_position=max_relative_position, - layer_norm_eps=layer_norm_eps, - ) + self.embeddings = NeZhaEmbeddings(config) + + self.encoder = NeZhaEncoder(config) - self.pooler = NeZhaPooler(hidden_size) + self.pooler = NeZhaPooler(config) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): @@ -627,14 +481,16 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): class NeZhaLMPredictionHead(nn.Layer): - def __init__(self, hidden_size, vocab_size, hidden_act, embedding_weights=None, layer_norm_eps=1e-12): + def __init__(self, config: NeZhaConfig, embedding_weights=None): super(NeZhaLMPredictionHead, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = ACT2FN[hidden_act] - self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.decoder_weight = embedding_weights - self.decoder_bias = self.create_parameter(shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True) + self.decoder_bias = self.create_parameter( + shape=[config.vocab_size], dtype=self.decoder_weight.dtype, is_bias=True + ) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -664,10 +520,10 @@ class NeZhaPretrainingHeads(nn.Layer): """ - def __init__(self, hidden_size, vocab_size, hidden_act, embedding_weights=None): + def __init__(self, config: NeZhaConfig, embedding_weights=None): super(NeZhaPretrainingHeads, self).__init__() - self.predictions = NeZhaLMPredictionHead(hidden_size, vocab_size, hidden_act, embedding_weights) - self.seq_relationship = nn.Linear(hidden_size, 2) + self.predictions = NeZhaLMPredictionHead(config=config, embedding_weights=embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, sequence_output, pooled_output): """ @@ -710,13 +566,11 @@ class NeZhaForPretraining(NeZhaPretrainedModel): """ - def __init__(self, nezha): - super(NeZhaForPretraining, self).__init__() - self.nezha = nezha + def __init__(self, config: NeZhaConfig): + super(NeZhaForPretraining, self).__init__(config) + self.nezha = NeZhaModel(config) self.cls = NeZhaPretrainingHeads( - self.nezha.config["hidden_size"], - self.nezha.config["vocab_size"], - self.nezha.config["hidden_act"], + config, self.nezha.embeddings.word_embeddings.weight, ) @@ -766,7 +620,7 @@ def forward( if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct( - prediction_scores.reshape((-1, self.nezha.config["vocab_size"])), masked_lm_labels.reshape((-1,)) + prediction_scores.reshape((-1, self.nezha.config.vocab_size)), masked_lm_labels.reshape((-1,)) ) next_sentence_loss = loss_fct(seq_relationship_score.reshape((-1, 2)), next_sentence_label.reshape((-1,))) total_loss = masked_lm_loss + next_sentence_loss @@ -774,7 +628,7 @@ def forward( elif masked_lm_labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct( - prediction_scores.reshape((-1, self.nezha.config["vocab_size"])), masked_lm_labels.reshape((-1,)) + prediction_scores.reshape((-1, self.nezha.config.vocab_size)), masked_lm_labels.reshape((-1,)) ) total_loss = masked_lm_loss return total_loss @@ -788,18 +642,14 @@ class NeZhaForQuestionAnswering(NeZhaPretrainedModel): and `span_end_logits`, designed for question-answering tasks like SQuAD. Args: - nezha (:class:`NeZhaModel`): - An instance of NeZhaModel. - dropout (float, optional): - The dropout probability for output of NeZha. - If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` - instance `nezha`. Defaults to `None`. + config (:class:`NeZhaConfig`): + An instance of NeZhaConfig used to construct NeZhaForSequenceClassification. """ - def __init__(self, nezha, dropout=None): - super(NeZhaForQuestionAnswering, self).__init__() - self.nezha = nezha - self.classifier = nn.Linear(self.nezha.config["hidden_size"], 2) + def __init__(self, config: NeZhaConfig): + super(NeZhaForQuestionAnswering, self).__init__(config) + self.nezha = NeZhaModel(config) + self.classifier = nn.Linear(config.hidden_size, 2) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): @@ -860,22 +710,18 @@ class NeZhaForSequenceClassification(NeZhaPretrainedModel): sequence classification/regression tasks like GLUE tasks. Args: - nezha (:class:`NeZhaModel`): - An instance of NeZhaModel. - num_classes (int, optional): - The number of classes. Defaults to `2`. - dropout (float, optional): - The dropout probability for output of NeZha. - If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` - instance `nezha`. Defaults to None. + config (:class:`NeZhaConfig`): + An instance of NeZhaConfig used to construct NeZhaForSequenceClassification. """ - def __init__(self, nezha, num_classes=2, dropout=None): - super(NeZhaForSequenceClassification, self).__init__() - self.num_classes = num_classes - self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.nezha.config["hidden_size"], num_classes) + def __init__(self, config: NeZhaConfig): + super(NeZhaForSequenceClassification, self).__init__(config) + self.nezha = NeZhaModel(config) + self.num_labels = config.num_labels + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): @@ -925,22 +771,18 @@ class NeZhaForTokenClassification(NeZhaPretrainedModel): designed for token classification tasks like NER tasks. Args: - nezha (:class:`NeZhaModel`): - An instance of NeZhaModel. - num_classes (int, optional): - The number of classes. Defaults to `2`. - dropout (float, optional): - The dropout probability for output of NeZha. - If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` - instance `nezha`. Defaults to `None`. + config (:class:`NeZhaConfig`): + An instance of NeZhaConfig used to construct NeZhaForSequenceClassification. """ - def __init__(self, nezha, num_classes=2, dropout=None): - super(NeZhaForTokenClassification, self).__init__() - self.num_classes = num_classes - self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.nezha.config["hidden_size"], num_classes) + def __init__(self, config: NeZhaConfig): + super(NeZhaForTokenClassification, self).__init__(config) + self.nezha = NeZhaModel(config) + self.num_labels = config.num_labels + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, self.num_labels) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): @@ -989,22 +831,18 @@ class NeZhaForMultipleChoice(NeZhaPretrainedModel): designed for multiple choice tasks like RocStories/SWAG tasks. Args: - nezha (:class:`NeZhaModel`): - An instance of NeZhaModel. - num_choices (int, optional): - The number of choices. Defaults to `2`. - dropout (float, optional): - The dropout probability for output of NeZha. - If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` - instance `nezha`. Defaults to `None`. + config (:class:`BertConfig`): + An instance of BertConfig used to construct BertForMultipleChoice. """ - def __init__(self, nezha, num_choices=2, dropout=None): - super(NeZhaForMultipleChoice, self).__init__() - self.num_choices = num_choices - self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.nezha.config["hidden_size"], 1) + def __init__(self, config: NeZhaConfig): + super(NeZhaForMultipleChoice, self).__init__(config) + self.nezha = NeZhaModel(config) + self.num_choices = config.num_choices + self.dropout = nn.Dropout( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): @@ -1027,9 +865,9 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): # input_ids: [bs, num_choice, seq_l] input_ids = input_ids.reshape((-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] - if token_type_ids: + if token_type_ids is not None: token_type_ids = token_type_ids.reshape((-1, token_type_ids.shape[-1])) - if attention_mask: + if attention_mask is not None: attention_mask = attention_mask.reshape((-1, attention_mask.shape[-1])) _, pooled_output = self.nezha(input_ids, token_type_ids, attention_mask) diff --git a/paddlenlp/transformers/nezha/tokenizer.py b/paddlenlp/transformers/nezha/tokenizer.py index 3fd1a430f860..0c7f8c642803 100644 --- a/paddlenlp/transformers/nezha/tokenizer.py +++ b/paddlenlp/transformers/nezha/tokenizer.py @@ -13,14 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import io -import json import os -import six -import unicodedata -from paddlenlp.transformers import PretrainedTokenizer, BasicTokenizer, WordpieceTokenizer +from paddlenlp.transformers import ( + BasicTokenizer, + PretrainedTokenizer, + WordpieceTokenizer, +) __all__ = ["NeZhaTokenizer"] @@ -298,3 +297,8 @@ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_spe if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] return [1] + ([0] * len(token_ids_0)) + [1] + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab diff --git a/tests/transformers/nezha/__init__.py b/tests/transformers/nezha/__init__.py new file mode 100644 index 000000000000..595add0aed9e --- /dev/null +++ b/tests/transformers/nezha/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/nezha/test_modeling.py b/tests/transformers/nezha/test_modeling.py new file mode 100644 index 000000000000..eb238c15369d --- /dev/null +++ b/tests/transformers/nezha/test_modeling.py @@ -0,0 +1,304 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + +from paddlenlp.transformers import ( + NeZhaConfig, + NeZhaForMultipleChoice, + NeZhaForPretraining, + NeZhaForQuestionAnswering, + NeZhaForSequenceClassification, + NeZhaForTokenClassification, + NeZhaModel, + NeZhaPretrainedModel, +) + +from ...testing_utils import slow +from ..test_modeling_common import ( + ModelTesterMixin, + ids_tensor, + random_attention_mask, +) + + +class NeZhaModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + pad_token_id=0, + type_sequence_label_size=2, + use_relative_position=True, + num_labels=3, + num_choices=4, + num_classes=3, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.pad_token_id = pad_token_id + self.type_sequence_label_size = type_sequence_label_size + self.use_relative_position = use_relative_position + self.num_classes = num_classes + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + + if self.parent.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return NeZhaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + use_relative_position=self.use_relative_position, + num_class=self.num_classes, + num_labels=self.num_labels, + num_choices=self.num_choices, + ) + + def create_and_check_model( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = NeZhaModel(config) + model.eval() + result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + result = model(input_ids, token_type_ids=token_type_ids) + result = model(input_ids) + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.hidden_size]) + + def create_and_check_for_multiple_choice( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = NeZhaForMultipleChoice(config) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_input_mask = input_mask.unsqueeze(1).expand([-1, self.num_choices, -1]) + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + ) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_choices]) + + def create_and_check_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = NeZhaForQuestionAnswering(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + start_logits, end_logits = result[0], result[1] + + self.parent.assertEqual(start_logits.shape, [self.batch_size, self.seq_length]) + self.parent.assertEqual(end_logits.shape, [self.batch_size, self.seq_length]) + + def create_and_check_for_sequence_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = NeZhaForSequenceClassification(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_classes]) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = NeZhaForTokenClassification(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + ) + if paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.num_classes]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +class NeZhaModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = NeZhaModel + return_dict: bool = False + use_labels: bool = False + test_resize_embeddings: bool = False + + all_model_classes = ( + NeZhaModel, + NeZhaForMultipleChoice, + NeZhaForPretraining, + NeZhaForQuestionAnswering, + NeZhaForSequenceClassification, + NeZhaForTokenClassification, + ) + + def setUp(self): + self.model_tester = NeZhaModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in list(NeZhaPretrainedModel.pretrained_init_configuration)[:1]: + model = NeZhaModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformers/nezha/test_tokenizer.py b/tests/transformers/nezha/test_tokenizer.py new file mode 100644 index 000000000000..bbd00129daf3 --- /dev/null +++ b/tests/transformers/nezha/test_tokenizer.py @@ -0,0 +1,297 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from paddlenlp.transformers import BasicTokenizer, NeZhaTokenizer, WordpieceTokenizer + +from ...testing_utils import slow +from ...transformers.test_tokenizer_common import ( + TokenizerTesterMixin, + filter_non_english, +) + + +class NeZhaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = NeZhaTokenizer + space_between_special_tokens = True + from_pretrained_filter = filter_non_english + test_seq2seq = True + + def setUp(self): + super().setUp() + + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + + self.vocab_file = os.path.join(self.tmpdirname, NeZhaTokenizer.resource_files_names["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11]) + + def test_fast_and_python_full_tokenizer(self): + if not self.test_fast_tokenizer: + return + + tokenizer = self.get_tokenizer() + tokenizer_fast = self.get_fast_tokenizer() + + sequence = "UNwant\u00E9d,running" + tokens = tokenizer.tokenize(sequence) + tokens_fast = tokenizer_fast.tokenize(sequence) + self.assertListEqual(tokens, tokens_fast) + + ids = tokenizer.encode(sequence, add_special_tokens=False)["input_ids"] + ids_fast = tokenizer_fast.encode(sequence, add_special_tokens=False)["input_ids"] + self.assertListEqual(ids, ids_fast) + + ids = tokenizer.encode(sequence)["input_ids"] + ids_fast = tokenizer_fast.encode(sequence)["input_ids"] + self.assertListEqual(ids, ids_fast) + + tokenizer = self.get_tokenizer(do_lower_case=True) + tokenizer_fast = self.get_fast_tokenizer(do_lower_case=True) + + tokens = tokenizer.tokenize(sequence) + tokens_fast = tokenizer_fast.tokenize(sequence) + self.assertListEqual(tokens, tokens_fast) + + ids = tokenizer.encode(sequence, add_special_tokens=False)["input_ids"] + ids_fast = tokenizer_fast.encode(sequence, add_special_tokens=False)["input_ids"] + self.assertListEqual(ids, ids_fast) + + ids = tokenizer.encode(sequence)["input_ids"] + ids_fast = tokenizer_fast.encode(sequence)["input_ids"] + self.assertListEqual(ids, ids_fast) + + def test_chinese(self): + tokenizer = BasicTokenizer() + + self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"]) + + def test_basic_tokenizer_lower(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) + + def test_basic_tokenizer_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_default(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_no_lower(self): + tokenizer = BasicTokenizer(do_lower_case=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_no_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_no_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_respects_never_split_tokens(self): + tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] + ) + + def test_wordpiece_tokenizer(self): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] + + vocab = {} + for (i, token) in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") + + self.assertListEqual(tokenizer.tokenize(""), []) + + self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"]) + + self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) + + def test_clean_text(self): + tokenizer = self.get_tokenizer() + + # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340 + self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("ernie-1.0") + + text = tokenizer.encode("sequence builders", return_token_type_ids=None, add_special_tokens=False)["input_ids"] + text_2 = tokenizer.encode("multi-sequence build", return_token_type_ids=None, add_special_tokens=False)[ + "input_ids" + ] + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + print(encoded_sentence) + assert encoded_sentence == [1] + text + [2] + assert encoded_pair == [1] + text + [2] + text_2 + [2] + + def test_offsets_with_special_characters(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_fast = self.fast_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + sentence = f"A, naïve {tokenizer.mask_token} AllenNLP sentence." + tokens = tokenizer.encode( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + tokens_fast = tokenizer_fast.encode( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + do_lower_case = tokenizer.do_lower_case if hasattr(tokenizer, "do_lower_case") else False + expected_results = ( + [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "A"), + ((1, 2), ","), + ((3, 5), "na"), + ((5, 6), "##ï"), + ((6, 8), "##ve"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "Allen"), + ((21, 23), "##NL"), + ((23, 24), "##P"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + if not do_lower_case + else [ + ((0, 0), tokenizer.cls_token), + ((0, 1), "a"), + ((1, 2), ","), + ((3, 5), "na"), + ((5, 8), "##ive"), + ((9, 15), tokenizer.mask_token), + ((16, 21), "allen"), + ((21, 22), "##n"), + ((22, 24), "##lp"), + ((25, 27), "se"), + ((27, 29), "##nt"), + ((29, 33), "##ence"), + ((33, 34), "."), + ((0, 0), tokenizer.sep_token), + ] + ) + + self.assertEqual( + [e[1] for e in expected_results], tokenizer.convert_ids_to_tokens(tokens["input_ids"]) + ) + self.assertEqual( + [e[1] for e in expected_results], tokenizer_fast.convert_ids_to_tokens(tokens_fast["input_ids"]) + ) + self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + self.assertEqual([e[0] for e in expected_results], tokens_fast["offset_mapping"]) + + def test_change_tokenize_chinese_chars(self): + list_of_commun_chinese_char = ["的", "人", "有"] + text_with_chinese_char = "".join(list_of_commun_chinese_char) + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + + kwargs["tokenize_chinese_chars"] = True + tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_fast = self.fast_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + ids_without_spe_char_p = tokenizer.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + ids_without_spe_char_fast = tokenizer_fast.encode( + text_with_chinese_char, return_token_type_ids=None, add_special_tokens=False + )["input_ids"] + + tokens_without_spe_char_p = tokenizer.convert_ids_to_tokens(ids_without_spe_char_p) + tokens_without_spe_char_fast = tokenizer.convert_ids_to_tokens(ids_without_spe_char_fast) + + # it is expected that each Chinese character is not preceded by "##" + self.assertListEqual(tokens_without_spe_char_p, list_of_commun_chinese_char) + self.assertListEqual(tokens_without_spe_char_fast, list_of_commun_chinese_char)