Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix unfreeze strategies with onecyclelr and reduced lr (#1329)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored May 6, 2022
1 parent 48a2500 commit 1ed7938
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where a loaded `TabularClassifier` or `TabularRegressor` checkpoint could not be served ([#1324](https://github.com/PyTorchLightning/lightning-flash/pull/1324))

- Fixed a bug where the `freeze_unfreeze` and `unfreeze_milestones` finetuning strategies could not be used in tandem with a `onecyclelr` LR scheduler ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

- Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the `freeze_unfreeze` or `unfreeze_milestones` strategies ([#1329](https://github.com/PyTorchLightning/lightning-flash/pull/1329))

## [0.7.4] - 2022-04-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/general/finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ For even more customization, create your own finetuning callback. Learn more abo

# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
pl_module.backbone,
optimizer,
)
Expand Down
19 changes: 16 additions & 3 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module,
modules = [modules]
self.freeze(modules=modules, train_bn=self.train_bn)

def unfreeze_and_extend_param_group(
self,
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
train_bn: bool = True,
) -> None:
self.make_trainable(modules)

params = self.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = self.filter_on_optimizer(optimizer, params)
if params:
optimizer.param_groups[0]["params"].extend(params)

def _freeze_unfreeze_function(
self,
pl_module: Union[Module, Iterable[Union[Module, Iterable]]],
Expand All @@ -117,7 +130,7 @@ def _freeze_unfreeze_function(

modules = self._get_modules_to_freeze(pl_module=pl_module)
if modules is not None:
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand All @@ -140,15 +153,15 @@ def _unfreeze_milestones_function(
# unfreeze num_layers last layers

backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
elif epoch == unfreeze_milestones[1]:
# unfreeze remaining layers
backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers]
self.unfreeze_and_add_param_group(
self.unfreeze_and_extend_param_group(
modules=backbone_modules,
optimizer=optimizer,
train_bn=self.train_bn,
Expand Down
37 changes: 31 additions & 6 deletions tests/core/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,45 @@ def test_finetuning_with_none_return_type(strategy):

@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.")
@pytest.mark.parametrize(
("strategy", "checker_class", "checker_class_data"),
("strategy", "lr_scheduler", "checker_class", "checker_class_data"),
[
("no_freeze", None, {}),
("freeze", FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
("no_freeze", None, None, {}),
("freeze", None, FreezeStrategyChecking, {}),
(("freeze_unfreeze", 2), None, FreezeUnfreezeStrategyChecking, {"check_epoch": 2}),
(
("unfreeze_milestones", ((1, 3), 1)),
None,
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
(
"no_freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
None,
{},
),
(
"freeze",
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeStrategyChecking,
{},
),
(
("freeze_unfreeze", 2),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
FreezeUnfreezeStrategyChecking,
{"check_epoch": 2},
),
(
("unfreeze_milestones", ((1, 3), 1)),
("onecyclelr", {"max_lr": 1e-3, "epochs": 50, "steps_per_epoch": 10}, {"interval": "step"}),
UnfreezeMilestonesStrategyChecking,
{"check_epochs": [1, 3], "num_layers": 1},
),
],
)
def test_finetuning(tmpdir, strategy, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss)
def test_finetuning(tmpdir, strategy, lr_scheduler, checker_class, checker_class_data):
task = TestTaskWithFinetuning(loss_fn=F.nll_loss, lr_scheduler=lr_scheduler, optimizer="sgd", learning_rate=0.1)
callbacks = [] if checker_class is None else checker_class(dirpath=tmpdir, **checker_class_data)
trainer = flash.Trainer(max_epochs=5, limit_train_batches=10, callbacks=callbacks)
ds = DummyDataset()
Expand Down

0 comments on commit 1ed7938

Please sign in to comment.