Skip to content

Commit

Permalink
current
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Oct 2, 2020
1 parent e7e0062 commit 76cedfe
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
27 changes: 24 additions & 3 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.transforms.dataset_normalizations import stl10_normalization
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, imagenet_normalization
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform


Expand Down Expand Up @@ -48,19 +48,40 @@ def cli_main(): # pragma: no-cover
)

args.maxpool1 = False
elif args.dataset == 'imagenet':
dm = ImagenetDataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

dm.train_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=False
)
dm.val_transforms = SwAVFinetuneTransform(
normalize=imagenet_normalization(),
input_height=dm.size()[-1],
eval_transform=True
)

args.num_samples = 0
args.maxpool1 = True
else:
raise NotImplementedError("other datasets have not been implemented till now")

backbone = SwAV(
gpus=1,
num_samples=args.num_samples,
batch_size=args.batch_size,
datamodule=dm
datamodule=dm,
maxpool1=args.maxpool1
).load_from_checkpoint(args.ckpt_path, strict=False)

tuner = SSLFineTuner(backbone, in_features=2048, num_classes=dm.num_classes, hidden_dim=None)
trainer = pl.Trainer.from_argparse_args(
args, gpus=1, precision=16, early_stop_callback=True
args, gpus=4, precision=16, early_stop_callback=True
)
trainer.fit(tuner, dm)

Expand Down
16 changes: 16 additions & 0 deletions pl_bolts/models/self_supervised/swav/weights_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch
from collections import OrderedDict

swav_imagenet = torch.load('swav_imagenet.pth.tar')

new_state_dict = OrderedDict()

for key in swav_imagenet.keys():
if 'prototype' in key:
continue
new_state_dict[key.replace('module.', 'model.')] = swav_imagenet[key]

stl10_save = torch.load("epoch=96.ckpt")
stl10_save['state_dict'] = new_state_dict

torch.save(stl10_save, 'swav_imagenet.ckpt')

0 comments on commit 76cedfe

Please sign in to comment.