Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add better lars scheduling #162

Merged
merged 10 commits into from
Aug 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from pl_bolts.datamodules.cifar10_dataset import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
import os


class CIFAR10DataModule(LightningDataModule):
Expand All @@ -17,7 +18,7 @@ class CIFAR10DataModule(LightningDataModule):

def __init__(
self,
data_dir,
data_dir: str = None,
val_split: int = 5000,
num_workers: int = 16,
batch_size: int = 32,
Expand Down Expand Up @@ -73,11 +74,11 @@ def __init__(
super().__init__(*args, **kwargs)
self.dims = (3, 32, 32)
self.DATASET = CIFAR10
self.data_dir = data_dir
self.val_split = val_split
self.num_workers = num_workers
self.batch_size = batch_size
self.seed = seed
self.data_dir = data_dir if data_dir is not None else os.getcwd()

@property
def num_classes(self):
Expand Down
5 changes: 3 additions & 2 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, random_split
Expand All @@ -14,7 +15,7 @@ class STL10DataModule(LightningDataModule): # pragma: no cover

def __init__(
self,
data_dir: str,
data_dir: str = None,
unlabeled_val_split: int = 5000,
train_val_split: int = 500,
num_workers: int = 16,
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
"""
super().__init__(*args, **kwargs)
self.dims = (3, 96, 96)
self.data_dir = data_dir
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.unlabeled_val_split = unlabeled_val_split
self.train_val_split = train_val_split
self.num_workers = num_workers
Expand Down
97 changes: 59 additions & 38 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from copy import deepcopy
import torch
import torch.nn.functional as F
from torch.optim import Adam
import pytorch_lightning as pl
from typing import Any

from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.optimizers.layer_adaptive_scaling import LARS
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate, SSLOnlineEvaluator


class BYOL(pl.LightningModule):
def __init__(self,
datamodule: pl.LightningDataModule = None,
num_classes,
data_dir: str = './',
learning_rate: float = 0.2,
weight_decay: float = 15e-6,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 4,
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
**kwargs):
Expand All @@ -42,11 +43,24 @@ def __init__(self,
- verify on STL-10
- pre-train on imagenet

Example:
Example::

>>> from pl_bolts.models.self_supervised import BYOL
...
>>> model = BYOL()
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import BYOL
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)

# model
model = BYOL(num_classes=10)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

trainer = pl.Trainer()
trainer.fit(model, dm)

Train::

Expand Down Expand Up @@ -80,23 +94,10 @@ def __init__(self,
super().__init__()
self.save_hyperparameters()

# init default datamodule
if datamodule is None:
datamodule = CIFAR10DataModule(data_dir, num_workers=num_workers, batch_size=batch_size)
datamodule.train_transforms = SimCLRTrainDataTransform(input_height)
datamodule.val_transforms = SimCLREvalDataTransform(input_height)

self.datamodule = datamodule

self.online_network = SiameseArm()
self.target_network = deepcopy(self.online_network)

self.weight_callback = BYOLMAWeightUpdate()

# for finetuning callback
self.z_dim = 2048
self.num_classes = self.datamodule.num_classes

def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback.on_batch_end(self.trainer, self)
Expand Down Expand Up @@ -150,7 +151,8 @@ def validation_step(self, batch, batch_idx):
return result

def configure_optimizers(self):
optimizer = LARS(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = LARSWrapper(optimizer)
scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.hparams.warmup_epochs,
Expand All @@ -168,7 +170,7 @@ def add_model_specific_args(parent_parser):

# Data
parser.add_argument('--data_dir', type=str, default='.')
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--num_workers', default=0, type=int)

# optim
parser.add_argument('--batch_size', type=int, default=256)
Expand All @@ -195,23 +197,42 @@ def add_model_specific_args(parent_parser):
args = parser.parse_args()

# pick data
datamodule = None
if args.dataset == 'stl10':
datamodule = STL10DataModule.from_argparse_args(args)
datamodule.train_dataloader = datamodule.train_dataloader_mixed
datamodule.val_dataloader = datamodule.val_dataloader_mixed
dm = None

(c, h, w) = datamodule.size()
datamodule.train_transforms = SimCLRTrainDataTransform(h)
datamodule.val_transforms = SimCLREvalDataTransform(h)
# init default datamodule
if args.dataset == 'cifar10':
dm = CIFAR10DataModule.from_argparse_args(args)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

elif args.dataset == 'imagenet2012':
datamodule = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = datamodule.size()
datamodule.train_transforms = SimCLRTrainDataTransform(h)
datamodule.val_transforms = SimCLREvalDataTransform(h)
elif args.dataset == 'stl10':
dm = STL10DataModule.from_argparse_args(args)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

model = BYOL(**args.__dict__, datamodule=datamodule)
(c, h, w) = dm.size()
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes

trainer = pl.Trainer.from_argparse_args(args, max_steps=10000, callbacks=[SSLOnlineEvaluator()])
trainer.fit(model)
elif args.dataset == 'imagenet2012':
dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = dm.size()
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes

model = BYOL(**args.__dict__)

def to_device(batch, device):
(x1, x2), y = batch
x1 = x1.to(device)
y = y.to(device)
return x1, y

# finetune in real-time
online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
online_eval.to_device = to_device

trainer = pl.Trainer.from_argparse_args(args, max_steps=10000, callbacks=[])
trainer.fit(model, dm)
15 changes: 7 additions & 8 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytorch_lightning as pl
import torch
from torch.optim import Adam
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR
Expand All @@ -11,7 +12,7 @@
from pl_bolts.metrics import mean
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.optimizers.layer_adaptive_scaling import LARS
from pl_bolts.optimizers.lars_scheduling import LARSWrapper


class DensenetEncoder(nn.Module):
Expand Down Expand Up @@ -232,13 +233,11 @@ def validation_epoch_end(self, outputs: list):
return dict(val_loss=val_loss, log=log, progress_bar=progress_bar)

def configure_optimizers(self):
if self.hparams.optimizer == 'adam':
optimizer = torch.optim.Adam(
self.parameters(), self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
elif self.hparams.optimizer == 'lars':
optimizer = LARS(
self.parameters(), lr=self.hparams.learning_rate, momentum=self.hparams.lars_momentum,
weight_decay=self.hparams.weight_decay, eta=self.hparams.lars_eta)
optimizer = torch.optim.Adam(
self.parameters(), self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)

if self.hparams.optimizer == 'lars':
optimizer = LARSWrapper(optimizer)
else:
raise ValueError(f'Invalid optimizer: {self.optimizer}')
scheduler = StepLR(
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from pl_bolts.optimizers.layer_adaptive_scaling import LARS
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
88 changes: 88 additions & 0 deletions pl_bolts/optimizers/lars_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
References:
- https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
- https://arxiv.org/pdf/1708.03888.pdf
- https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
"""
import torch
from torch.optim import Optimizer


class LARSWrapper(object):
def __init__(self, optimizer, eta=0.02, clip=True, eps=1e-8):
"""
Wrapper that adds LARS scheduling to any optimizer. This helps stability with huge batch sizes.

Args:
optimizer: torch optimizer
eta: LARS coefficient (trust)
clip: True to clip LR
eps: adaptive_lr stability coefficient
"""
self.optim = optimizer
self.eta = eta
self.eps = eps
self.clip = clip

# transfer optim methods
self.state_dict = self.optim.state_dict
self.load_state_dict = self.optim.load_state_dict
self.zero_grad = self.optim.zero_grad
self.add_param_group = self.optim.add_param_group
self.__setstate__ = self.optim.__setstate__
self.__getstate__ = self.optim.__getstate__
self.__repr__ = self.optim.__repr__

@property
def __class__(self):
return Optimizer

@property
def state(self):
return self.optim.state

@property
def param_groups(self):
return self.optim.param_groups

@param_groups.setter
def param_groups(self, value):
self.optim.param_groups = value

@torch.no_grad()
def step(self):
weight_decays = []

for group in self.optim.param_groups:
weight_decay = group.get('weight_decay', 0)
weight_decays.append(weight_decay)

# reset weight decay
group['weight_decay'] = 0

# update the parameters
[self.update_p(p, group, weight_decay) for p in group['params'] if p.grad is not None]

# update the optimizer
self.optim.step()

# return weight decay control to optimizer
for group_idx, group in enumerate(self.optim.param_groups):
group['weight_decay'] = weight_decays[group_idx]

def update_p(self, p, group, weight_decay):
# calculate new norms
p_norm = torch.norm(p.data)
g_norm = torch.norm(p.grad.data)

if p_norm != 0 and g_norm != 0:
# calculate new lr
new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)

# clip lr
if self.clip:
new_lr = min(new_lr / group['lr'], 1)

# update params with clipped lr
p.grad.data += weight_decay * p.data
p.grad.data *= new_lr
Loading