Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.
Here are common terms you need to be familiar with:
Term | Definition |
---|---|
Finetuning | The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset |
Transfer learning | The common name for finetuning |
Backbone | The neural network that was pretrained on a different dataset |
Head | Another neural network (usually smaller) that maps the backbone to your particular dataset |
Freeze | Disabling gradient updates to a model (ie: not learning) |
Unfreeze | Enabling gradient updates to a model |
From the :ref:`quick_start` guide.
.. testsetup:: strategies import flash from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") datamodule = ImageClassificationData.from_files( train_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"], train_targets=[0], batch_size=1, num_workers=0, ) model = ImageClassifier(backbone="resnet18", num_classes=2) trainer = flash.Trainer(max_epochs=1)
Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.
Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.
In this strategy, the backbone and the head are unfrozen from the beginning.
.. testcode:: strategies trainer.finetune(model, datamodule, strategy="no_freeze")
.. testoutput:: strategies :hide: ...
In pseudocode, this looks like:
backbone = Resnet50()
head = nn.Linear(...)
backbone.unfreeze()
head.unfreeze()
train(backbone, head)
The freeze strategy keeps the backbone frozen throughout.
.. testcode:: strategies trainer.finetune(model, datamodule, strategy="freeze")
The pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head)
Every finetune strategy can also be customized.
The freeze_unfreeze
strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy
argument) after which the backbone will be unfrozen.
For example, to unfreeze after epoch 7:
.. testcode:: strategies trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=10)
# unfreeze after 7 epochs
backbone.unfreeze()
train(backbone, head)
This strategy allows you to unfreeze part of the backbone at predetermined intervals.
Here's an example where:
- backbone starts frozen
- at epoch 3 the last 2 layers unfreeze
- at epoch 8 the full backbone unfreezes
.. testcode:: strategies trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=3)
# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)
train(backbone, head, epochs=8)
# unfreeze the full backbone
backbone.unfreeze()
For even more customization, create your own finetuning callback. Learn more about callbacks here.
.. testcode:: strategies from flash.core.finetuning import FlashBaseFinetuning # Create a finetuning callback class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning): def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True): # this will set self.attr_names as ["backbone"] super().__init__("backbone", train_bn) self._unfreeze_epoch = unfreeze_epoch def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx): # unfreeze any module you want by overriding this function # When ``current_epoch`` is 5, backbone will start to be trained. if current_epoch == self._unfreeze_epoch: self.unfreeze_and_extend_param_group( pl_module.backbone, optimizer, ) # Pass the callback to trainer.finetune trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))
If you are using DeepSpeed, you can use the following strategies. The usage of the following strategies is the same as listed above, but finetuning with DeepSpeed doesn't yet support the loading and storing of its parameters.
freeze_deepspeed
no_freeze_deepspeed
freeze_unfreeze_deepspeed
unfreeze_milestones_deepspeed