Skip to content

Commit

Permalink
Revision of BYOL module and tests (#874)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: otaj <ota@lightning.ai>
  • Loading branch information
3 people authored Sep 28, 2022
1 parent 6f58d71 commit d8ff64f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 136 deletions.
180 changes: 83 additions & 97 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,36 @@

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor
from torch.nn import functional as F
from torch.optim import Adam

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.utils.stability import under_review


@under_review()
class BYOL(LightningModule):
"""PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_
Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.
Args:
learning_rate (float, optional): optimizer learning rate. Defaults to 0.2.
weight_decay (float, optional): optimizer weight decay. Defaults to 1.5e-6.
warmup_epochs (int, optional): number of epochs for scheduler warmup. Defaults to 10.
max_epochs (int, optional): maximum number of epochs for scheduler. Defaults to 1000.
base_encoder (Union[str, torch.nn.Module], optional): base encoder architecture. Defaults to "resnet50".
encoder_out_dim (int, optional): base encoder output dimension. Defaults to 2048.
projector_hidden_dim (int, optional): projector MLP hidden dimension. Defaults to 4096.
projector_out_dim (int, optional): projector MLP output dimension. Defaults to 256.
initial_tau (float, optional): initial value of target decay rate used. Defaults to 0.996.
Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_
.. warning:: Work in progress. This implementation is still being verified.
TODOs:
- verify on CIFAR-10
- verify on STL-10
- pre-train on imagenet
Example::
model = BYOL(num_classes=10)
Expand All @@ -42,11 +45,6 @@ class BYOL(LightningModule):
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)
Train::
trainer = Trainer()
trainer.fit(model)
CLI command::
# cifar10
Expand All @@ -65,87 +63,82 @@ class BYOL(LightningModule):

def __init__(
self,
num_classes,
learning_rate: float = 0.2,
weight_decay: float = 1.5e-6,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
base_encoder: Union[str, torch.nn.Module] = "resnet50",
encoder_out_dim: int = 2048,
projector_hidden_size: int = 4096,
projector_hidden_dim: int = 4096,
projector_out_dim: int = 256,
**kwargs
):
"""
Args:
datamodule: The datamodule
learning_rate: the learning rate
weight_decay: optimizer weight decay
input_height: image input height
batch_size: the batch size
num_workers: number of workers
warmup_epochs: num of epochs for scheduler warm up
max_epochs: max epochs for scheduler
base_encoder: the base encoder module or resnet name
encoder_out_dim: output dimension of base_encoder
projector_hidden_size: hidden layer size of projector MLP
projector_out_dim: output size of projector MLP
"""
initial_tau: float = 0.996,
**kwargs: Any,
) -> None:

super().__init__()
self.save_hyperparameters(ignore="base_encoder")

self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim)
self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_dim, projector_out_dim)
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()
self.predictor = MLP(projector_out_dim, projector_hidden_dim, projector_out_dim)

def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback = BYOLMAWeightUpdate(initial_tau=initial_tau)

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
"""Add callback to perform exponential moving average weight update on target network."""
self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx)

def forward(self, x):
y, _, _ = self.online_network(x)
return y
def forward(self, x: Tensor) -> Tensor:
"""Returns the encoded representation of a view.
def shared_step(self, batch, batch_idx):
imgs, y = batch
img_1, img_2 = imgs[:2]
Args:
x (Tensor): sample to be encoded
"""
return self.online_network.encode(x)

# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_2)
loss_a = -2 * F.cosine_similarity(h1, z2).mean()
def training_step(self, batch: Any, batch_idx: int) -> Tensor:
"""Complete training loop."""
return self._shared_step(batch, batch_idx, "train")

# Image 2 to image 1 loss
y1, z1, h1 = self.online_network(img_2)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_1)
# L2 normalize
loss_b = -2 * F.cosine_similarity(h1, z2).mean()
def validation_step(self, batch: Any, batch_idx: int) -> Tensor:
"""Complete validation loop."""
return self._shared_step(batch, batch_idx, "val")

# Final loss
total_loss = loss_a + loss_b
def _shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor:
"""Shared evaluation step for training and validation loop."""
imgs, _ = batch
img1, img2 = imgs[:2]

return loss_a, loss_b, total_loss
# Calculate similarity loss in each direction
loss_12 = self.calculate_loss(img1, img2)
loss_21 = self.calculate_loss(img2, img1)

def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
# Calculate total loss
total_loss = loss_12 + loss_21

# log results
self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss})
# Log losses
if step == "train":
self.log_dict({"train_loss_12": loss_12, "train_loss_21": loss_21, "train_loss": total_loss})
elif step == "val":
self.log_dict({"val_loss_12": loss_12, "val_loss_21": loss_21, "val_loss": total_loss})
else:
raise ValueError(f"Step '{step}' is invalid. Must be 'train' or 'val'.")

return total_loss

def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor:
"""Calculates similarity loss between the online network prediction of target network projection.
# log results
self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "val_loss": total_loss})

return total_loss
Args:
v_online (Tensor): Online network view
v_target (Tensor): Target network view
"""
_, z1 = self.online_network(v_online)
h1 = self.predictor(z1)
with torch.no_grad():
_, z2 = self.target_network(v_target)
loss = -2 * F.cosine_similarity(h1, z2).mean()
return loss

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
Expand All @@ -155,30 +148,23 @@ def configure_optimizers(self):
return [optimizer], [scheduler]

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--online_ft", action="store_true", help="run online finetuner")
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"])

(args, _) = parser.parse_known_args()
args = parser.parse_args([])

# Data
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", default=8, type=int)
if "max_epochs" in args:
parser.set_defaults(max_epochs=1000)
else:
parser.add_argument("--max_epochs", type=int, default=1000)

# optim
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--learning_rate", type=float, default=0.2)
parser.add_argument("--weight_decay", type=float, default=1.5e-6)
parser.add_argument("--warmup_epochs", type=float, default=10)

# Model
parser.add_argument("--warmup_epochs", type=int, default=10)
parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet")

return parser


@under_review()
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
Expand All @@ -188,23 +174,19 @@ def cli_main():

parser = ArgumentParser()

# trainer args
parser = Trainer.add_argparse_args(parser)

# model args
parser = BYOL.add_model_specific_args(parser)
args = parser.parse_args()
parser = CIFAR10DataModule.add_dataset_specific_args(parser)
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"])

# pick data
dm = None
args = parser.parse_args()

# init default datamodule
# Initialize datamodule
if args.dataset == "cifar10":
dm = CIFAR10DataModule.from_argparse_args(args)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
args.num_classes = dm.num_classes

elif args.dataset == "stl10":
dm = STL10DataModule.from_argparse_args(args)
dm.train_dataloader = dm.train_dataloader_mixed
Expand All @@ -214,20 +196,24 @@ def cli_main():
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes

elif args.dataset == "imagenet2012":
dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = dm.dims
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes
else:
raise ValueError(
f"{args.dataset} is not a valid dataset. Dataset must be 'cifar10', 'stl10', or 'imagenet2012'."
)

model = BYOL(**args.__dict__)
# Initialize BYOL module
model = BYOL(**vars(args))

# finetune in real-time
online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes)

trainer = Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval])
trainer = Trainer.from_argparse_args(args, callbacks=[online_eval])

trainer.fit(model, datamodule=dm)

Expand Down
85 changes: 60 additions & 25 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,78 @@
from torch import nn
from typing import Tuple, Union

from torch import Tensor, nn

from pl_bolts.utils.self_supervised import torchvision_ssl_encoder
from pl_bolts.utils.stability import under_review


@under_review()
class MLP(nn.Module):
def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256):
"""MLP architecture used as projectors in online and target networks and predictors in the online network.
Args:
input_dim (int, optional): Input dimension. Defaults to 2048.
hidden_dim (int, optional): Hidden layer dimension. Defaults to 4096.
output_dim (int, optional): Output dimension. Defaults to 256.
Note:
Default values for input, hidden, and output dimensions are based on values used in BYOL.
"""

def __init__(self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256) -> None:

super().__init__()
self.output_dim = output_dim
self.input_dim = input_dim

self.model = nn.Sequential(
nn.Linear(input_dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.Linear(input_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, output_dim, bias=True),
nn.Linear(hidden_dim, output_dim, bias=True),
)

def forward(self, x):
x = self.model(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self.model(x)


@under_review()
class SiameseArm(nn.Module):
def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256):
"""SiameseArm consolidates the encoder and projector networks of BYOL's symmetric architecture into a single
class.
Args:
encoder (Union[str, nn.Module], optional): Online and target network encoder architecture.
Defaults to "resnet50".
encoder_out_dim (int, optional): Output dimension of encoder. Defaults to 2048.
projector_hidden_dim (int, optional): Online and target network projector network hidden dimension.
Defaults to 4096.
projector_out_dim (int, optional): Online and target network projector network output dimension.
Defaults to 256.
"""

def __init__(
self,
encoder: Union[str, nn.Module] = "resnet50",
encoder_out_dim: int = 2048,
projector_hidden_dim: int = 4096,
projector_out_dim: int = 256,
) -> None:

super().__init__()

if isinstance(encoder, str):
encoder = torchvision_ssl_encoder(encoder)
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim)
# Predictor
self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim)

def forward(self, x):
self.encoder = torchvision_ssl_encoder(encoder)
else:
self.encoder = encoder

self.projector = MLP(encoder_out_dim, projector_hidden_dim, projector_out_dim)

def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
y = self.encoder(x)[0]
z = self.projector(y)
h = self.predictor(z)
return y, z, h
return y, z

def encode(self, x: Tensor) -> Tensor:
"""Returns the encoded representation of a view. This method does not calculate the projection as in the
forward method.
Args:
x (Tensor): sample to be encoded
"""
return self.encoder(x)[0]
Loading

0 comments on commit d8ff64f

Please sign in to comment.