Skip to content

Commit

Permalink
Add docstrings for ModelTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik committed Dec 4, 2024
1 parent 39d4952 commit 5fa9035
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
5 changes: 5 additions & 0 deletions flair/trainers/plugins/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def __init__(self, *, plugins: Sequence[PluginArgument] = []) -> None:

@property
def plugins(self):
"""Returns all plugins attached to this instance as a list of :class:`BasePlugin`.
Returns:
List of :class:`BasePlugin` instances attached to this `Pluggable`.
"""
return self._plugins

def append_plugin(self, plugin):
Expand Down
27 changes: 23 additions & 4 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@


class ModelTrainer(Pluggable):
"""Use this class to train a Flair model.
The ModelTrainer is initialized using a :class:`flair.nn.Model` (the architecture you want to train) and a
:class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model). It offers two main training
functions for the two main modes of training a model: (1) :func:`train`, which is used to train a model from scratch or
to fit a classification head on a frozen transformer language model. (2) :func:`fine_tune`, which is used if you
do not freeze the transformer language model and rather fine-tune it for a specific task.
Additionally, there is also a `train_custom` method that allows you to fully customize the training run.
ModelTrainer inherits from :class:`flair.trainers.plugins.base.Pluggable` and thus uses a plugin system to inject
specific functionality into the training process. You can add any number of plugins to the above-mentioned training
modes. For instance, if you want to use an annealing scheduler during training, you can add the
:class:`flair.trainers.plugins.functional.AnnealingPlugin` plugin to the train command.
"""

valid_events = {
"after_setup",
"before_training_epoch",
Expand All @@ -59,11 +75,14 @@ class ModelTrainer(Pluggable):
}

def __init__(self, model: flair.nn.Model, corpus: Corpus) -> None:
"""Initialize a model trainer.
"""Initialize a model trainer by passing a :class:`flair.nn.Model` (the architecture you want to train) and a
:class:`flair.data.Corpus` (the labeled data you use to train and evaluate the model).
Args:
model: The model that you want to train. The model should inherit from flair.nn.Model # noqa: E501
corpus: The dataset used to train the model, should be of type Corpus
model: The model that you want to train. The model should inherit from :class:`flair.nn.Model`. So for
instance you should pass a :class:`flair.models.TextClassifier` if you want to train a text classifier,
or :class:`flair.models.SequenceLabeler` if you want to train an RNN-based sequence labeler.
corpus: The dataset (of type :class:`flair.data.Corpus`) used to train the model.
"""
super().__init__()
self.model: flair.nn.Model = model
Expand Down Expand Up @@ -346,7 +365,7 @@ def train_custom(
plugins: Optional[list[TrainerPlugin]] = None,
**kwargs,
) -> dict:
"""Trains any class that implements the flair.nn.Model interface.
"""Trains any class that implements the :class:`flair.nn.Model` interface.
Args:
base_path: Main path to which all output during training is logged and models are saved
Expand Down

0 comments on commit 5fa9035

Please sign in to comment.