Skip to content

Commit

Permalink
Fix tests by overriding _prepare_for_class
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Aug 30, 2022
1 parent 499d3ea commit fc163aa
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _generate_supported_model_class_names(
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
"LayoutLMForQuestionAnswering",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
Expand Down Expand Up @@ -690,6 +691,7 @@ def _generate_dummy_input(
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
"LayoutLMForQuestionAnswering",
"XLNetForQuestionAnswering",
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
Expand Down
30 changes: 29 additions & 1 deletion tests/models/layoutlm/test_modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import copy
import unittest

from transformers import LayoutLMConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device

from ...test_configuration_common import ConfigTester
Expand All @@ -27,6 +28,9 @@
import torch

from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
LayoutLMForMaskedLM,
LayoutLMForQuestionAnswering,
LayoutLMForSequenceClassification,
Expand Down Expand Up @@ -269,6 +273,30 @@ def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
elif model_class in [
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
elif model_class.__name__ == "LayoutLMForQuestionAnswering":
inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)
inputs_dict["end_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device
)

return inputs_dict


def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
Expand Down
25 changes: 25 additions & 0 deletions tests/models/layoutlm/test_modeling_tf_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np

from transformers import LayoutLMConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow

from ...test_configuration_common import ConfigTester
Expand All @@ -27,6 +29,11 @@
if is_tf_available():
import tensorflow as tf

from transformers import (
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
from transformers.models.layoutlm.modeling_tf_layoutlm import (
TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFLayoutLMForMaskedLM,
Expand Down Expand Up @@ -256,6 +263,24 @@ def test_model_from_pretrained(self):
model = TFLayoutLMModel.from_pretrained(model_name)
self.assertIsNotNone(model)

def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
if return_labels:
if model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [
*get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
elif model_class.__name__ == "TFLayoutLMForQuestionAnswering":
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)

return inputs_dict


def prepare_layoutlm_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
Expand Down

0 comments on commit fc163aa

Please sign in to comment.