Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Latest commit

 

History

History
235 lines (153 loc) · 5.82 KB

finetuning.rst

File metadata and controls

235 lines (153 loc) · 5.82 KB

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.


Terminology

Here are common terms you need to be familiar with:

Terminology
Term Definition
Finetuning The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset
Transfer learning The common name for finetuning
Backbone The neural network that was pretrained on a different dataset
Head Another neural network (usually smaller) that maps the backbone to your particular dataset
Freeze Disabling gradient updates to a model (ie: not learning)
Unfreeze Enabling gradient updates to a model

Finetuning in Flash

From the :ref:`quick_start` guide.


Finetune strategies

.. testsetup:: strategies

    import flash
    from flash.core.data.utils import download_data
    from flash.image import ImageClassificationData, ImageClassifier

    download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

    datamodule = ImageClassificationData.from_files(
        train_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"],
        train_targets=[0],
        batch_size=1,
        num_workers=0,
    )

    model = ImageClassifier(backbone="resnet18", num_classes=2)
    trainer = flash.Trainer(max_epochs=1)

Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.

Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.

no_freeze

In this strategy, the backbone and the head are unfrozen from the beginning.

.. testcode:: strategies

    trainer.finetune(model, datamodule, strategy="no_freeze")

.. testoutput:: strategies
    :hide:

    ...

In pseudocode, this looks like:

backbone = Resnet50()
head = nn.Linear(...)

backbone.unfreeze()
head.unfreeze()

train(backbone, head)

freeze

The freeze strategy keeps the backbone frozen throughout.

.. testcode:: strategies

    trainer.finetune(model, datamodule, strategy="freeze")

The pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head)

Advanced strategies

Every finetune strategy can also be customized.

freeze_unfreeze

The freeze_unfreeze strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy argument) after which the backbone will be unfrozen.

For example, to unfreeze after epoch 7:

.. testcode:: strategies

    trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=10)

# unfreeze after 7 epochs
backbone.unfreeze()

train(backbone, head)

unfreeze_milestones

This strategy allows you to unfreeze part of the backbone at predetermined intervals.

Here's an example where:

  • backbone starts frozen
  • at epoch 3 the last 2 layers unfreeze
  • at epoch 8 the full backbone unfreezes

.. testcode:: strategies

    trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))


Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=3)

# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)

train(backbone, head, epochs=8)

# unfreeze the full backbone
backbone.unfreeze()

Custom Strategy

For even more customization, create your own finetuning callback. Learn more about callbacks here.

.. testcode:: strategies

    from flash.core.finetuning import FlashBaseFinetuning


    # Create a finetuning callback
    class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
        def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
            # this will set self.attr_names as ["backbone"]
            super().__init__("backbone", train_bn)
            self._unfreeze_epoch = unfreeze_epoch

        def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
            # unfreeze any module you want by overriding this function

            # When ``current_epoch`` is 5, backbone will start to be trained.
            if current_epoch == self._unfreeze_epoch:
                self.unfreeze_and_extend_param_group(
                    pl_module.backbone,
                    optimizer,
                )


    # Pass the callback to trainer.finetune
    trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))

Working with DeepSpeed

If you are using DeepSpeed, you can use the following strategies. The usage of the following strategies is the same as listed above, but finetuning with DeepSpeed doesn't yet support the loading and storing of its parameters.

  • freeze_deepspeed
  • no_freeze_deepspeed
  • freeze_unfreeze_deepspeed
  • unfreeze_milestones_deepspeed