Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swav #239

Merged
merged 55 commits into from
Oct 19, 2020
Merged

Swav #239

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ee1a43c
swav
ananyahjha93 Sep 16, 2020
53d47d9
swav
ananyahjha93 Sep 16, 2020
19955ae
Merge branch 'swav' of https://github.com/PyTorchLightning/pytorch-li…
ananyahjha93 Sep 16, 2020
ab4ee2c
tests
ananyahjha93 Sep 16, 2020
a50bcbc
tests
ananyahjha93 Sep 16, 2020
2f88311
tests
ananyahjha93 Sep 16, 2020
6fea7ce
param vals
ananyahjha93 Sep 17, 2020
c0402f7
swav
ananyahjha93 Sep 16, 2020
ff7efe8
tests
ananyahjha93 Sep 16, 2020
09d2f89
tests
ananyahjha93 Sep 16, 2020
e774c7b
tests
ananyahjha93 Sep 16, 2020
982736e
origin pull
ananyahjha93 Sep 17, 2020
d618959
tests
ananyahjha93 Sep 17, 2020
f5750b6
pep8
ananyahjha93 Sep 17, 2020
c29de83
changed datamodule import
ananyahjha93 Sep 17, 2020
af0dcc6
changed datamodule import
ananyahjha93 Sep 17, 2020
5129b0b
docs and fix finetune
ananyahjha93 Sep 22, 2020
f8ca002
swav
ananyahjha93 Sep 16, 2020
c5941ee
tests
ananyahjha93 Sep 16, 2020
3832cd8
tests
ananyahjha93 Sep 16, 2020
d281407
tests
ananyahjha93 Sep 16, 2020
9649ad6
param vals
ananyahjha93 Sep 17, 2020
2efa376
tests
ananyahjha93 Sep 17, 2020
4c0162f
pep8
ananyahjha93 Sep 17, 2020
7d158b3
changed datamodule import
ananyahjha93 Sep 17, 2020
74d3580
changed datamodule import
ananyahjha93 Sep 17, 2020
12fb1d2
docs and fix finetune
ananyahjha93 Sep 22, 2020
20774fa
script tests
ananyahjha93 Sep 23, 2020
28d2209
Merge branch 'swav' of https://github.com/PyTorchLightning/pytorch-li…
ananyahjha93 Sep 23, 2020
5bdd08c
passing tests
ananyahjha93 Sep 23, 2020
06b6c0d
passing tests
ananyahjha93 Sep 23, 2020
8fd2cf3
replaced datamodule
ananyahjha93 Sep 23, 2020
2e6e378
replaced datamodule
ananyahjha93 Sep 23, 2020
06cb0ff
replaced datamodule
ananyahjha93 Sep 23, 2020
69ae62b
resnet
ananyahjha93 Sep 25, 2020
4fd6325
resnet
ananyahjha93 Sep 25, 2020
1834b4f
resnet
ananyahjha93 Sep 25, 2020
2ea8960
swav]
ananyahjha93 Sep 29, 2020
5fdc295
imagenet
ananyahjha93 Oct 13, 2020
fe7b400
Merge branch 'master' into swav
ananyahjha93 Oct 19, 2020
bf1fb87
cifar10
ananyahjha93 Oct 19, 2020
fbec8af
cifar10
ananyahjha93 Oct 19, 2020
62d4124
cifar10
ananyahjha93 Oct 19, 2020
4817431
update for v1
ananyahjha93 Oct 19, 2020
9a43a24
min req
ananyahjha93 Oct 19, 2020
0b02fed
min req
ananyahjha93 Oct 19, 2020
b6e3955
tests
ananyahjha93 Oct 19, 2020
b32fde3
Apply suggestions from code review
Borda Oct 19, 2020
0b1ad24
Apply suggestions from code review
Borda Oct 19, 2020
1f8f080
req
Borda Oct 19, 2020
a3b2dd0
imports
Borda Oct 19, 2020
aae0b48
imports
ananyahjha93 Oct 19, 2020
f3d4f28
Merge branch 'swav' of https://github.com/PyTorchLightning/pytorch-li…
ananyahjha93 Oct 19, 2020
1538f16
imports
ananyahjha93 Oct 19, 2020
1ccfbaa
imports
ananyahjha93 Oct 19, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 150 additions & 9 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,15 @@ CIFAR-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/cpc-cifar10-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|

Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-cifar10-v4-exp3/online-finetuning-cpc-cifar10.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand All @@ -234,15 +234,15 @@ STL-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/cpc-stl10-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|

Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpc-stl10-v0-exp3/online-finetuning-cpc-stl10.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand All @@ -263,15 +263,15 @@ ImageNet pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpcv2_weights/cpc-imagenet-val.png
:width: 200
:width: 400
:alt: pretraining validation loss

|

Fine-tuning:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/cpc/cpcv2_weights/online-finetuning-cpc-imagenet.png
:width: 200
:width: 400
:alt: online finetuning accuracy

|
Expand Down Expand Up @@ -371,19 +371,19 @@ CIFAR-10 pretrained model::
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_loss.png
:width: 200
:width: 400
:alt: pretraining validation loss

|

Fine-tuning (Single layer MLP, 1024 hidden units):

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_acc.png
:width: 200
:width: 400
:alt: finetuning validation accuracy

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/test_acc.png
:width: 200
:width: 400
:alt: finetuning test accuracy

|
Expand All @@ -408,3 +408,144 @@ SimCLR API

.. autoclass:: pl_bolts.models.self_supervised.SimCLR
:noindex:

---------

SwAV
^^^^

PyTorch Lightning implementation of `SwAV <https://arxiv.org/abs/2006.09882>`_
Adapted from the `official implementation <https://github.com/facebookresearch/swav>`_

Paper authors: Mathilde Caron, Ishan Misra, Julien Mairal, Priya Goyal, Piotr Bojanowski, Armand Joulin.

Implementation adapted by:

- `Ananya Harsh Jha <https://github.com/ananyahjha93>`_

To Train::

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform, SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

# data
batch_size = 128
dm = STL10DataModule(data_dir='.', batch_size=batch_size)
dm.train_dataloader = dm.train_dataloader_mixed
dm.val_dataloader = dm.val_dataloader_mixed

dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization()
)

dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization()
)

# model
model = SwAV(
gpus=1,
num_samples=dm.num_unlabeled_samples,
datamodule=dm,
batch_size=batch_size
)

# fit
trainer = pl.Trainer(precision=16)
trainer.fit(model)

STL-10 baseline
*****************

The original paper does not provide baselines on STL10.

.. list-table:: STL-10 implementation results
:widths: 18 15 25 15 10 20 20 20 10
:header-rows: 1

* - Implementation
- test acc
- Encoder
- Optimizer
- Batch
- Queue used
- Epochs
- Hardware
- LR
* - Ours
- `86.72 <https://tensorboard.dev/experiment/w2pq3bPPSxC4VIm5udhA2g/>`_
- SwAV resnet50
- `LARS <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 128
- No
- 100 (~9 hr)
- 1 V100 (16GB)
- 1e-3

|

- `Pre-training tensorboard link <https://tensorboard.dev/experiment/68jet8o4RdK34u5kUXLedg/>`_

STL-10 pretrained model::

from pl_bolts.models.self_supervised import SwAV

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/checkpoints/epoch%3D96.ckpt'
swav = SwAV.load_from_checkpoint(weight_path, strict=False)

swav.freeze()

|

Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/pretraining-val-loss.png
:width: 400
:alt: pretraining validation loss

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/online-finetuning-val-acc.png
:width: 400
:alt: online finetuning validation acc

|

Fine-tuning (Single layer MLP, 1024 hidden units):

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/fine-tune-val-acc.png
:width: 400
:alt: finetuning validation accuracy

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/fine-tune-val-loss.png
:width: 400
:alt: finetuning validation loss

|

To reproduce::

# pretrain
python swav_module.py
--online_ft
--gpus 1
--lars_wrapper
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
--queue_length 0
--jitter_strength 1.
--nmb_prototypes 512

# finetune
python swav_finetuner.py
--ckpt_path path/to/epoch=xyz.ckpt

SwAV API
********

.. autoclass:: pl_bolts.models.self_supervised.SwAV
:noindex:
1 change: 1 addition & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from pl_bolts.models.self_supervised.moco.moco2_module import MocoV2
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from warnings import warn
from typing import Optional
from warnings import warn

from torch.utils.data import random_split

Expand Down
8 changes: 8 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_online_eval import SwavOnlineEvaluator
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVEvalDataTransform,
SwAVTrainDataTransform,
SwAVFinetuneTransform
)
91 changes: 91 additions & 0 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from argparse import ArgumentParser

import pytorch_lightning as pl

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, imagenet_normalization


def cli_main(): # pragma: no-cover
from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule

pl.seed_everything(1234)

parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument('--dataset', type=str, help='cifar10', default='stl10')
parser.add_argument('--ckpt_path', type=str, help='path to ckpt')
parser.add_argument('--data_path', type=str, help='path to ckpt', default=os.getcwd())

parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU")
args = parser.parse_args()
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved

if args.dataset == 'stl10':
dm = STL10DataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

dm.train_dataloader = dm.train_dataloader_labeled
dm.val_dataloader = dm.val_dataloader_labeled
args.num_samples = 0

dm.train_transforms = SwAVFinetuneTransform(
normalize=stl10_normalization(),
input_height=dm.size()[-1],
eval_transform=False
)
dm.val_transforms = SwAVFinetuneTransform(
normalize=stl10_normalization(),
input_height=dm.size()[-1],
eval_transform=True
)

args.maxpool1 = False
elif args.dataset == 'imagenet':
dm = ImagenetDataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

dm.train_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=False
)
dm.val_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=True
)

args.num_samples = 0
args.maxpool1 = True
else:
raise NotImplementedError("other datasets have not been implemented till now")

backbone = SwAV(
gpus=args.gpus,
num_samples=args.num_samples,
batch_size=args.batch_size,
datamodule=dm,
maxpool1=args.maxpool1
).load_from_checkpoint(args.ckpt_path, strict=False)

tuner = SSLFineTuner(backbone, in_features=2048, num_classes=dm.num_classes, hidden_dim=None)
trainer = pl.Trainer.from_argparse_args(
args, gpus=args.gpus, precision=16, early_stop_callback=True
)
trainer.fit(tuner, dm)

trainer.test(datamodule=dm)


if __name__ == '__main__':
cli_main()
Loading