Skip to content

Commit

Permalink
refactor bert and gpt (#1130)
Browse files Browse the repository at this point in the history
Converting the composer GPT2 and BERT models to use a HuggingFaceModel base class compatible with any pretrained model off the HF Hub. Converted the model creations to factory functions and moved all logic out of the hparams classes.
  • Loading branch information
A-Jacobson authored Jun 28, 2022
1 parent 0ca7cee commit 99cf7d2
Show file tree
Hide file tree
Showing 26 changed files with 483 additions and 537 deletions.
11 changes: 7 additions & 4 deletions composer/algorithms/gated_linear_units/gated_linear_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

import torch

from composer.models.huggingface import HuggingFaceModel

try:
from transformers import BertForMaskedLM, BertForSequenceClassification
from transformers.models.bert.modeling_bert import BertIntermediate, BertOutput
IS_TRANSFORMERS_INSTALLED = True
except ImportError as e:
Expand All @@ -21,7 +24,6 @@
from composer.algorithms.warnings import NoEffectWarning
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.models import BERTModel
from composer.utils import MissingConditionalImportError, module_surgery

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,9 +81,10 @@ def apply_gated_linear_units(model: torch.nn.Module,
if not IS_TRANSFORMERS_INSTALLED:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers')

# ensure that the model is an instance of a BERTModel, since our replacement policy is only defined for BERTs
if not isinstance(model, BERTModel):
raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERTModel.')
# ensure that the model is an instance of a BERT model, since our replacement policy is only defined for BERTs
if not isinstance(model, HuggingFaceModel) and not (isinstance(model.model, BertForMaskedLM) or
isinstance(model.model, BertForSequenceClassification)):
raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERT models.')

if act_fn is None:
# get the activation functions used
Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from composer.core.time import TimeUnit
from composer.core.types import Batch
from composer.loggers import Logger
from composer.models import ComposerTransformer
from composer.models import HuggingFaceModel
from composer.utils import ensure_tuple

__all__ = ['SeqLengthWarmup', 'set_batch_sequence_length']
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
self.truncate = truncate

if self.duration < 0 or self.duration > 1:
raise ValueError(f'Duration must be getween 0 and 1, got: {self.duration}')
raise ValueError(f'Duration must be between 0 and 1, got: {self.duration}')

if self.max_seq_length < self.min_seq_length:
raise ValueError(f'max_seq_length={self.max_seq_length} must be '
Expand All @@ -176,10 +176,10 @@ def match(self, event: Event, state: State) -> bool:

def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
if event == Event.INIT:
if not isinstance(state.model, ComposerTransformer):
if not isinstance(state.model, HuggingFaceModel):
raise RuntimeError(
textwrap.dedent(f"""\
{type(self).__name__} requires state.model to be of type {ComposerTransformer.__name__}, not of type {type(state.model)}"""
{type(self).__name__} requires state.model to be of type {HuggingFaceModel.__name__}, not of type {type(state.model)}"""
))

self._original_model = state.model
Expand Down
7 changes: 3 additions & 4 deletions composer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
from composer.models.base import ComposerModel as ComposerModel
from composer.models.bert import BERTForClassificationHparams as BERTForClassificationHparams
from composer.models.bert import BERTHparams as BERTHparams
from composer.models.bert import BERTModel as BERTModel
from composer.models.bert import create_bert_classification as create_bert_classification
from composer.models.bert import create_bert_mlm as create_bert_mlm
from composer.models.classify_mnist import MNIST_Classifier as MNIST_Classifier
from composer.models.classify_mnist import MnistClassifierHparams as MnistClassifierHparams
from composer.models.deeplabv3 import ComposerDeepLabV3 as ComposerDeepLabV3
from composer.models.deeplabv3 import DeepLabV3Hparams as DeepLabV3Hparams
from composer.models.efficientnetb0 import EfficientNetB0 as EfficientNetB0
from composer.models.efficientnetb0 import EfficientNetB0Hparams as EfficientNetB0Hparams
from composer.models.gpt2 import GPT2Hparams as GPT2Hparams
from composer.models.gpt2 import GPT2Model as GPT2Model
from composer.models.gpt2 import create_gpt2 as create_gpt2
from composer.models.huggingface import HuggingFaceModel as HuggingFaceModel
from composer.models.initializers import Initializer as Initializer
from composer.models.model_hparams import ModelHparams as ModelHparams
Expand All @@ -33,8 +34,6 @@
from composer.models.tasks import ComposerClassifier as ComposerClassifier
from composer.models.timm import Timm as Timm
from composer.models.timm import TimmHparams as TimmHparams
from composer.models.transformer_hparams import TransformerHparams as TransformerHparams
from composer.models.transformer_shared import ComposerTransformer as ComposerTransformer
from composer.models.unet import UNet as UNet
from composer.models.unet import UnetHparams as UnetHparams
from composer.models.vit_small_patch16 import ViTSmallPatch16 as ViTSmallPatch16
Expand Down
5 changes: 3 additions & 2 deletions composer/models/bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from composer.models.bert.bert_hparams import BERTForClassificationHparams as BERTForClassificationHparams
from composer.models.bert.bert_hparams import BERTHparams as BERTHparams
from composer.models.bert.model import BERTModel as BERTModel
from composer.models.bert.model import create_bert_classification as create_bert_classification
from composer.models.bert.model import create_bert_mlm as create_bert_mlm

__all__ = ['BERTModel', 'BERTHparams', 'BERTForClassificationHparams']
__all__ = ['BERTHparams', 'BERTForClassificationHparams', 'create_bert_classification', 'create_bert_mlm']
174 changes: 71 additions & 103 deletions composer/models/bert/bert_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,130 +5,98 @@
:class:`.BERTModel`."""

from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Dict, Optional

import yahp as hp

from composer.models.transformer_hparams import TransformerHparams
from composer.utils import MissingConditionalImportError

if TYPE_CHECKING:
from composer.models.bert import BERTModel
from composer.core.types import JSON
from composer.models.model_hparams import ModelHparams

__all__ = ['BERTForClassificationHparams', 'BERTHparams']


@dataclass
class BERTForClassificationHparams(TransformerHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ classification interface for
:class:`.BERTModel`.
class BERTHparams(ModelHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.BERTModel`.
Args:
pretrained_model_name (str): Pretrained model name to pull from Hugging Face Model Hub.
model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration.
tokenizer_name (Optional[str]): The tokenizer used for this model,
necessary to assert required model inputs.
model_config (Dict[str, JSON], optional): A dictionary providing a HuggingFace model configuration.
pretrained_model_name (str, optional): Pretrained model name to pull from Hugging Face Model Hub.
use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights.
gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``.
num_labels (int, optional): The number of classes in the segmentation task. Default: ``2``.
tokenizer_name (str, optional): The tokenizer used for this model,
necessary to assert required model inputs. Default ``None``.
gradient_checkpointing (bool, optional): Use gradient checkpointing. default: False.
"""
num_labels: int = hp.optional(doc='The number of possible labels for the task.', default=2)

def validate(self):
if self.num_labels < 1:
raise ValueError('The number of target labels must be at least one.')

def initialize_object(self) -> 'BERTModel':
try:
import transformers
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e

from composer.models.bert.model import BERTModel
self.validate()

model_hparams = {'num_labels': self.num_labels}

if self.model_config:
config = transformers.BertConfig.from_dict(self.model_config, **model_hparams)
elif self.pretrained_model_name is not None:
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name, **model_hparams)
else:
raise ValueError('One of pretrained_model_name or model_config needed.')
config.num_labels = self.num_labels

# setup the tokenizer in the hparams interface
if self.tokenizer_name is not None:
tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name)
else:
tokenizer = None

if self.use_pretrained:
# TODO (Moin): handle the warnings on not using the seq_relationship head
assert transformers.AutoModelForSequenceClassification.from_pretrained is not None, 'from_pretrained should not be None'
model = transformers.AutoModelForSequenceClassification.from_pretrained(self.pretrained_model_name,
**model_hparams)
else:
# an invariant to ensure that we don't lose keys when creating the HF config
for k, v in model_hparams.items():
assert getattr(config, k) == v
model = transformers.AutoModelForSequenceClassification.from_config( #type: ignore (thirdparty)
config)

return BERTModel(
module=model,
config=config, #type: ignore (thirdparty)
tokenizer=tokenizer,
model_config: Optional[Dict[str,
JSON]] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.',
default_factory=dict)
pretrained_model_name: Optional[str] = hp.optional(doc='Pretrained model name to pull from Hugging Face Model Hub.',
default=None)
use_pretrained: Optional[bool] = hp.optional('Whether to initialize the model with the pretrained weights.',
default=False)
tokenizer_name: Optional[str] = hp.optional(
'The tokenizer used for this model, necessary to assert required model inputs.', default=None)
gradient_checkpointing: Optional[bool] = hp.optional('Whether to enable gradient checkpointing.', default=False)

def initialize_object(self):
from composer.models.bert.model import create_bert_mlm

# user must specify one of either config or the pretrained model
if not self.pretrained_model_name and self.model_config == {}:
raise Exception('One of pretrained_model_name or model_config needed.')

if self.use_pretrained and self.model_config:
raise Exception('A model cannot load pretrained weights from configuration.')

return create_bert_mlm(
model_config=self.model_config, # type: ignore (thirdparty)
pretrained_model_name=self.pretrained_model_name,
use_pretrained=self.use_pretrained,
tokenizer_name=self.tokenizer_name,
gradient_checkpointing=self.gradient_checkpointing,
)


@dataclass
class BERTHparams(TransformerHparams):
class BERTForClassificationHparams(ModelHparams):
"""`YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for :class:`.BERTModel`.
Args:
pretrained_model_name (str): "Pretrained model name to pull from Huggingface Model Hub."
num_labels (int, optional): The number of classes in the classification task. Default: ``2``.
pretrained_model_name (str, optional): Pretrained model name to pull from Hugging Face Model Hub.
model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration.
tokenizer_name (str): The tokenizer used for this model,
necessary to assert required model inputs.
use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights.
tokenizer_name (Optional[str]): The tokenizer used for this model,
necessary to assert required model inputs. Default ``None``.
gradient_checkpointing (bool, optional): Use gradient checkpointing. default: False.
"""

def initialize_object(self) -> 'BERTModel':
try:
import transformers
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp', conda_package='transformers') from e

from composer.models.bert.model import BERTModel
self.validate()

if self.model_config:
config = transformers.BertConfig.from_dict(self.model_config)
elif self.pretrained_model_name is not None:
config = transformers.BertConfig.from_pretrained(self.pretrained_model_name)
else:
raise ValueError('One of pretrained_model_name or model_config needed.')

# set the number of labels ot the vocab size, used for measuring MLM accuracy
config.num_labels = config.vocab_size

# setup the tokenizer in the hparams interface
if self.tokenizer_name is not None:
tokenizer = transformers.BertTokenizer.from_pretrained(self.tokenizer_name)
else:
tokenizer = None

if self.use_pretrained:
# TODO (Moin): handle the warnings on not using the seq_relationship head
assert transformers.AutoModelForMaskedLM.from_pretrained is not None, 'from_pretrained should not be None'
model = transformers.AutoModelForMaskedLM.from_pretrained(self.pretrained_model_name)
else:
model = transformers.AutoModelForMaskedLM.from_config(config) #type: ignore (thirdparty)

return BERTModel(
module=model,
config=config, #type: ignore (thirdparty)
tokenizer=tokenizer,
num_labels: Optional[int] = hp.optional(doc='The number of possible labels for the task.', default=2)
pretrained_model_name: Optional[str] = hp.optional(doc='Pretrained model name to pull from Hugging Face Model Hub.',
default=None)
model_config: Optional[Dict[str,
JSON]] = hp.optional(doc='A dictionary providing a HuggingFace model configuration.',
default_factory=dict)
use_pretrained: Optional[bool] = hp.optional('Whether to initialize the model with the pretrained weights.',
default=False)
tokenizer_name: Optional[str] = hp.optional(
'The tokenizer used for this model, necessary to assert required model inputs.', default=None)
gradient_checkpointing: Optional[bool] = hp.optional('Whether to enable gradient checkpointing.', default=False)

def initialize_object(self):
from composer.models.bert.model import create_bert_classification

# user must specify one of either config or the pretrained model
if not self.pretrained_model_name and self.model_config == {}:
raise Exception('One of pretrained_model_name or model_config needed.')

if self.use_pretrained and self.model_config:
raise Exception('A model cannot load pretrained weights from configuration.')

return create_bert_classification(
num_labels=self.num_labels,
pretrained_model_name=self.pretrained_model_name,
model_config=self.model_config, # type: ignore (thirdparty)
use_pretrained=self.use_pretrained,
tokenizer_name=self.tokenizer_name,
gradient_checkpointing=self.gradient_checkpointing,
)
Loading

0 comments on commit 99cf7d2

Please sign in to comment.