Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mobilebert model #1160

Merged
merged 19 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 233 additions & 40 deletions paddlenlp/transformers/mobilebert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,13 @@
"MobileBertForQuestionAnswering",
]


def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + paddle.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0))))


ACT2FN = {
"relu": F.relu,
"gelu": F.gelu,
"gelu_new": gelu_new,
"tanh": F.tanh,
Copy link
Contributor

@FrostML FrostML Dec 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不再支持近似计算的 gelu 和 tanh 的原因是?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

模型没用到

}

MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = ["google/mobilebert-uncased"]


class NoNorm(nn.Layer): #paddle
class NoNorm(nn.Layer):
def __init__(self, feat_size, eps=None):
super().__init__()
if isinstance(feat_size, int):
Expand Down Expand Up @@ -258,7 +244,6 @@ def forward(self, hidden_states):
class OutputBottleneck(nn.Layer):
def __init__(
self,
# config
true_hidden_size=128,
hidden_size=512,
normalization_type="no_norm",
Expand Down Expand Up @@ -334,7 +319,6 @@ def forward(self, hidden_states):
class Bottleneck(nn.Layer):
def __init__(
self,
# config
key_query_shared_bottleneck=True,
use_bottleneck_attention=False,
hidden_size=512,
Expand Down Expand Up @@ -722,8 +706,11 @@ def forward(self, sequence_output, pooled_output):

class MobileBertPreTrainedModel(PretrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
An abstract class for pretrained MobileBert models. It provides MobileBert related
`model_config_file`, `resource_files_names`, `pretrained_resource_files_map`,
`pretrained_init_configuration`, `base_model_prefix` for downloading and
loading pretrained models.
See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details.
"""

model_config_file = "model_config.json"
Expand Down Expand Up @@ -861,7 +848,66 @@ def forward(
@register_base_model
class MobileBertModel(MobileBertPreTrainedModel):
"""
https://arxiv.org/pdf/2004.02984.pdf
The bare MobileBert Model transformer outputting raw hidden-states.
This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
Refer to the superclass documentation for the generic methods.
This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation
/docs/en/api/paddle/fluid/dygraph/layers/Layer_en.html>`__ subclass. Use it as a regular Paddle Layer
and refer to the Paddle documentation for all matter related to general usage and behavior.
Args:
vocab_size (int):
Vocabulary size of `inputs_ids` in `MobileBertModel`. Also is the vocab size of token embedding matrix.
Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `MobileBertModel`.
embedding_size (int, optional):
Embedding dimensionality of lookup_table in the embedding layer. Defaults to `128`.
hidden_size (int, optional):
Dimensionality of the embedding layer, encoder layer and pooler layer. Defaults to `512`.
true_hidden_size (int, optional):
Dimensionality of input_tensor in self attention layer. Defaults to `128`.
use_bottleneck_attention (bool, optional):
Using bottleneck to value tensor in self attention layer. Defaults to `False`.
key_query_shared_bottleneck (bool, optional):
Key and query shared bottleneck layer. Defaults to `True`.
num_hidden_layers (int, optional):
Number of hidden layers in the Transformer encoder. Defaults to `24`.
num_attention_heads (int, optional):
Number of attention heads for each attention layer in the Transformer encoder.
Defaults to `4`.
intermediate_size (int, optional):
Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors
to ff layers are firstly projected from `hidden_size` to `intermediate_size`,
and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`.
Defaults to `512`.
hidden_act (str, optional):
The non-linear activation function in the feed-forward layer.
``"gelu"``, ``"relu"`` and any other paddle supported activation functions
are supported. Defaults to `"relu"`.
hidden_dropout_prob (float, optional):
The dropout probability for all fully connected layers in the embeddings and encoder.
Defaults to `0.1`.
attention_probs_dropout_prob (float, optional):
The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target.
Defaults to `0.1`.
max_position_embeddings (int, optional):
The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input
sequence. Defaults to `512`.
type_vocab_size (int, optional):
The vocabulary size of `token_type_ids`.
Defaults to `2`.
initializer_range (float, optional):
The standard deviation of the normal initializer.
Defaults to 0.02.
.. note::
A normal_initializer initializes weight matrices as normal distributions.
See :meth:`MobileBertPreTrainedModel.init_weights()` for how weights are initialized in `MobileBertModel`.
pad_token_id (int, optional):
The index of padding token in the token vocabulary.
Defaults to `1`.
add_pooling_layer (bool, optional):
Adding the pooling Layer after the encoder layer. Defaults to `True`.
classifier_activation (bool, optional):
Using the non-linear activation function in the pooling layer. Defaults to `False`.

"""

def __init__(
Expand Down Expand Up @@ -980,14 +1026,76 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):

def forward(
self,
input_ids=None,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_hidden_states=None,
output_attentions=None, ):
r'''
The MobileBertModel forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
token_type_ids (Tensor, optional):
Segment token indices to indicate different portions of the inputs.
Selected in the range ``[0, type_vocab_size - 1]``.
If `type_vocab_size` is 2, which means the inputs have two portions.
Indices can either be 0 or 1:
- 0 corresponds to a *sentence A* token,
- 1 corresponds to a *sentence B* token.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
Defaults to `None`, which means we don't add segment embeddings.
position_ids(Tensor, optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
max_position_embeddings - 1]``.
Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`.
attention_mask (Tensor, optional):
Mask used in multi-head attention to avoid performing attention on to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
Defaults to `None`, which means nothing needed to be prevented attention to.
head_mask (:obj:`paddle.Tensor` with shape :obj:`[num_heads]` or :obj:`[num_hidden_layers x num_heads]`, `optional`):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). Defaults to `None`.
output_hidden_states (bool, optional):
Whether to return the output of each hidden layers.
Defaults to `None`.
output_attentions (bool, optional):
Whether to return the output of each self attention layers.
Defaults to `None`.
Returns:
tuple: Returns tuple (`sequence_output`, `pooled_output`) or (`encoder_outputs`, `pooled_output`).
With the fields:
- `sequence_output` (Tensor):
Sequence of hidden-states at the last layer of the model.
It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
- `pooled_output` (Tensor):
The output of first token (`[CLS]`) in sequence.
We "pool" the model by simply taking the hidden state corresponding to the first token.
Its data type should be float32 and its shape is [batch_size, hidden_size].
- `encoder_outputs` (List(Tensor)):
A list of Tensor containing hidden-states of the model at each hidden layer in the Transformer encoder.
The length of the list is `num_hidden_layers`.
Each Tensor has a data type of float32 and its shape is [batch_size, sequence_length, hidden_size].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import MobileBertModel, MobileBertTokenizer
tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
model = MobileBertModel.from_pretrained('google/mobilebert-uncased')
inputs = tokenizer("Hello, my dog is cute")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
output = model(**inputs)
'''

output_attentions = output_attentions is not None
output_hidden_states = (output_hidden_states is not None)

Expand All @@ -1003,8 +1111,6 @@ def forward(
raise ValueError(
"You have to specify either input_ids or inputs_embeds")

# device = input_ids.device if input_ids is not None else inputs_embeds.device

if attention_mask is None:
attention_mask = paddle.ones(input_shape, dtype=input_ids.dtype)
if token_type_ids is None:
Expand Down Expand Up @@ -1041,7 +1147,17 @@ def forward(


class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
def __init__(self, mobilebert, num_labels=3):
"""
MobileBert Model with a linear layer on top of the output layer,
designed for sequence classification/regression tasks like GLUE tasks.
Args:
mobilebert (:class:`MobileBertModel`):
An instance of MobileBert.
num_classes (int, optional):
The number of classes. Defaults to `2`.
"""

def __init__(self, mobilebert, num_labels=2):
super(MobileBertForSequenceClassification, self).__init__()
self.num_labels = num_labels
self.mobilebert = mobilebert
Expand All @@ -1056,15 +1172,48 @@ def __init__(self, mobilebert, num_labels=3):
self.init_weights()

def forward(self,
input_ids=None,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
do_compare=False):
output_hidden_states=None):
r"""
The MobileBertForSequenceClassification forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`MobileBertModel`.
token_type_ids (Tensor, optional):
See :class:`MobileBertModel`.
position_ids(Tensor, optional):
See :class:`MobileBertModel`.
head_mask (Tensor, optional):
See :class:`MobileBertModel`.
attention_mask (Tensor, optional):
See :class:`MobileBertModel`.
inputs_embeds (Tensor, optional):
See :class:`MobileBertModel`.
output_attentions (bool, optional):
See :class:`MobileBertModel`.
output_hidden_states (bool, optional):
See :class:`MobileBertModel`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Shape as `[batch_size, num_classes]` and dtype as float32.
Example:
.. code-block::
import paddle
from paddlenlp.transformers import MobileBertForSequenceClassification, MobileBertTokenizer
tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
model = MobileBertForSequenceClassification.from_pretrained('google/mobilebert-uncased', num_classes=2)
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
print(logits.shape)
# [1, 2]
"""

outputs = self.mobilebert(
input_ids,
Expand All @@ -1080,13 +1229,19 @@ def forward(self,

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if do_compare:
return logits, outputs[0]
else:
return logits

return logits


class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
"""
MobileBert Model with a linear layer on top of the hidden-states output to compute `span_start_logits`
and `span_end_logits`, designed for question-answering tasks like SQuAD.
Args:
mobilebert (:class:`MobileBert`):
An instance of MobileBert.
"""

def __init__(self, mobilebert):
super(MobileBertForQuestionAnswering, self).__init__()
self.num_labels = 2
Expand All @@ -1098,7 +1253,7 @@ def __init__(self, mobilebert):

def forward(
self,
input_ids=None,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
Expand All @@ -1109,14 +1264,52 @@ def forward(
output_attentions=None,
output_hidden_states=None, ):
r"""
start_positions (:obj:`paddle.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (:obj:`paddle.Tensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
The MobileBertForQuestionAnswering forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`MobileBertModel`.
token_type_ids (Tensor, optional):
See :class:`MobileBertModel`.
position_ids(Tensor, optional):
See :class:`MobileBertModel`.
head_mask (Tensor, optional):
See :class:`MobileBertModel`.
attention_mask (Tensor, optional):
See :class:`MobileBertModel`.
inputs_embeds (Tensor, optional):
See :class:`MobileBertModel`.
output_attentions (bool, optional):
See :class:`MobileBertModel`.
output_hidden_states (bool, optional):
See :class:`MobileBertModel`.
start_positions (Tensor, optional):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
end_positions (Tensor, optional):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
sequence are not taken into account for computing the loss.
Returns:
tuple: Returns tuple (`start_logits`, `end_logits`).
With the fields:
- `start_logits` (Tensor):
A tensor of the input token classification logits, indicates the start position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- `end_logits` (Tensor):
A tensor of the input token classification logits, indicates the end position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import MobileBertForQuestionAnswering, MobileBertTokenizer
tokenizer = MobileBertTokenizer.from_pretrained('bert-base-cased')
model = MobileBertForQuestionAnswering.from_pretrained('bert-base-cased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", "PaddlePaddle and PaddleNLP")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
start_logits = outputs[0]
end_logits = outputs[1]
"""
outputs = self.mobilebert(
input_ids,
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/mobilebert/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MobileBertTokenizer(BertTokenizer):
Construct a MobileBERT tokenizer.
:class:`~paddlenlp.transformers.MobileBertTokenizer is identical to :class:`~paddlenlp.transformers.BertTokenizer` and runs end-to-end
tokenization: punctuation splitting and wordpiece.
Refer to superclass :class:`~transformers.BertTokenizer` for usage examples and documentation concerning
Refer to superclass :class:`~~paddlenlp.transformers.BertTokenizer` for usage examples and documentation concerning
parameters.
"""
resource_files_names = {"vocab_file": "vocab.txt"}
Expand Down