diff --git a/docs/model_zoo/transformers.rst b/docs/model_zoo/transformers.rst index 329e5118e9ea..5d37f2f8f159 100644 --- a/docs/model_zoo/transformers.rst +++ b/docs/model_zoo/transformers.rst @@ -447,7 +447,6 @@ Transformer预训练模型汇总 Transformer预训练模型适用任务汇总 ------------------------------------ - +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ | Model | Sequence Classification | Token Classification | Question Answering | Text Generation | Multiple Choice | +====================+=========================+======================+====================+=================+=================+ @@ -457,7 +456,7 @@ Transformer预训练模型适用任务汇总 +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ |BERT_ | ✅ | ✅ | ✅ | ❌ | ✅ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ -|BigBird_ | ✅ | ❌ | ❌ | ❌ | ❌ | +|BigBird_ | ✅ | ✅ | ✅ | ❌ | ✅ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ |ConvBert_ | ✅ | ✅ | ✅ | ✅ | ✅ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ diff --git a/paddlenlp/transformers/bigbird/modeling.py b/paddlenlp/transformers/bigbird/modeling.py index db2e0ac51477..14d608575d3f 100644 --- a/paddlenlp/transformers/bigbird/modeling.py +++ b/paddlenlp/transformers/bigbird/modeling.py @@ -11,26 +11,60 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import math +import numpy as np import paddle from paddle.nn import Linear, Dropout, LayerNorm, LayerList, Layer import paddle.nn.functional as F import paddle.nn as nn + from ..attention_utils import _convert_param_attr_to_list, MultiHeadAttention, \ AttentionRegistry from .. import PretrainedModel, register_base_model __all__ = [ - 'BigBirdModel', - 'BigBirdPretrainedModel', - 'BigBirdForPretraining', - 'BigBirdPretrainingCriterion', - 'BigBirdForSequenceClassification', - 'BigBirdPretrainingHeads', + 'BigBirdModel', 'BigBirdPretrainedModel', 'BigBirdForPretraining', + 'BigBirdPretrainingCriterion', 'BigBirdForSequenceClassification', + 'BigBirdPretrainingHeads', 'BigBirdForQuestionAnswering', + 'BigBirdForTokenClassification', 'BigBirdForMultipleChoice', + 'BigBirdForMaskedLM', 'BigBirdForCausalLM' ] +def mish(x): + return x * F.tanh(F.softplus(x)) + + +def linear_act(x): + return x + + +def swish(x): + return x * F.sigmoid(x) + + +def gelu_new(x): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + return 0.5 * x * (1.0 + paddle.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0)))) + + +ACT2FN = { + "relu": F.relu, + "gelu": F.gelu, + "gelu_new": gelu_new, + "tanh": F.tanh, + "sigmoid": F.sigmoid, + "mish": mish, + "linear": linear_act, + "swish": swish, +} + + class TransformerEncoderLayer(Layer): def __init__(self, d_model, @@ -146,6 +180,7 @@ def forward(self, class BigBirdPooler(Layer): """ + Pool the result of BigBird Encoder """ def __init__(self, hidden_size): @@ -899,3 +934,432 @@ def forward(self, prediction_scores, seq_relationship_score, next_sentence_loss = F.cross_entropy( seq_relationship_score, next_sentence_labels, reduction='none') return masked_lm_loss + paddle.mean(next_sentence_loss) * scale + + +class BigBirdIntermediate(Layer): + def __init__(self, hidden_size, dim_feedforward, activation): + super().__init__() + self.dense = nn.Linear(hidden_size, dim_feedforward) + if isinstance(activation, str): + self.intermediate_act_fn = ACT2FN[activation] + else: + self.intermediate_act_fn = activation + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BigBirdOutput(Layer): + def __init__(self, hidden_size, dim_feedforward, hidden_dropout_prob): + super().__init__() + self.dense = nn.Linear(dim_feedforward, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BigBirdForQuestionAnswering(BigBirdPretrainedModel): + """ + BigBird Model with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and + `span end logits`). + + Args: + bigbird (:class:`BigBirdModel`): + An instance of BigBirdModel. + dropout (float, optional): + The dropout probability for output of BigBirdModel. + If None, use the same value as `hidden_dropout_prob` of `BigBirdModel` + instance `bigbird`. Defaults to `None`. + """ + + def __init__(self, bigbird, dropout=None): + super(BigBirdForQuestionAnswering, self).__init__() + self.bigbird = bigbird # allow bigbird to be config + self.dropout = nn.Dropout(dropout if dropout is not None else + self.bigbird.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.bigbird.config["hidden_size"], 2) + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask_list=None, + rand_mask_idx_list=None): + r""" + The BigBirdForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`BigBirdModel`. + token_type_ids (Tensor, optional): + See :class:`BigBirdModel`. + attention_mask_list (`List`): + See :class:`BigBirdModel`. + rand_mask_idx_list (`List`): + See :class:`BigBirdModel`. + + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.bigbird.modeling import BigBirdForQuestionAnswering + from paddlenlp.transformers.bigbird.tokenizer import BigBirdTokenizer + + tokenizer = BigBirdTokenizer.from_pretrained('bigbird-base-uncased') + model = BigBirdForQuestionAnswering.from_pretrained('bigbird-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + start_logits = outputs[0] + end_logits =outputs[1] + """ + sequence_output, _ = self.bigbird( + input_ids, + token_type_ids=token_type_ids, + attention_mask_list=attention_mask_list, + rand_mask_idx_list=rand_mask_idx_list) + + logits = self.classifier(sequence_output) + logits = paddle.transpose(logits, perm=[2, 0, 1]) + start_logits, end_logits = paddle.unstack(x=logits, axis=0) + + return start_logits, end_logits + + @staticmethod + def prepare_question_mask(q_lengths, maxlen): + mask = paddle.arange(0, maxlen).unsqueeze_(0) + mask = mask < q_lengths + return mask + + +class BigBirdForTokenClassification(BigBirdPretrainedModel): + """ + BigBird Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + for Named-Entity-Recognition (NER) tasks. + + Args: + bigbird (:class:`BigBirdModel`): + An instance of BigBirdModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of BIGBIRD. + If None, use the same value as `hidden_dropout_prob` of `BigBirdModel` + instance `bigbird`. Defaults to None. + """ + + def __init__(self, bigbird, num_classes=2, dropout=None): + super(BigBirdForTokenClassification, self).__init__() + self.num_classes = num_classes + self.bigbird = bigbird # allow bigbird to be config + self.dropout = nn.Dropout(dropout if dropout is not None else + self.bigbird.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.bigbird.config["hidden_size"], + num_classes) + self.apply(self.init_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask_list=None, + rand_mask_idx_list=None): + r""" + The BigBirdForSequenceClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`BigBirdModel`. + token_type_ids (Tensor, optional): + See :class:`BigBirdModel`. + attention_mask_list (`List`): + See :class:`BigBirdModel`. + rand_mask_idx_list (`List`): + See :class:`BigBirdModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input token classification logits. + Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.bigbird.modeling import BigBirdForTokenClassification + from paddlenlp.transformers.bigbird.tokenizer import BigBirdTokenizer + + tokenizer = BigBirdTokenizer.from_pretrained('bigbird-base-uncased') + model = BigBirdForTokenClassification.from_pretrained('bigbird-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits = outputs + """ + sequence_output, _ = self.bigbird( + input_ids, + token_type_ids=token_type_ids, + attention_mask_list=attention_mask_list, + rand_mask_idx_list=rand_mask_idx_list) + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + return logits + + +class BigBirdForMultipleChoice(BigBirdPretrainedModel): + """ + BigBird Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + + Args: + bigbird (:class:`BigBirdModel`): + An instance of BigBirdModel. + num_choices (int, optional): + The number of choices. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of BIGBIRD. + If None, use the same value as `hidden_dropout_prob` of `BigBirdModel` + instance `bigbird`. Defaults to None. + """ + + def __init__(self, bigbird, num_choices=2, dropout=None): + super(BigBirdForMultipleChoice, self).__init__() + self.bigbird = bigbird # allow bigbird to be config + self.num_choices = num_choices + self.dropout = nn.Dropout(dropout if dropout is not None else + self.bigbird.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.bigbird.config["hidden_size"], 1) + self.apply(self.init_weights) + + def forward(self, + input_ids, + attention_mask_list=None, + rand_mask_idx_list=None): + r""" + The BigBirdForMultipleChoice forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`BigBirdModel` and shape as [batch_size, num_choice, sequence_length]. + attention_mask_list (`List`): + See :class:`BigBirdModel` and shape as [batch_size, num_choice, n_head, sequence_length, sequence_length]. + rand_mask_idx_list (`List`): + See :class:`BigBirdModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input text classification logits. + Shape as `[batch_size, 1]` and dtype as float32. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.bigbird.modeling import BigBirdForMultipleChoice + from paddlenlp.transformers.bigbird.tokenizer import BigBirdTokenizer + + tokenizer = BigBirdTokenizer.from_pretrained('bigbird-base-uncased') + model = BigBirdForTokenClassification.from_pretrained('bigbird-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits = outputs + """ + # input_ids: [bs, num_choice, seq_l] + input_ids = input_ids.reshape(shape=( + -1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] + + if attention_mask_list is not None: + attention_mask_list = attention_mask_list.reshape(shape=( + -1, *attention_mask_list.shape[2:])) + + if rand_mask_idx_list is not None: + rand_mask_idx_list = rand_mask_idx_list.reshape(shape=( + -1, *rand_mask_idx_list.shape[2:])) + + _, pooled_output = self.bigbird( + input_ids, + attention_mask_list=attention_mask_list, + rand_mask_idx_list=rand_mask_idx_list) + + pooled_output = self.dropout(pooled_output) + + logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) + reshaped_logits = logits.reshape( + shape=(-1, self.num_choices)) # logits: (bs, num_choice) + + return reshaped_logits + + +class BigBirdForMaskedLM(BigBirdPretrainedModel): + """ + BigBird Model with pretraining tasks on top. + + Args: + BigBird (:class:`BigBirdModel`): + An instance of :class:`BigBirdModel`. + + """ + + def __init__(self, bigbird): + super(BigBirdForMaskedLM, self).__init__() + self.bigbird = bigbird + self.lm_head = BigBirdLMPredictionHead( + self.bigbird.config["hidden_size"], + self.bigbird.config["vocab_size"], + self.bigbird.config["activation"], + self.bigbird.embeddings.word_embeddings.weight) + + self.apply(self.init_weights) + + def forward(self, + input_ids, + attention_mask_list=None, + rand_mask_idx_list=None, + labels=None): + r""" + + Args: + input_ids (Tensor): + See :class:`BigBirdModel`. + attention_mask_list (`List`): + See :class:`BigBirdModel`. + rand_mask_idx_list (`List`): + See :class:`BigBirdModel`. + labels (Tensor, optional): + The Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., vocab_size]`` Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., vocab_size]`` Its shape is [batch_size, sequence_length]. + + Returns: + tuple: Returns tuple (`masked_lm_loss`, `prediction_scores`, ``sequence_output`). + + With the fields: + + - `masked_lm_loss` (Tensor): + The masked lm loss. Its data type should be float32 and its shape is [1]. + + - `prediction_scores` (Tensor): + The scores of masked token prediction. Its data type should be float32. Its shape is [batch_size, sequence_length, vocab_size]. + + - `sequence_output` (Tensor): + Sequence of hidden-states at the last layer of the model. Its data type should be float32. Its shape is `[batch_size, sequence_length, hidden_size]`. + + + """ + sequence_output, _ = self.bigbird( + input_ids, + attention_mask_list=attention_mask_list, + rand_mask_idx_list=rand_mask_idx_list) + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.reshape(shape=( + -1, self.bigbird.config["vocab_size"])), + labels.reshape(shape=(-1, )), ) + return masked_lm_loss, prediction_scores, sequence_output + + return prediction_scores, sequence_output + + +class BigBirdForCausalLM(BigBirdPretrainedModel): + """ + BigBird Model for casual language model task. + + Args: + BigBird (:class:`BigBirdModel`): + An instance of :class:`BigBirdModel`. + + """ + + def __init__(self, bigbird): + super(BigBirdForCausalLM, self).__init__() + self.bigbird = bigbird + self.lm_head = BigBirdLMPredictionHead( + self.bigbird.config["hidden_size"], + self.bigbird.config["vocab_size"], + self.bigbird.config["activation"], + self.bigbird.embeddings.word_embeddings.weight) + + self.apply(self.init_weights) + + def forward(self, + input_ids, + attention_mask_list=None, + rand_mask_idx_list=None, + labels=None): + r""" + + Args: + input_ids (Tensor): + See :class:`BigBirdModel`. + attention_mask_list (`List`): + See :class:`BigBirdModel`. + rand_mask_idx_list (`List`): + See :class:`BigBirdModel`. + labels (Tensor, optional): + The Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., vocab_size]`` Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., vocab_size]`` Its shape is [batch_size, sequence_length]. + + Returns: + tuple: Returns tuple (`masked_lm_loss`, `prediction_scores`, ``sequence_output`). + + With the fields: + + - `masked_lm_loss` (Tensor): + The masked lm loss. Its data type should be float32 and its shape is [1]. + + - `prediction_scores` (Tensor): + The scores of masked token prediction. Its data type should be float32. Its shape is [batch_size, sequence_length, vocab_size]. + + - `sequence_output` (Tensor): + Sequence of hidden-states at the last layer of the model. Its data type should be float32. Its shape is `[batch_size, sequence_length, hidden_size]`. + + + """ + sequence_output, _ = self.bigbird( + input_ids, + attention_mask_list=attention_mask_list, + rand_mask_idx_list=rand_mask_idx_list) + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :] + labels = labels[:, 1:] + loss_fct = nn.CrossEntropyLoss() + lm_loss = loss_fct( + paddle.reshape(shifted_prediction_scores, + [-1, self.bigbird.config['vocab_size']]), + paddle.reshape(labels, [-1])) + + return lm_loss, prediction_scores, sequence_output + + return prediction_scores, sequence_output diff --git a/tests/transformers/bigbird/test_modeling.py b/tests/transformers/bigbird/test_modeling.py index 5f281b3a8892..9430428944f9 100644 --- a/tests/transformers/bigbird/test_modeling.py +++ b/tests/transformers/bigbird/test_modeling.py @@ -13,16 +13,18 @@ # limitations under the License. import copy +import unittest + import numpy as np -import os import paddle -from paddlenlp.transformers import BigBirdForSequenceClassification, \ - BigBirdPretrainingCriterion, BigBirdForPretraining, BigBirdModel -from paddlenlp.transformers import create_bigbird_rand_mask_idx_list from common_test import CommonTest from util import softmax_with_cross_entropy, slow -import unittest +from paddlenlp.transformers import BigBirdForSequenceClassification, \ + BigBirdPretrainingCriterion, BigBirdForPretraining, BigBirdModel, \ + BigBirdForQuestionAnswering, BigBirdForTokenClassification, BigBirdForMultipleChoice, \ + BigBirdForMaskedLM, BigBirdForCausalLM +from paddlenlp.transformers import create_bigbird_rand_mask_idx_list def create_input_data(config, seed=None): @@ -111,6 +113,234 @@ def test_forward(self): self.check_output_equal(self.expected_shape, output.numpy().shape) +class TestBigBirdForQuestionAnswering(CommonTest): + def set_input(self): + self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[ + 'bigbird-base-uncased']) + self.config['num_layers'] = 2 + self.config['vocab_size'] = 1024 + self.config['attn_dropout'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['dim_feedforward'] = 1024 + self.config['seq_len'] = 1024 + self.config['batch_size'] = 2 + self.config['max_position_embeddings'] = 2048 + self.rand_mask_idx_list, self.input_ids, self.masked_lm_positions = create_input_data( + self.config) + + def set_output(self): + self.expected_shape1 = (self.config['batch_size'], + self.config['seq_len']) + self.expected_shape2 = (self.config['batch_size'], + self.config['seq_len']) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = BigBirdForQuestionAnswering + + def test_forward(self): + bigbird = BigBirdModel(**self.config) + model = self.TEST_MODEL_CLASS(bigbird) + input_ids = paddle.to_tensor(self.input_ids) + rand_mask_idx_list = paddle.to_tensor(self.rand_mask_idx_list) + start_logits, end_logits = model( + input_ids, rand_mask_idx_list=rand_mask_idx_list) + self.check_output_equal(self.expected_shape1, + start_logits.numpy().shape) + self.check_output_equal(self.expected_shape2, end_logits.numpy().shape) + + +class TestBigBirdForTokenClassification(CommonTest): + def set_input(self): + self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[ + 'bigbird-base-uncased']) + self.config['num_layers'] = 2 + self.config['vocab_size'] = 1024 + self.config['attn_dropout'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['dim_feedforward'] = 1024 + self.config['seq_len'] = 1024 + self.config['batch_size'] = 2 + self.config['max_position_embeddings'] = 2048 + self.rand_mask_idx_list, self.input_ids, self.masked_lm_positions = create_input_data( + self.config) + self.num_classes = 2 + + def set_output(self): + self.expected_shape = (self.config['batch_size'], + self.config['seq_len'], self.num_classes) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = BigBirdForTokenClassification + + def test_forward(self): + bigbird = BigBirdModel(**self.config) + model = self.TEST_MODEL_CLASS(bigbird, num_classes=self.num_classes) + input_ids = paddle.to_tensor(self.input_ids) + rand_mask_idx_list = paddle.to_tensor(self.rand_mask_idx_list) + output = model(input_ids, rand_mask_idx_list=rand_mask_idx_list) + self.check_output_equal(self.expected_shape, output.numpy().shape) + + +class TestBigBirdForMultipleChoice(CommonTest): + def set_input(self): + self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[ + 'bigbird-base-uncased']) + self.config['num_layers'] = 2 + self.config['vocab_size'] = 1024 + self.config['attn_dropout'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['dim_feedforward'] = 1024 + self.config['seq_len'] = 1024 + self.config['batch_size'] = 2 + self.config['max_position_embeddings'] = 2048 + self.rand_mask_idx_list, self.input_ids, self.masked_lm_positions = [], [], [] + self.num_choices = 2 + for i in range(self.num_choices): + rand_mask_idx_list, input_ids, masked_lm_positions = create_input_data( + self.config) + self.rand_mask_idx_list.append(rand_mask_idx_list) + self.input_ids.append(input_ids) + self.masked_lm_positions.append(masked_lm_positions) + self.rand_mask_idx_list = np.array(self.rand_mask_idx_list).swapaxes(0, + 1) + self.input_ids = np.array(self.input_ids).swapaxes(0, 1) + self.masked_lm_positions = np.array(self.masked_lm_positions).swapaxes( + 0, 1) + + def set_output(self): + self.expected_shape = (self.config['batch_size'], self.num_choices) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = BigBirdForMultipleChoice + + def test_forward(self): + bigbird = BigBirdModel(**self.config) + model = self.TEST_MODEL_CLASS(bigbird, num_choices=self.num_choices) + input_ids = paddle.to_tensor(self.input_ids) + rand_mask_idx_list = paddle.to_tensor(self.rand_mask_idx_list) + output = model(input_ids, rand_mask_idx_list=rand_mask_idx_list) + self.check_output_equal(self.expected_shape, output.numpy().shape) + + +class TestBigBirdForMaskedLM(CommonTest): + def set_input(self): + self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[ + 'bigbird-base-uncased']) + self.config['num_layers'] = 2 + self.config['vocab_size'] = 1024 + self.config['attn_dropout'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['dim_feedforward'] = 1024 + self.config['seq_len'] = 1024 + self.config['batch_size'] = 2 + self.config['max_position_embeddings'] = 2048 + self.rand_mask_idx_list, self.input_ids, self.masked_lm_positions = create_input_data( + self.config) + self.labels = np.random.randint( + low=0, + high=self.config['vocab_size'], + size=(self.config["batch_size"], self.config["seq_len"])) + + def set_output(self): + self.expected_shape1 = (1, ) + self.expected_shape2 = (self.config['batch_size'], + self.config['seq_len'], + self.config['vocab_size']) + self.expected_shape3 = (self.config['batch_size'], + self.config['seq_len'], + self.config['hidden_size']) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = BigBirdForMaskedLM + + def test_forward(self): + bigbird = BigBirdModel(**self.config) + model = self.TEST_MODEL_CLASS(bigbird) + input_ids = paddle.to_tensor(self.input_ids) + rand_mask_idx_list = paddle.to_tensor(self.rand_mask_idx_list) + labels = paddle.to_tensor(self.labels) + masked_lm_loss, prediction_scores, sequence_output = model( + input_ids, rand_mask_idx_list=rand_mask_idx_list, labels=labels) + self.check_output_equal(self.expected_shape1, + masked_lm_loss.numpy().shape) + self.check_output_equal(self.expected_shape2, + prediction_scores.numpy().shape) + self.check_output_equal(self.expected_shape3, + sequence_output.numpy().shape) + + +class TestBigBirdForCausalLM(CommonTest): + def set_input(self): + self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[ + 'bigbird-base-uncased']) + self.config['num_layers'] = 2 + self.config['vocab_size'] = 1024 + self.config['attn_dropout'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['dim_feedforward'] = 1024 + self.config['seq_len'] = 1024 + self.config['batch_size'] = 2 + self.config['max_position_embeddings'] = 2048 + self.rand_mask_idx_list, self.input_ids, self.masked_lm_positions = create_input_data( + self.config) + self.labels = np.random.randint( + low=0, + high=self.config['vocab_size'], + size=(self.config["batch_size"], self.config["seq_len"])) + + def set_output(self): + self.expected_shape1 = (1, ) + self.expected_shape2 = (self.config['batch_size'], + self.config['seq_len'], + self.config['vocab_size']) + self.expected_shape3 = (self.config['batch_size'], + self.config['seq_len'], + self.config['hidden_size']) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = BigBirdForCausalLM + + def test_forward(self): + bigbird = BigBirdModel(**self.config) + model = self.TEST_MODEL_CLASS(bigbird) + input_ids = paddle.to_tensor(self.input_ids) + rand_mask_idx_list = paddle.to_tensor(self.rand_mask_idx_list) + labels = paddle.to_tensor(self.labels) + masked_lm_loss, prediction_scores, sequence_output = model( + input_ids, rand_mask_idx_list=rand_mask_idx_list, labels=labels) + self.check_output_equal(self.expected_shape1, + masked_lm_loss.numpy().shape) + self.check_output_equal(self.expected_shape2, + prediction_scores.numpy().shape) + self.check_output_equal(self.expected_shape3, + sequence_output.numpy().shape) + + class TestBigBirdForPretraining(TestBigBirdForSequenceClassification): def set_input(self): self.config = copy.deepcopy(BigBirdModel.pretrained_init_configuration[