Skip to content

Commit

Permalink
update swav to override optimizer_step with optimizer.step(closure=op… (
Browse files Browse the repository at this point in the history
#323)

* update swav to override optimizer_step with optimizer.step(closure=optimizer_closure)

* added defaults from optimizer_step def

* imports

* removed TODOs

* imports

* optimizer_step cleanup

* imports
  • Loading branch information
ananyahjha93 authored Nov 2, 2020
1 parent 94ddcae commit 978fa1c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 20 deletions.
43 changes: 25 additions & 18 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.optimizer import Optimizer

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18
from typing import Callable, Optional
from pytorch_lightning.utilities import AMPType

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, cifar10_normalization
from pl_bolts.optimizers.lars_scheduling import LARSWrapper

Expand Down Expand Up @@ -321,15 +324,15 @@ def configure_optimizers(self):

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
second_order_closure=None,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False
):
epoch: int,
batch_idx: int,
optimizer: Optimizer,
optimizer_idx: int,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class
# adjust LR of optim contained within LARSWrapper
if self.lars_wrapper:
Expand All @@ -340,14 +343,18 @@ def optimizer_step(
param_group["lr"] = self.lr_schedule[self.trainer.global_step]

# log LR (LearningRateLogger callback doesn't work with LARSWrapper)
learning_rate = {'learning_rate': self.lr_schedule[self.trainer.global_step]}
self.logger.log_metrics(learning_rate, step=self.trainer.global_step)

# from lightning implementation
if using_native_amp:
self.trainer.scaler.step(optimizer)
else:
optimizer.step()
self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False)

super().optimizer_step(
epoch=epoch,
batch_idx=batch_idx,
optimizer=optimizer,
optimizer_idx=optimizer_idx,
optimizer_closure=optimizer_closure,
on_tpu=on_tpu,
using_native_amp=using_native_amp,
using_lbfgs=using_lbfgs,
)

def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
Expand Down
6 changes: 4 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,7 @@ def test_swav(tmpdir):
gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3
)

results = trainer.fit(model, datamodule)
assert results == 1
trainer.fit(model, datamodule)
loss = trainer.progress_bar_dict['loss']

assert float(loss) > 0

0 comments on commit 978fa1c

Please sign in to comment.