Skip to content

Commit

Permalink
Swav (#239)
Browse files Browse the repository at this point in the history
* swav

* swav

* tests

* tests

* tests

* param vals

* swav

* tests

* tests

* tests

* tests

* pep8

* changed datamodule import

* changed datamodule import

* docs and fix finetune

* swav

* tests

* tests

* tests

* param vals

* tests

* pep8

* changed datamodule import

* changed datamodule import

* docs and fix finetune

* script tests

* passing tests

* passing tests

* replaced datamodule

* replaced datamodule

* replaced datamodule

* resnet

* resnet

* resnet

* swav]

* imagenet

* cifar10

* cifar10

* cifar10

* update for v1

* min req

* min req

* tests

* Apply suggestions from code review

* Apply suggestions from code review

* req

* imports

* imports

* imports

* imports

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people authored Oct 19, 2020
1 parent d32d3eb commit 2d57918
Show file tree
Hide file tree
Showing 12 changed files with 1,577 additions and 14 deletions.
159 changes: 150 additions & 9 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ CIFAR-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/cpc-cifar10-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|
Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/online-finetuning-cpc-cifar10.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand All @@ -234,15 +234,15 @@ STL-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/cpc-stl10-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|
Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/online-finetuning-cpc-stl10.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand All @@ -263,15 +263,15 @@ ImageNet pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpcv2_weights/cpc-imagenet-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|
Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpcv2_weights/online-finetuning-cpc-imagenet.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand Down Expand Up @@ -371,19 +371,19 @@ CIFAR-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_loss.png
:width: 200
:width: 400
:alt: pretraining validation loss

|
Fine-tuning (Single layer MLP, 1024 hidden units):

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_acc.png
:width: 200
:width: 400
:alt: finetuning validation accuracy

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/test_acc.png
:width: 200
:width: 400
:alt: finetuning test accuracy

|
Expand All @@ -408,3 +408,144 @@ SimCLR API

.. autoclass:: pl_bolts.models.self_supervised.SimCLR
:noindex:

---------

SwAV
^^^^

PyTorch Lightning implementation of `SwAV <https://arxiv.org/abs/2006.09882>`_
Adapted from the `official implementation <https://github.com/facebookresearch/swav>`_

Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.

Implementation adapted by:

- `Ananya Harsh Jha <https://github.com/ananyahjha93>`_

To Train::

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform, SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization()
)

dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization()
)

# model
model = SwAV(
gpus=1,
num_samples=dm.num_unlabeled_samples,
datamodule=dm,
batch_size=batch_size
)

# fit
trainer = pl.Trainer(precision=16)
trainer.fit(model)

STL-10 baseline
*****************

The original paper does not provide baselines on STL10.

.. list-table:: STL-10 implementation results
:widths: 18 15 25 15 10 20 20 20 10
:header-rows: 1

* - Implementation
- test acc
- Encoder
- Optimizer
- Batch
- Queue used
- Epochs
- Hardware
- LR
* - Ours
- `86.72 <https://tensorboard.dev/experiment/w2pq3bPPSxC4VIm5udhA2g/>`_
- SwAV resnet50
- `LARS <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 128
- No
- 100 (~9 hr)
- 1 V100 (16GB)
- 1e-3

|
- `Pre-training tensorboard link <https://tensorboard.dev/experiment/68jet8o4RdK34u5kUXLedg/>`_

STL-10 pretrained model::

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/epoch%3D96.ckpt'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)

swav.freeze()

|
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/pretraining-val-loss.png
:width: 400
:alt: pretraining validation loss

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/online-finetuning-val-acc.png
:width: 400
:alt: online finetuning validation acc

|
Fine-tuning (Single layer MLP, 1024 hidden units):

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/fine-tune-val-acc.png
:width: 400
:alt: finetuning validation accuracy

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/fine-tune-val-loss.png
:width: 400
:alt: finetuning validation loss

|
To reproduce::

# pretrain
python swav_module.py
--online_ft
--gpus 1
--lars_wrapper
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
--queue_length 0
--jitter_strength 1.
--nmb_prototypes 512

# finetune
python swav_finetuner.py
--ckpt_path path/to/epoch=xyz.ckpt

SwAV API
********

.. autoclass:: pl_bolts.models.self_supervised.SwAV
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from warnings import warn
from typing import Optional
from warnings import warn

from torch.utils.data import random_split

Expand Down
8 changes: 8 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_online_eval import SwavOnlineEvaluator
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVEvalDataTransform,
SwAVTrainDataTransform,
SwAVFinetuneTransform
)
91 changes: 91 additions & 0 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from argparse import ArgumentParser

import pytorch_lightning as pl

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, imagenet_normalization


def cli_main(): # pragma: no-cover
from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule

pl.seed_everything(1234)

parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument('--dataset', type=str, help='cifar10', default='stl10')
parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
parser.add_argument('--data_path', type=str, help='path to ckpt', default=os.getcwd())

parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU")
args = parser.parse_args()

if args.dataset == 'stl10':
dm = STL10DataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

dm.train_dataloader = dm.train_dataloader_labeled
dm.val_dataloader = dm.val_dataloader_labeled
args.num_samples = 0

dm.train_transforms = SwAVFinetuneTransform(
normalize=stl10_normalization(),
input_height=dm.size()[-1],
eval_transform=False
)
dm.val_transforms = SwAVFinetuneTransform(
normalize=stl10_normalization(),
input_height=dm.size()[-1],
eval_transform=True
)

args.maxpool1 = False
elif args.dataset == 'imagenet':
dm = ImagenetDataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

dm.train_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=False
)
dm.val_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=True
)

args.num_samples = 0
args.maxpool1 = True
else:
raise NotImplementedError("other datasets have not been implemented till now")

backbone = SwAV(
gpus=args.gpus,
num_samples=args.num_samples,
batch_size=args.batch_size,
datamodule=dm,
maxpool1=args.maxpool1
).load_from_checkpoint(args.ckpt_path, strict=False)

tuner = SSLFineTuner(backbone, in_features=2048, num_classes=dm.num_classes, hidden_dim=None)
trainer = pl.Trainer.from_argparse_args(
args, gpus=args.gpus, precision=16, early_stop_callback=True
)
trainer.fit(tuner, dm)

trainer.test(datamodule=dm)


if __name__ == '__main__':
cli_main()
Loading

0 comments on commit 2d57918

Please sign in to comment.