From 99cf7d22c3cbbe81c905e560828cd1cb0f7df02c Mon Sep 17 00:00:00 2001 From: Austin Date: Tue, 28 Jun 2022 13:10:29 -0700 Subject: [PATCH] refactor bert and gpt (#1130) Converting the composer GPT2 and BERT models to use a HuggingFaceModel base class compatible with any pretrained model off the HF Hub. Converted the model creations to factory functions and moved all logic out of the hparams classes. --- .../gated_linear_units/gated_linear_units.py | 11 +- .../seq_length_warmup/seq_length_warmup.py | 8 +- composer/models/__init__.py | 7 +- composer/models/bert/__init__.py | 5 +- composer/models/bert/bert_hparams.py | 174 +++++------ composer/models/bert/model.py | 271 ++++++++++++------ composer/models/gpt2/__init__.py | 4 +- composer/models/gpt2/gpt2_hparams.py | 83 +++--- composer/models/gpt2/model.py | 144 ++++++---- composer/models/huggingface.py | 60 +++- composer/models/transformer_hparams.py | 49 ---- composer/models/transformer_shared.py | 130 --------- composer/yamls/models/glue/cola.yaml | 1 - composer/yamls/models/glue/mnli.yaml | 1 - composer/yamls/models/glue/mrpc.yaml | 1 - composer/yamls/models/glue/qnli.yaml | 1 - composer/yamls/models/glue/qqp.yaml | 1 - composer/yamls/models/glue/rte.yaml | 1 - composer/yamls/models/glue/sst-2.yaml | 1 - composer/yamls/models/glue/stsb.yaml | 1 - examples/huggingface_models.ipynb | 4 +- tests/algorithms/test_fused_layernorm.py | 17 +- tests/algorithms/test_gated_linear_units.py | 15 +- tests/common/datasets.py | 6 +- tests/common/models.py | 10 +- tests/fixtures/synthetic_hf_state.py | 14 +- 26 files changed, 483 insertions(+), 537 deletions(-) delete mode 100644 composer/models/transformer_hparams.py delete mode 100644 composer/models/transformer_shared.py diff --git a/composer/algorithms/gated_linear_units/gated_linear_units.py b/composer/algorithms/gated_linear_units/gated_linear_units.py index d955fa21c6..8ce0d59025 100644 --- a/composer/algorithms/gated_linear_units/gated_linear_units.py +++ b/composer/algorithms/gated_linear_units/gated_linear_units.py @@ -11,7 +11,10 @@ import torch +from composer.models.huggingface import HuggingFaceModel + try: + from transformers import BertForMaskedLM, BertForSequenceClassification from transformers.models.bert.modeling_bert import BertIntermediate, BertOutput IS_TRANSFORMERS_INSTALLED = True except ImportError as e: @@ -21,7 +24,6 @@ from composer.algorithms.warnings import NoEffectWarning from composer.core import Algorithm, Event, State from composer.loggers import Logger -from composer.models import BERTModel from composer.utils import MissingConditionalImportError, module_surgery log = logging.getLogger(__name__) @@ -79,9 +81,10 @@ def apply_gated_linear_units(model: torch.nn.Module, if not IS_TRANSFORMERS_INSTALLED: raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') - # ensure that the model is an instance of a BERTModel, since our replacement policy is only defined for BERTs - if not isinstance(model, BERTModel): - raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERTModel.') + # ensure that the model is an instance of a BERT model, since our replacement policy is only defined for BERTs + if not isinstance(model, HuggingFaceModel) and not (isinstance(model.model, BertForMaskedLM) or + isinstance(model.model, BertForSequenceClassification)): + raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERT models.') if act_fn is None: # get the activation functions used diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 3dbaadc127..0cf0e9b398 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -15,7 +15,7 @@ from composer.core.time import TimeUnit from composer.core.types import Batch from composer.loggers import Logger -from composer.models import ComposerTransformer +from composer.models import HuggingFaceModel from composer.utils import ensure_tuple __all__ = ['SeqLengthWarmup', 'set_batch_sequence_length'] @@ -163,7 +163,7 @@ def __init__( self.truncate = truncate if self.duration < 0 or self.duration > 1: - raise ValueError(f'Duration must be getween 0 and 1, got: {self.duration}') + raise ValueError(f'Duration must be between 0 and 1, got: {self.duration}') if self.max_seq_length < self.min_seq_length: raise ValueError(f'max_seq_length={self.max_seq_length} must be ' @@ -176,10 +176,10 @@ def match(self, event: Event, state: State) -> bool: def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: if event == Event.INIT: - if not isinstance(state.model, ComposerTransformer): + if not isinstance(state.model, HuggingFaceModel): raise RuntimeError( textwrap.dedent(f"""\ - {type(self).__name__} requires state.model to be of type {ComposerTransformer.__name__}, not of type {type(state.model)}""" + {type(self).__name__} requires state.model to be of type {HuggingFaceModel.__name__}, not of type {type(state.model)}""" )) self._original_model = state.model diff --git a/composer/models/__init__.py b/composer/models/__init__.py index 41d4eee725..9354056620 100644 --- a/composer/models/__init__.py +++ b/composer/models/__init__.py @@ -12,7 +12,8 @@ from composer.models.base import ComposerModel as ComposerModel from composer.models.bert import BERTForClassificationHparams as BERTForClassificationHparams from composer.models.bert import BERTHparams as BERTHparams -from composer.models.bert import BERTModel as BERTModel +from composer.models.bert import create_bert_classification as create_bert_classification +from composer.models.bert import create_bert_mlm as create_bert_mlm from composer.models.classify_mnist import MNIST_Classifier as MNIST_Classifier from composer.models.classify_mnist import MnistClassifierHparams as MnistClassifierHparams from composer.models.deeplabv3 import ComposerDeepLabV3 as ComposerDeepLabV3 @@ -20,7 +21,7 @@ from composer.models.efficientnetb0 import EfficientNetB0 as EfficientNetB0 from composer.models.efficientnetb0 import EfficientNetB0Hparams as EfficientNetB0Hparams from composer.models.gpt2 import GPT2Hparams as GPT2Hparams -from composer.models.gpt2 import GPT2Model as GPT2Model +from composer.models.gpt2 import create_gpt2 as create_gpt2 from composer.models.huggingface import HuggingFaceModel as HuggingFaceModel from composer.models.initializers import Initializer as Initializer from composer.models.model_hparams import ModelHparams as ModelHparams @@ -33,8 +34,6 @@ from composer.models.tasks import ComposerClassifier as ComposerClassifier from composer.models.timm import Timm as Timm from composer.models.timm import TimmHparams as TimmHparams -from composer.models.transformer_hparams import TransformerHparams as TransformerHparams -from composer.models.transformer_shared import ComposerTransformer as ComposerTransformer from composer.models.unet import UNet as UNet from composer.models.unet import UnetHparams as UnetHparams from composer.models.vit_small_patch16 import ViTSmallPatch16 as ViTSmallPatch16 diff --git a/composer/models/bert/__init__.py b/composer/models/bert/__init__.py index 9264fcc21c..7fb03f6787 100644 --- a/composer/models/bert/__init__.py +++ b/composer/models/bert/__init__.py @@ -6,6 +6,7 @@ from composer.models.bert.bert_hparams import BERTForClassificationHparams as BERTForClassificationHparams from composer.models.bert.bert_hparams import BERTHparams as BERTHparams -from composer.models.bert.model import BERTModel as BERTModel +from composer.models.bert.model import create_bert_classification as create_bert_classification +from composer.models.bert.model import create_bert_mlm as create_bert_mlm -__all__ = ['BERTModel', 'BERTHparams', 'BERTForClassificationHparams'] +__all__ = ['BERTHparams', 'BERTForClassificationHparams', 'create_bert_classification', 'create_bert_mlm'] diff --git a/composer/models/bert/bert_hparams.py b/composer/models/bert/bert_hparams.py index 0551154323..e6b12bdaac 100644 --- a/composer/models/bert/bert_hparams.py +++ b/composer/models/bert/bert_hparams.py @@ -5,130 +5,98 @@ :class:`.BERTModel`.""" from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Dict, Optional import yahp as hp -from composer.models.transformer_hparams import TransformerHparams -from composer.utils import MissingConditionalImportError - -if TYPE_CHECKING: - from composer.models.bert import BERTModel +from composer.core.types import JSON +from composer.models.model_hparams import ModelHparams __all__ = ['BERTForClassificationHparams', 'BERTHparams'] @dataclass -class BERTForClassificationHparams(TransformerHparams): - """`YAHP `_ classification interface for - :class:`.BERTModel`. +class BERTHparams(ModelHparams): + """`YAHP `_ interface for :class:`.BERTModel`. Args: - pretrained_model_name (str): Pretrained model name to pull from Hugging Face Model Hub. - model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration. - tokenizer_name (Optional[str]): The tokenizer used for this model, - necessary to assert required model inputs. + model_config (Dict[str, JSON], optional): A dictionary providing a HuggingFace model configuration. + pretrained_model_name (str, optional): Pretrained model name to pull from Hugging Face Model Hub. use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. - gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. - num_labels (int, optional): The number of classes in the segmentation task. Default: ``2``. + tokenizer_name (str, optional): The tokenizer used for this model, + necessary to assert required model inputs. Default ``None``. + gradient_checkpointing (bool, optional): Use gradient checkpointing. default: False. """ - num_labels: int = hp.optional(doc='The number of possible labels for the task.', default=2) - - def validate(self): - if self.num_labels < 1: - raise ValueError('The number of target labels must be at least one.') - - def initialize_object(self) -> 'BERTModel': - try: - import transformers - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e - - from composer.models.bert.model import BERTModel - self.validate() - - model_hparams = {'num_labels': self.num_labels} - - if self.model_config: - config = transformers.BertConfig.from_dict(self.model_config, **model_hparams) - elif self.pretrained_model_name is not None: - config = transformers.BertConfig.from_pretrained(self.pretrained_model_name, **model_hparams) - else: - raise ValueError('One of pretrained_model_name or model_config needed.') - config.num_labels = self.num_labels - - # setup the tokenizer in the hparams interface - if self.tokenizer_name is not None: - tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name) - else: - tokenizer = None - - if self.use_pretrained: - # TODO (Moin): handle the warnings on not using the seq_relationship head - assert transformers.AutoModelForSequenceClassification.from_pretrained is not None, 'from_pretrained should not be None' - model = transformers.AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name, - **model_hparams) - else: - # an invariant to ensure that we don't lose keys when creating the HF config - for k, v in model_hparams.items(): - assert getattr(config, k) == v - model = transformers.AutoModelForSequenceClassification.from_config( #type: ignore (thirdparty) - config) - - return BERTModel( - module=model, - config=config, #type: ignore (thirdparty) - tokenizer=tokenizer, + model_config: Optional[Dict[str, + JSON]] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.', + default_factory=dict) + pretrained_model_name: Optional[str] = hp.optional(doc='Pretrained model name to pull from Hugging Face Model Hub.', + default=None) + use_pretrained: Optional[bool] = hp.optional('Whether to initialize the model with the pretrained weights.', + default=False) + tokenizer_name: Optional[str] = hp.optional( + 'The tokenizer used for this model, necessary to assert required model inputs.', default=None) + gradient_checkpointing: Optional[bool] = hp.optional('Whether to enable gradient checkpointing.', default=False) + + def initialize_object(self): + from composer.models.bert.model import create_bert_mlm + + # user must specify one of either config or the pretrained model + if not self.pretrained_model_name and self.model_config == {}: + raise Exception('One of pretrained_model_name or model_config needed.') + + if self.use_pretrained and self.model_config: + raise Exception('A model cannot load pretrained weights from configuration.') + + return create_bert_mlm( + model_config=self.model_config, # type: ignore (thirdparty) + pretrained_model_name=self.pretrained_model_name, + use_pretrained=self.use_pretrained, + tokenizer_name=self.tokenizer_name, + gradient_checkpointing=self.gradient_checkpointing, ) @dataclass -class BERTHparams(TransformerHparams): +class BERTForClassificationHparams(ModelHparams): """`YAHP `_ interface for :class:`.BERTModel`. Args: - pretrained_model_name (str): "Pretrained model name to pull from Huggingface Model Hub." + num_labels (int, optional): The number of classes in the classification task. Default: ``2``. + pretrained_model_name (str, optional): Pretrained model name to pull from Hugging Face Model Hub. model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration. - tokenizer_name (str): The tokenizer used for this model, - necessary to assert required model inputs. use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. + tokenizer_name (Optional[str]): The tokenizer used for this model, + necessary to assert required model inputs. Default ``None``. gradient_checkpointing (bool, optional): Use gradient checkpointing. default: False. """ - - def initialize_object(self) -> 'BERTModel': - try: - import transformers - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e - - from composer.models.bert.model import BERTModel - self.validate() - - if self.model_config: - config = transformers.BertConfig.from_dict(self.model_config) - elif self.pretrained_model_name is not None: - config = transformers.BertConfig.from_pretrained(self.pretrained_model_name) - else: - raise ValueError('One of pretrained_model_name or model_config needed.') - - # set the number of labels ot the vocab size, used for measuring MLM accuracy - config.num_labels = config.vocab_size - - # setup the tokenizer in the hparams interface - if self.tokenizer_name is not None: - tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name) - else: - tokenizer = None - - if self.use_pretrained: - # TODO (Moin): handle the warnings on not using the seq_relationship head - assert transformers.AutoModelForMaskedLM.from_pretrained is not None, 'from_pretrained should not be None' - model = transformers.AutoModelForMaskedLM.from_pretrained(self.pretrained_model_name) - else: - model = transformers.AutoModelForMaskedLM.from_config(config) #type: ignore (thirdparty) - - return BERTModel( - module=model, - config=config, #type: ignore (thirdparty) - tokenizer=tokenizer, + num_labels: Optional[int] = hp.optional(doc='The number of possible labels for the task.', default=2) + pretrained_model_name: Optional[str] = hp.optional(doc='Pretrained model name to pull from Hugging Face Model Hub.', + default=None) + model_config: Optional[Dict[str, + JSON]] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.', + default_factory=dict) + use_pretrained: Optional[bool] = hp.optional('Whether to initialize the model with the pretrained weights.', + default=False) + tokenizer_name: Optional[str] = hp.optional( + 'The tokenizer used for this model, necessary to assert required model inputs.', default=None) + gradient_checkpointing: Optional[bool] = hp.optional('Whether to enable gradient checkpointing.', default=False) + + def initialize_object(self): + from composer.models.bert.model import create_bert_classification + + # user must specify one of either config or the pretrained model + if not self.pretrained_model_name and self.model_config == {}: + raise Exception('One of pretrained_model_name or model_config needed.') + + if self.use_pretrained and self.model_config: + raise Exception('A model cannot load pretrained weights from configuration.') + + return create_bert_classification( + num_labels=self.num_labels, + pretrained_model_name=self.pretrained_model_name, + model_config=self.model_config, # type: ignore (thirdparty) + use_pretrained=self.use_pretrained, + tokenizer_name=self.tokenizer_name, + gradient_checkpointing=self.gradient_checkpointing, ) diff --git a/composer/models/bert/model.py b/composer/models/bert/model.py index 75f6c9c1c9..c9bbbc03f2 100644 --- a/composer/models/bert/model.py +++ b/composer/models/bert/model.py @@ -5,114 +5,205 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union +from typing import Optional -import torch -from torchmetrics import MeanSquaredError, Metric, MetricCollection +from torchmetrics import MeanSquaredError from torchmetrics.classification.accuracy import Accuracy from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef from torchmetrics.regression.spearman import SpearmanCorrCoef from composer.metrics.nlp import BinaryF1Score, LanguageCrossEntropy, MaskedAccuracy -from composer.models.transformer_shared import ComposerTransformer +from composer.models.huggingface import HuggingFaceModel +from composer.utils.import_helpers import MissingConditionalImportError -if TYPE_CHECKING: - import transformers +__all__ = ['create_bert_mlm', 'create_bert_classification'] - from composer.core.types import Batch -__all__ = ['BERTModel'] - - -class BERTModel(ComposerTransformer): +def create_bert_mlm(use_pretrained: Optional[bool] = False, + pretrained_model_name: Optional[str] = None, + model_config: Optional[dict] = None, + tokenizer_name: Optional[str] = None, + gradient_checkpointing: Optional[bool] = False): """BERT model based on |:hugging_face:| Transformers. For more information, see `Transformers `_. Args: - module (transformers.BertModel): An instance of BertModel that - contains the forward pass function. - config (transformers.BertConfig): The BertConfig object that - stores information about the model hyperparameters. - tokenizer (transformers.BertTokenizer): An instance of BertTokenizer. Necessary to process model inputs. - To create a BERT model for Language Model pretraining: + gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. + use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``. + model_config (dict): The settings used to create a Hugging Face BertConfig. BertConfig is used to specify the + architecture of a Hugging Face model. + tokenizer_name (transformers.BertTokenizer, optional): Tokenizer name used to preprocess the dataset + and validate the models inputs. + + .. code-block:: + + { + "_name_or_path": "bert-base-uncased", + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": null, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.16.0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 + } + + To create a BERT model for Masked Language Model pretraining: .. testcode:: - from composer.models import BERTModel - import transformers + from composer.models import create_bert_mlm + model = create_bert_mlm() - config = transformers.BertConfig() - hf_model = transformers.BertLMHeadModel(config=config) - tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") - model = BERTModel(module=hf_model, config=config, tokenizer=tokenizer) """ + try: + import transformers + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e + + if not model_config: + model_config = {} + + if not pretrained_model_name: + pretrained_model_name = 'bert-base-uncased' + + if use_pretrained: + model = transformers.AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path=pretrained_model_name, + **model_config) + else: + config = transformers.AutoConfig.from_pretrained(pretrained_model_name, **model_config) + model = transformers.AutoModelForMaskedLM.from_config(config) # type: ignore (thirdparty) + + if gradient_checkpointing: + model.gradient_checkpointing_enable() # type: ignore + + # setup the tokenizer + if tokenizer_name: + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + else: + tokenizer = None + + metrics = [ + LanguageCrossEntropy(ignore_index=-100, vocab_size=model.config.vocab_size), + MaskedAccuracy(ignore_index=-100) + ] + return HuggingFaceModel(model=model, tokenizer=tokenizer, use_logits=True, metrics=metrics) + + +def create_bert_classification(num_labels: Optional[int] = 2, + use_pretrained: Optional[bool] = False, + pretrained_model_name: Optional[str] = None, + model_config: Optional[dict] = None, + tokenizer_name: Optional[str] = None, + gradient_checkpointing: Optional[bool] = False): + """BERT classification model based on |:hugging_face:| Transformers. - def __init__(self, - module: transformers.BertModel, - config: transformers.BertConfig, - tokenizer: Optional[transformers.BertTokenizer] = None) -> None: - - if tokenizer is None: - model_inputs = {'input_ids', 'attention_mask', 'token_type_ids'} - else: - model_inputs = set(tokenizer.model_input_names) - - super().__init__( - module=module, #type: ignore (thirdparty) - config=config, - model_inputs=model_inputs) - - # we're going to remove the label from the expected inputs - # since we will handle metric calculation with TorchMetrics instead of HuggingFace. - self.model_inputs.remove('labels') - - # When using Evaluators, the validation metrics represent all possible - # validation metrics that can be used with the bert model - # The Evaluator class checks if it's metrics are in the models validation metrics - - ignore_index = -100 - self.val_metrics = [ - Accuracy(), - MeanSquaredError(), - SpearmanCorrCoef(), - BinaryF1Score(), - MatthewsCorrCoef(num_classes=config.num_labels), - LanguageCrossEntropy(ignore_index=ignore_index, vocab_size=config.num_labels), - MaskedAccuracy(ignore_index=ignore_index), - ] - self.train_metrics = [] - - def loss(self, outputs: Mapping, batch: Batch) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - if outputs.get('loss', None) is not None: - return outputs['loss'] - else: - raise NotImplementedError('Calculating loss directly not supported yet.') - - def validate(self, batch: Any) -> Any: - """Runs the validation step. - - Args: - batch (Dict): a dictionary of Dict[str, Tensor] of inputs - that the model expects, as found in :meth:`.ComposerTransformer.get_model_inputs`. - - Returns: - tuple (Tensor, Tensor): with the output from the forward pass and the correct labels. - This is fed into directly into the output of :meth:`.ComposerModel.metrics`. - """ - assert self.training is False, 'For validation, model must be in eval mode' - - # temporary hack until eval on multiple datasets is finished - labels = batch.pop('labels') - output = self.forward(batch) - output = output['logits'] - - # if we are in the single class case, then remove the classes dimension - if output.shape[1] == 1: - output = output.squeeze(dim=1) - - return output, labels - - def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]: - return MetricCollection(self.train_metrics) if train else MetricCollection(self.val_metrics) + For more information, see `Transformers `_. + + Args: + num_labels (int, optional): The number of classes in the classification task. Default: ``2``. + gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. + use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``. + model_config (dict): The settings used to create a Hugging Face BertConfig. BertConfig is used to specify the + architecture of a Hugging Face model. + tokenizer_name (str, optional): Tokenizer name used to preprocess the dataset + and validate the models inputs. + + .. code-block:: + + { + "_name_or_path": "bert-base-uncased", + "architectures": [ + "BertForSequenceClassification + ], + "attention_probs_dropout_prob": 0.1, + "classifier_dropout": null, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1", + "2": "LABEL_2" + }, + "initializer_range": 0.02, + "intermediate_size": 3072, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1, + "LABEL_2": 2 + }, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.16.0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 30522 + } + + To create a BERT model for classification: + + .. testcode:: + + from composer.models import create_bert_classification + model = create_bert_classification(num_labels=3) # if the task has three classes. + + """ + try: + import transformers + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e + + if not model_config: + model_config = {} + + model_config['num_labels'] = num_labels + + if not pretrained_model_name: + pretrained_model_name = 'bert-base-uncased' + + if use_pretrained: + model = transformers.AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name, **model_config) + else: + config = transformers.AutoConfig.from_pretrained(pretrained_model_name, **model_config) + model = transformers.AutoModelForSequenceClassification.from_config(config) # type: ignore (thirdparty) + + if gradient_checkpointing: + model.gradient_checkpointing_enable() # type: ignore + + # setup the tokenizer + if tokenizer_name: + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + else: + tokenizer = None + + metrics = [ + Accuracy(), + MeanSquaredError(), + SpearmanCorrCoef(), + BinaryF1Score(), + MatthewsCorrCoef(num_classes=model.config.num_labels) + ] + return HuggingFaceModel(model=model, tokenizer=tokenizer, use_logits=True, metrics=metrics) diff --git a/composer/models/gpt2/__init__.py b/composer/models/gpt2/__init__.py index 80e0f8397e..89711cb464 100644 --- a/composer/models/gpt2/__init__.py +++ b/composer/models/gpt2/__init__.py @@ -9,9 +9,9 @@ """ from composer.models.gpt2.gpt2_hparams import GPT2Hparams as GPT2Hparams -from composer.models.gpt2.model import GPT2Model as GPT2Model +from composer.models.gpt2.model import create_gpt2 as create_gpt2 -__all__ = ['GPT2Model', 'GPT2Hparams'] +__all__ = ['create_gpt2', 'GPT2Hparams'] _metadata = { 'gpt2': { diff --git a/composer/models/gpt2/gpt2_hparams.py b/composer/models/gpt2/gpt2_hparams.py index 66c073904f..30e66cbb9e 100644 --- a/composer/models/gpt2/gpt2_hparams.py +++ b/composer/models/gpt2/gpt2_hparams.py @@ -3,63 +3,54 @@ """`YAHP `_ interface for :class:`.GPT2Model`.""" -import dataclasses -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import Dict, Optional -from composer.models.transformer_hparams import TransformerHparams -from composer.utils.import_helpers import MissingConditionalImportError +import yahp as hp -if TYPE_CHECKING: - from composer.models.transformer_shared import ComposerTransformer +from composer.core.types import JSON +from composer.models.model_hparams import ModelHparams __all__ = ['GPT2Hparams'] -@dataclasses.dataclass -class GPT2Hparams(TransformerHparams): +@dataclass +class GPT2Hparams(ModelHparams): """`YAHP `_ interface for :class:`.GPT2Model`. Args: - pretrained_model_name (str): Pretrained model name to pull from Hugging Face Model Hub. - model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration. - tokenizer_name (Optional[str]): The tokenizer used for this model, - necessary to assert required model inputs. + model_config (Dict[str, JSON], optional ): A dictionary providing a HuggingFace model configuration. + pretrained_model_name (str, optional): Pretrained model name to pull from Hugging Face Model Hub. use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``. + tokenizer_name (str, optional): The tokenizer used for this model, + necessary to assert required model inputs. Default ``None``. gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. """ - - def initialize_object(self) -> 'ComposerTransformer': - try: - import transformers - except ImportError as e: - raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e - - from composer.models.gpt2.model import GPT2Model - self.validate() - - if self.model_config: - config = transformers.GPT2Config.from_dict(self.model_config) - elif self.pretrained_model_name is not None: - # TODO (Moin): verify that the config is an appropriate instance of GPT2! - config = transformers.GPT2Config.from_pretrained(self.pretrained_model_name) - else: - raise ValueError('One of pretrained_model_name or model_config needed.') - - # setup the tokenizer in the hparams interface - if self.tokenizer_name is not None: - tokenizer = transformers.GPT2Tokenizer.from_pretrained(self.tokenizer_name) - else: - tokenizer = None - - if self.use_pretrained: - assert transformers.AutoModelForCausalLM.from_pretrained is not None, 'from_pretrained should not be None' - model = transformers.AutoModelForCausalLM.from_pretrained(self.pretrained_model_name) - else: - model = transformers.AutoModelForCausalLM.from_config(config) #type: ignore (thirdparty) - - return GPT2Model( - module=model, - config=config, #type: ignore (thirdparty) - tokenizer=tokenizer, + model_config: Optional[Dict[str, + JSON]] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.', + default_factory=dict) + pretrained_model_name: Optional[str] = hp.optional(doc='Pretrained model name to pull from Hugging Face Model Hub.', + default=None) + use_pretrained: Optional[bool] = hp.optional('Whether to initialize the model with the pretrained weights.', + default=False) + tokenizer_name: Optional[str] = hp.optional( + 'The tokenizer used for this model, necessary to assert required model inputs.', default=None) + gradient_checkpointing: Optional[bool] = hp.optional('Whether to enable gradient checkpointing.', default=False) + + def initialize_object(self): + from composer.models.gpt2.model import create_gpt2 + + # user must specify one of either config or the pretrained model + if not self.pretrained_model_name and self.model_config == {}: + raise Exception('One of pretrained_model_name or model_config needed.') + + if self.use_pretrained and self.model_config: + raise Exception('A model cannot load pretrained weights from configuration.') + + return create_gpt2( + model_config=self.model_config, #type: ignore (thirdparty) + pretrained_model_name=self.pretrained_model_name, + use_pretrained=self.use_pretrained, + tokenizer_name=self.tokenizer_name, gradient_checkpointing=self.gradient_checkpointing, ) diff --git a/composer/models/gpt2/model.py b/composer/models/gpt2/model.py index 0e29286d9a..dd912cd0f2 100644 --- a/composer/models/gpt2/model.py +++ b/composer/models/gpt2/model.py @@ -8,78 +8,106 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union +from typing import Optional -from torch import Tensor -from torchmetrics import Metric, MetricCollection +from composer.metrics.nlp import HFCrossEntropy, Perplexity +from composer.models.huggingface import HuggingFaceModel +from composer.utils.import_helpers import MissingConditionalImportError -from composer.metrics.nlp import Perplexity -from composer.models.transformer_shared import ComposerTransformer +__all__ = ['create_gpt2'] -if TYPE_CHECKING: - import transformers - from composer.core.types import Batch - -__all__ = ['GPT2Model'] - - -class GPT2Model(ComposerTransformer): - """Implements :class:`~composer.models.transformer_shared.ComposerTransformer` to wrap `Hugging Face GPT-2 +def create_gpt2(use_pretrained: Optional[bool] = False, + pretrained_model_name: Optional[str] = None, + model_config: Optional[dict] = None, + tokenizer_name: Optional[str] = None, + gradient_checkpointing: Optional[bool] = False): + """Implements :class:`~composer.models.huggingface.HuggingFaceModel` to wrap `Hugging Face GPT-2 \ transformers `_. Logs training and validation perplexity. From `Language Models are Unsupervised Multitask Learners `_ (Radford et al, 2018). Args: - module (transformers.GPT2Model): The model to wrap with this module. - config (transformers.GPT2Config): The config for the model. - tokenizer (transformers.GPT2Tokenizer): The tokenizer used for this model. Necessary to process model inputs. - gradient_checkpointing (bool, optional): Use gradient checkpointing. default: ``False``. - To create a GPT-2 model for language modeling pretraining: + gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. + use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``. + model_config (dict): A dictionary providing a HuggingFace model configuration. + tokenizer_name (str, optional): Tokenizer name used to preprocess the dataset + and validate the models inputs. + + .. code-block:: + + { + "_name_or_path": "gpt2", + "activation_function": "gelu_new", + "architectures": ["GPT2LMHeadModel"], + "attn_pdrop": 0.1, + "bos_token_id": 50256, + "embd_pdrop": 0.1, + "eos_token_id": 50256, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": null, + "n_layer": 12, + "n_positions": 1024, + "reorder_and_upcast_attn": false, + "resid_pdrop": 0.1, + "scale_attn_by_inverse_layer_idx": false, + "scale_attn_weights": true, + "summary_activation": null, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": true, + "summary_type": "cls_index", + "summary_use_proj": true, + "task_specific_params": { + "text-generation": { + "do_sample": true, + "max_length": 50 } + }, + "transformers_version": "4.16.0", + "use_cache": true, + "vocab_size": 50257 + } + + To create a GPT-2 model for language modeling pretraining: .. testcode:: - from composer.models import GPT2Model - import transformers + from composer.models import create_gpt2 + + composer_model = create_gpt2() - config = transformers.GPT2Config() - hf_model = transformers.GPT2LMHeadModel(config=config) # gpt2-small model from huggingface - tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2") - model = GPT2Model(module=hf_model, config=config, tokenizer=tokenizer) """ + try: + import transformers + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e + + if not model_config: + model_config = {} + + if not pretrained_model_name: + pretrained_model_name = 'gpt2' + + if use_pretrained: + model = transformers.AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=pretrained_model_name, + **model_config) + else: + config = transformers.AutoConfig.from_pretrained(pretrained_model_name, **model_config) + model = transformers.AutoModelForCausalLM.from_config(config) # type: ignore (thirdparty) + + if gradient_checkpointing: + model.gradient_checkpointing_enable() # type: ignore + + # setup the tokenizer + if tokenizer_name: + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) + else: + tokenizer = None - def __init__(self, - module: transformers.GPT2Model, - config: transformers.GPT2Config, - tokenizer: Optional[transformers.GPT2Tokenizer] = None, - gradient_checkpointing: bool = False) -> None: - - if tokenizer is None: - model_inputs = {'input_ids', 'attention_mask'} - else: - model_inputs = set(tokenizer.model_input_names) - - super().__init__( - module=module, #type: ignore (thirdparty) - config=config, - model_inputs=model_inputs, - gradient_checkpointing=gradient_checkpointing) - - # If we ever have algorithms that modify the loss function, then this might be a bit inefficient - # because it'll compute the expensive softmax operation twice. - # Instead, we should consider figuring out how to leverage self.train_loss and return the e^self.train_loss. - # Of course, this also depends on the implementation details of algorithms. - self.train_perplexity = Perplexity() - self.val_perplexity = Perplexity() - - def loss(self, outputs: Mapping, batch: Batch) -> Union[Tensor, Sequence[Tensor]]: - if outputs.get('loss', None) is not None: - return outputs['loss'] - else: - raise NotImplementedError('Calculating loss directly not supported yet.') - - def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]: - return MetricCollection([self.train_loss, self.train_perplexity]) if train else MetricCollection( - [self.val_loss, self.val_perplexity]) + return HuggingFaceModel(model=model, tokenizer=tokenizer, metrics=[HFCrossEntropy(), Perplexity()]) diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index c7bae4f4c4..76e269987e 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -11,6 +11,7 @@ from torchmetrics.collections import MetricCollection from composer.models.base import ComposerModel +from composer.utils.import_helpers import MissingConditionalImportError if TYPE_CHECKING: import transformers @@ -24,8 +25,9 @@ class HuggingFaceModel(ComposerModel): Args: model (transformers.PreTrainedModel): A 🤗 Transformers model. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer used to prepare the dataset and validate model inputs during training. Default ``None``. + use_logits (bool, optional): If True, the model's output logits will be used to calculate validation metrics. Else, metrics will be inferred from the HuggingFaceModel directly. Default: ``False`` metrics (list[Metric], optional): list of torchmetrics to apply to the output of `validate`. Default: ``None``. - .. warning:: This wrapper is designed to work with 🤗 datasets that define a `labels` column. Example: @@ -39,9 +41,31 @@ class HuggingFaceModel(ComposerModel): model = HuggingFaceModel(hf_model) """ - def __init__(self, model: transformers.PreTrainedModel, metrics: Optional[List[Metric]] = None) -> None: + def __init__(self, + model: transformers.PreTrainedModel, + tokenizer: Optional[transformers.PreTrainedTokenizer] = None, + use_logits: Optional[bool] = False, + metrics: Optional[List[Metric]] = None) -> None: + try: + import transformers + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e + super().__init__() self.model = model + self.config = model.config + + # the set of inputs that a model expects inferred from the model type or + # tokenizer if provided + if tokenizer is None: + if isinstance(self.model.base_model, transformers.GPT2Model): + self.model_inputs = {'input_ids', 'attention_mask'} + elif isinstance(self.model.base_model, transformers.BertModel): + self.model_inputs = {'input_ids', 'attention_mask', 'token_type_ids'} + else: + self.model_inputs = set(tokenizer.model_input_names) + + self.use_logits = use_logits self.train_metrics = None self.valid_metrics = None @@ -52,6 +76,10 @@ def __init__(self, model: transformers.PreTrainedModel, metrics: Optional[List[M self.valid_metrics = metric_collection.clone(prefix='val_') def forward(self, batch): + for key in self.model_inputs: + if key not in batch.keys(): + raise ValueError(f'Batch missing key: {key}') + output = self.model(**batch) # type: ignore (thirdparty) return output @@ -59,10 +87,30 @@ def loss(self, outputs, batch): return outputs['loss'] def validate(self, batch): - labels = batch.pop('labels') - output = self.forward(batch) - output = output['logits'] - return output, labels + if self.use_logits: + labels = batch.pop('labels') + output = self.forward(batch) + output = output['logits'] + + # if we are in the single class case, then remove the classes dimension + if output.shape[1] == 1: + output = output.squeeze(dim=1) + + return output, labels + else: + output = self.forward(batch) + return output, None def metrics(self, train: bool = False): return self.train_metrics if train else self.valid_metrics + + def get_model_inputs(self): + """Returns a set of inputs that the model expects in the forward pass. + If an algorithm wants to interact with the model inputs (for instance, + popping the labels for a custom loss fn, or adding attention head masks + for head pruning, it must access self.set_model_inputs(). + Returns: + model_inputs: The set of keys that are expected in the Mapping used to compute the forward pass. + """ + + return self.model_inputs diff --git a/composer/models/transformer_hparams.py b/composer/models/transformer_hparams.py deleted file mode 100644 index 92c0290bea..0000000000 --- a/composer/models/transformer_hparams.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""YAHP :class:`.hp.Hparams` hyperparameters for ComposerTransformers.""" - -from abc import ABC -from dataclasses import dataclass -from typing import Dict, Optional - -import yahp as hp - -from composer.core.types import JSON -from composer.models.model_hparams import ModelHparams - -__all__ = ['TransformerHparams'] - - -@dataclass -class TransformerHparams(ModelHparams, ABC): - """Defines the necessary hyperparameters for a Transformer base module. - - Args: - pretrained_model_name (Optional[str]): "Pretrained model name to pull from Huggingface Model Hub." - model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration. - tokenizer_name (str): The tokenizer used for this model, - necessary to assert required model inputs. - use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False`` - gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. - """ - - tokenizer_name: Optional[str] = hp.optional('Tokenizer name to pull from Huggingface Model Hub.', default=None) - pretrained_model_name: Optional[str] = hp.optional( - doc='Pretrained model name to pull from Huggingface Model Hub.', - default=None, - ) - model_config: Dict[str, JSON] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.', - default_factory=dict) - use_pretrained: bool = hp.optional('Whether to initialize the model with the pretrained weights.', default=False) - gradient_checkpointing: bool = hp.optional('Whether to enable gradient checkpointing.', default=False) - - def validate(self): - if self.pretrained_model_name is None and self.model_config == {}: - raise Exception('One of pretrained_model_name or model_config needed.') - - if self.pretrained_model_name is not None and self.model_config != {}: - raise Exception('Only one of pretrained_model_name or model_config can be provided.') - - if self.use_pretrained and self.model_config: - raise Exception('A model cannot load pretrained weights from configuration.') diff --git a/composer/models/transformer_shared.py b/composer/models/transformer_shared.py deleted file mode 100644 index 6c32db31d4..0000000000 --- a/composer/models/transformer_shared.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""The ComposerModel base interface for Transformers.""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Mapping, Sequence, Tuple, Union - -from torch import Tensor - -from composer.metrics.nlp import HFCrossEntropy -from composer.models.base import ComposerModel - -if TYPE_CHECKING: - import transformers - - from composer.core.types import Batch - -log = logging.getLogger(__name__) - -__all__ = ['ComposerTransformer'] - - -class ComposerTransformer(ComposerModel): - """The ComposerModel base interface for Transformers. - - Works with `Hugging Face Transformers `_. - - Args: - module (transformers.PreTrainedModel): An instance of PreTrainedModel that - contains the forward pass function. - config (transformers.PretrainedConfig): The PretrainedConfig object that - stores information about the model hyperparameters. - model_inputs (set): The dictionary keys that should be required to be fed into the model's forward function. - gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``. - """ - - def __init__(self, - module: transformers.PreTrainedModel, - config: transformers.PretrainedConfig, - model_inputs: set, - gradient_checkpointing: bool = False) -> None: - super().__init__() - self.module = module - self.config = config - log.info('Number of parameters in the model: ' \ - f'{sum(p.numel() for p in module.parameters()):,}') # type: ignore (thirdparty) - log.info('Number of trainable parameters in the model: ' - f'{sum(p.numel() for p in module.parameters() if p.requires_grad):,}') # type: ignore (thirdparty) - - # the set of inputs that a model expects - # if an algorithm modifies the loss function, it must remove "labels" from this set. - self.model_inputs = model_inputs - self.model_inputs.update(set({'labels'})) - - # define metrics for measurements - self.train_loss = HFCrossEntropy() - self.val_loss = HFCrossEntropy() - - if gradient_checkpointing: - self.module.gradient_checkpointing_enable() # type: ignore - - def loss(self, outputs: Mapping, batch: Batch) -> Union[Tensor, Sequence[Tensor]]: - """Computes the loss of the tensor from the output. - - We don't implement this for the generic Transformer abstraction, since loss - functions are model and objective specific. A single model architecture could - use a myriad of loss functions which are better left expressed by the user. - - Args: - outputs (Mapping): The dictionary output from the model. - It could contain the loss as computed by Hugging Face, - or algorithms can pop the labels from the input in case - they modify the loss function. - batch (:class:`~composer.core.types.Batch`): The set of ground truth labels to use to compute the loss against. - - Raises: - NotImplementedError: A model-specific and task-specific loss function must be written. - """ - raise NotImplementedError('A model-specific loss function must be written.') - - def forward(self, batch: Batch) -> Mapping: - """Run the forward pass of the model. - - Args: - batch (~composer.core.types.Batch): A dictionary of Dict[str, Tensor] of inputs that the - model expects, as found in :meth:`.ComposerTransformer.get_model_inputs`. - - Returns: - output: A dictionary of model outputs as a ``Mapping``. It will include the loss if `labels` is passed as an input. - """ - if not isinstance(batch, dict): - raise ValueError(f'Model expects batch to be a dict, got {type(batch)}') - - for key in self.model_inputs: - if key not in batch.keys(): - raise ValueError(f'Batch missing key: {key}') - - output = self.module(**batch) # type: ignore (thirdparty) - return output - - def validate(self, batch: Batch) -> Tuple[Mapping, None]: - """Runs the validation step. - - Args: - batch (~composer.core.types.Batch): a dictionary of Dict[str, Tensor] of inputs - that the model expects, as found in :meth:`.ComposerTransformer.get_model_inputs`. - - Returns: - Tuple[Mapping, None]: A tuple containing the output from the forward pass. - This is fed into directly into the output of :meth:`.ComposerModel.metrics`. - """ - assert self.training is False, 'For validation, model must be in eval mode' - output = self.forward(batch) - return output, None - - def get_model_inputs(self): - """Returns a set of inputs that the model expects in the forward pass. - - If an algorithm wants to interact with the model inputs (for instance, - popping the labels for a custom loss fn, or adding attention head masks - for head pruning, it must access self.set_model_inputs(). - - Returns: - model_inputs: The set of keys that are expected in the Mapping used to compute the forward pass. - """ - - return self.model_inputs diff --git a/composer/yamls/models/glue/cola.yaml b/composer/yamls/models/glue/cola.yaml index b4d38ef706..01d16421c2 100644 --- a/composer/yamls/models/glue/cola.yaml +++ b/composer/yamls/models/glue/cola.yaml @@ -23,7 +23,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/mnli.yaml b/composer/yamls/models/glue/mnli.yaml index fac57e2754..8566a31749 100644 --- a/composer/yamls/models/glue/mnli.yaml +++ b/composer/yamls/models/glue/mnli.yaml @@ -35,7 +35,6 @@ model: bert_classification: num_labels: 3 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/mrpc.yaml b/composer/yamls/models/glue/mrpc.yaml index 79c911f800..fc3fbc7517 100644 --- a/composer/yamls/models/glue/mrpc.yaml +++ b/composer/yamls/models/glue/mrpc.yaml @@ -24,7 +24,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/qnli.yaml b/composer/yamls/models/glue/qnli.yaml index 9858f7f16a..921983d21a 100644 --- a/composer/yamls/models/glue/qnli.yaml +++ b/composer/yamls/models/glue/qnli.yaml @@ -23,7 +23,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/qqp.yaml b/composer/yamls/models/glue/qqp.yaml index 292074a4e4..c38bfde8cf 100644 --- a/composer/yamls/models/glue/qqp.yaml +++ b/composer/yamls/models/glue/qqp.yaml @@ -24,7 +24,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/rte.yaml b/composer/yamls/models/glue/rte.yaml index 843fad2898..7b0dd66d77 100644 --- a/composer/yamls/models/glue/rte.yaml +++ b/composer/yamls/models/glue/rte.yaml @@ -23,7 +23,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/sst-2.yaml b/composer/yamls/models/glue/sst-2.yaml index a5fc6005fd..5e53782cfa 100644 --- a/composer/yamls/models/glue/sst-2.yaml +++ b/composer/yamls/models/glue/sst-2.yaml @@ -23,7 +23,6 @@ model: bert_classification: num_labels: 2 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/composer/yamls/models/glue/stsb.yaml b/composer/yamls/models/glue/stsb.yaml index 9ca5037677..f2c998b424 100644 --- a/composer/yamls/models/glue/stsb.yaml +++ b/composer/yamls/models/glue/stsb.yaml @@ -23,7 +23,6 @@ model: bert_classification: num_labels: 1 use_pretrained: true - tokenizer_name: bert-base-uncased pretrained_model_name: bert-base-uncased optimizers: decoupled_adamw: diff --git a/examples/huggingface_models.ipynb b/examples/huggingface_models.ipynb index cae4b847b8..c5b932a19c 100644 --- a/examples/huggingface_models.ipynb +++ b/examples/huggingface_models.ipynb @@ -41,7 +41,7 @@ "metadata": {}, "source": [ "## Import Hugging Face pretrained model\n", - "First, we import a pretrainer BERT model and tokenizer from the transformers library. We alter the model to output two classes for sentiment classification by setting `num_labels=2`." + "First, we import a pretrainec BERT model and tokenizer from the transformers library. We alter the model to output two classes for sentiment classification by setting `num_labels=2`." ] }, { @@ -141,7 +141,7 @@ "\n", "metrics = [CrossEntropy(), Accuracy()]\n", "# Package as a composer model\n", - "composer_model = HuggingFaceModel(model, metrics=metrics)" + "composer_model = HuggingFaceModel(model, metrics=metrics, use_logits=True)" ] }, { diff --git a/tests/algorithms/test_fused_layernorm.py b/tests/algorithms/test_fused_layernorm.py index 2c7d084c7c..b4c8f4a389 100644 --- a/tests/algorithms/test_fused_layernorm.py +++ b/tests/algorithms/test_fused_layernorm.py @@ -9,7 +9,6 @@ from composer.algorithms.fused_layernorm import FusedLayerNorm, apply_fused_layernorm from composer.core.event import Event from composer.loggers import Logger -from composer.models import BERTModel from tests.common import device from tests.fixtures.synthetic_hf_state import make_dataset_configs, synthetic_hf_state_maker @@ -20,10 +19,13 @@ def synthetic_bert_state(): return synthetic_hf_state_maker(synthetic_config) -def assert_is_fln_instance(model: BERTModel): +def assert_is_fln_instance(model): pytest.importorskip('apex') + pytest.importorskip('transformers') from apex.normalization.fused_layer_norm import FusedLayerNorm as APEXFusedLayerNorm + from transformers import BertForMaskedLM, BertForSequenceClassification + assert isinstance(model, BertForMaskedLM) or isinstance(model, BertForSequenceClassification) # ensure that within the entire model, no PyTorch LayerNorm exists, and at least one APEX FLN does. for module_class in model.modules(): assert not isinstance( @@ -37,17 +39,22 @@ def assert_is_fln_instance(model: BERTModel): def test_fused_layernorm_functional(synthetic_bert_state: Tuple, device: str): state, _, _ = synthetic_bert_state apply_fused_layernorm(state.model, state.optimizers) - assert_is_fln_instance(state.model) + assert_is_fln_instance(state.model.model) @device('gpu') def test_fused_layernorm_algorithm(synthetic_bert_state: Tuple, empty_logger: Logger, device: str): + pytest.importorskip('transformers') + from transformers import BertForMaskedLM, BertForSequenceClassification + state, _, _ = synthetic_bert_state fused_layernorm = FusedLayerNorm() if device == 'gpu': state.model = state.model.cuda() # move the model to gpu - assert isinstance(state.model, BERTModel) + # state.model wrapped in HuggingFaceModel wrapped + assert isinstance(state.model.model, BertForMaskedLM) or isinstance(state.model.model, + BertForSequenceClassification) fused_layernorm.apply(Event.INIT, state, empty_logger) - assert_is_fln_instance(state.model) + assert_is_fln_instance(state.model.model) diff --git a/tests/algorithms/test_gated_linear_units.py b/tests/algorithms/test_gated_linear_units.py index 7a6292a7c7..22df8f6b16 100644 --- a/tests/algorithms/test_gated_linear_units.py +++ b/tests/algorithms/test_gated_linear_units.py @@ -11,7 +11,6 @@ from composer.algorithms.gated_linear_units.gated_linear_unit_layers import BERTGatedFFOutput from composer.core.event import Event from composer.loggers import Logger -from composer.models import BERTModel from tests.fixtures.synthetic_hf_state import make_dataset_configs, synthetic_hf_state_maker @@ -56,10 +55,12 @@ def synthetic_bert_state(): return synthetic_hf_state_maker(synthetic_config) -def assert_is_glu_instance(model: BERTModel): +def assert_is_glu_instance(model): pytest.importorskip('transformers') + from transformers import BertForMaskedLM, BertForSequenceClassification from transformers.models.bert.modeling_bert import BertOutput + assert isinstance(model, BertForMaskedLM) or isinstance(model, BertForSequenceClassification) # ensure that within the entire model, no BertOutput exists, and at least one BERTGatedFFOutput does. for module_class in model.modules(): assert not isinstance( @@ -74,14 +75,18 @@ def assert_is_glu_instance(model: BERTModel): def test_gated_linear_units_functional(synthetic_bert_state: Tuple): state, _, _ = synthetic_bert_state apply_gated_linear_units(state.model, state.optimizers) - assert_is_glu_instance(state.model) + assert_is_glu_instance(state.model.model) def test_gated_linear_units_algorithm(synthetic_bert_state: Tuple, empty_logger: Logger): + pytest.importorskip('transformers') + from transformers import BertForMaskedLM, BertForSequenceClassification state, _, _ = synthetic_bert_state gated_linear_units = GatedLinearUnits() - assert isinstance(state.model, BERTModel) + # state.model wrapped in HuggingFaceModel wrapped + assert isinstance(state.model.model, BertForMaskedLM) or isinstance(state.model.model, + BertForSequenceClassification) gated_linear_units.apply(Event.INIT, state, empty_logger) - assert_is_glu_instance(state.model) + assert_is_glu_instance(state.model.model) diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 0eaddcb85a..67241ca7cc 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -17,7 +17,6 @@ from composer.datasets.lm_dataset_hparams import LMDatasetHparams from composer.datasets.synthetic_hparams import SyntheticHparamsMixin from composer.models import ModelHparams -from composer.models.transformer_hparams import TransformerHparams from tests.common.models import model_hparams_to_tokenizer_family @@ -124,10 +123,7 @@ def configure_dataset_hparams_for_synthetic( dataset_hparams.use_synthetic = True - if isinstance(model_hparams, TransformerHparams): - if type(model_hparams) not in model_hparams_to_tokenizer_family: - raise ValueError(f'Model {type(model_hparams)} is currently not supported for synthetic testing!') - + if model_hparams and type(model_hparams) in model_hparams_to_tokenizer_family: tokenizer_family = model_hparams_to_tokenizer_family[type(model_hparams)] assert isinstance(dataset_hparams, (GLUEHparams, LMDatasetHparams)) dataset_hparams.tokenizer_name = tokenizer_family diff --git a/tests/common/models.py b/tests/common/models.py index ef9baddd90..b8cbba2dab 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -15,9 +15,8 @@ from composer.models.deeplabv3.deeplabv3_hparams import DeepLabV3Hparams from composer.models.gpt2.gpt2_hparams import GPT2Hparams from composer.models.model_hparams import ModelHparams -from composer.models.transformer_hparams import TransformerHparams -model_hparams_to_tokenizer_family: Dict[Type[TransformerHparams], str] = { +model_hparams_to_tokenizer_family: Dict[Type[ModelHparams], str] = { GPT2Hparams: 'gpt2', BERTForClassificationHparams: 'bert', BERTHparams: 'bert' @@ -122,10 +121,8 @@ def initialize_object(self) -> SimpleConvModel: def configure_model_hparams_for_synthetic(model_hparams: ModelHparams) -> None: # configure Transformer-based models for synthetic testing - if isinstance(model_hparams, TransformerHparams): - if type(model_hparams) not in model_hparams_to_tokenizer_family: - raise ValueError(f'Model {type(model_hparams)} is currently not supported for synthetic testing!') - + if type(model_hparams) in model_hparams_to_tokenizer_family.keys(): + assert isinstance(model_hparams, (BERTHparams, GPT2Hparams, BERTForClassificationHparams)) tokenizer_family = model_hparams_to_tokenizer_family[type(model_hparams)] # force a non-pretrained model @@ -134,7 +131,6 @@ def configure_model_hparams_for_synthetic(model_hparams: ModelHparams) -> None: # generate tokenizers and synthetic models tokenizer = generate_synthetic_tokenizer(tokenizer_family=tokenizer_family) - model_hparams.tokenizer_name = None model_hparams.model_config = generate_dummy_model_config(type(model_hparams), tokenizer) # configure DeepLabV3 models for synthetic testing diff --git a/tests/fixtures/synthetic_hf_state.py b/tests/fixtures/synthetic_hf_state.py index db75634f8b..33ba10e0f7 100644 --- a/tests/fixtures/synthetic_hf_state.py +++ b/tests/fixtures/synthetic_hf_state.py @@ -9,7 +9,7 @@ from composer.datasets.dataset_hparams import DataLoaderHparams from composer.datasets.lm_dataset_hparams import LMDatasetHparams from composer.datasets.synthetic_lm import generate_synthetic_tokenizer, synthetic_hf_dataset_builder -from composer.models import BERTHparams, GPT2Hparams, TransformerHparams +from composer.models import BERTHparams, GPT2Hparams, create_bert_mlm, create_gpt2 from tests.common.models import generate_dummy_model_config from tests.datasets import test_synthetic_lm_data @@ -33,18 +33,18 @@ def make_lm_tokenizer(config: dict): def make_dummy_lm(model_name: str, max_position_embeddings: int, tokenizer): - pytest.importorskip('transformers') - class_name = TransformerHparams if model_name == 'gpt2': class_name = GPT2Hparams + model_config = generate_dummy_model_config(class_name, tokenizer) + model_config['max_position_embeddings'] = max_position_embeddings + model = create_gpt2(model_config=model_config) elif model_name == 'bert': class_name = BERTHparams + model_config = generate_dummy_model_config(class_name, tokenizer) + model_config['max_position_embeddings'] = max_position_embeddings + model = create_bert_mlm(model_config=model_config) else: raise ValueError("Model name must be one of 'gpt2' or 'bert'") - model_config = generate_dummy_model_config(class_name, tokenizer) - model_config['max_position_embeddings'] = max_position_embeddings - model = class_name(model_config=model_config).initialize_object() - model.eval() return model