-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* swav * swav * tests * tests * tests * param vals * swav * tests * tests * tests * tests * pep8 * changed datamodule import * changed datamodule import * docs and fix finetune * swav * tests * tests * tests * param vals * tests * pep8 * changed datamodule import * changed datamodule import * docs and fix finetune * script tests * passing tests * passing tests * replaced datamodule * replaced datamodule * replaced datamodule * resnet * resnet * resnet * swav] * imagenet * cifar10 * cifar10 * cifar10 * update for v1 * min req * min req * tests * Apply suggestions from code review * Apply suggestions from code review * req * imports * imports * imports * imports Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
- Loading branch information
1 parent
d32d3eb
commit 2d57918
Showing
12 changed files
with
1,577 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from pl_bolts.models.self_supervised.swav.swav_module import SwAV | ||
from pl_bolts.models.self_supervised.swav.swav_online_eval import SwavOnlineEvaluator | ||
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 | ||
from pl_bolts.models.self_supervised.swav.transforms import ( | ||
SwAVEvalDataTransform, | ||
SwAVTrainDataTransform, | ||
SwAVFinetuneTransform | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
from argparse import ArgumentParser | ||
|
||
import pytorch_lightning as pl | ||
|
||
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner | ||
from pl_bolts.models.self_supervised.swav.swav_module import SwAV | ||
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform | ||
from pl_bolts.transforms.dataset_normalizations import stl10_normalization, imagenet_normalization | ||
|
||
|
||
def cli_main(): # pragma: no-cover | ||
from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule | ||
|
||
pl.seed_everything(1234) | ||
|
||
parser = ArgumentParser() | ||
parser = pl.Trainer.add_argparse_args(parser) | ||
parser.add_argument('--dataset', type=str, help='cifar10', default='stl10') | ||
parser.add_argument('--ckpt_path', type=str, help='path to ckpt') | ||
parser.add_argument('--data_path', type=str, help='path to ckpt', default=os.getcwd()) | ||
|
||
parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") | ||
parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU") | ||
args = parser.parse_args() | ||
|
||
if args.dataset == 'stl10': | ||
dm = STL10DataModule( | ||
data_dir=args.data_path, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers | ||
) | ||
|
||
dm.train_dataloader = dm.train_dataloader_labeled | ||
dm.val_dataloader = dm.val_dataloader_labeled | ||
args.num_samples = 0 | ||
|
||
dm.train_transforms = SwAVFinetuneTransform( | ||
normalize=stl10_normalization(), | ||
input_height=dm.size()[-1], | ||
eval_transform=False | ||
) | ||
dm.val_transforms = SwAVFinetuneTransform( | ||
normalize=stl10_normalization(), | ||
input_height=dm.size()[-1], | ||
eval_transform=True | ||
) | ||
|
||
args.maxpool1 = False | ||
elif args.dataset == 'imagenet': | ||
dm = ImagenetDataModule( | ||
data_dir=args.data_path, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers | ||
) | ||
|
||
dm.train_transforms = SwAVFinetuneTransform( | ||
normalize=imagenet_normalization(), | ||
input_height=dm.size()[-1], | ||
eval_transform=False | ||
) | ||
dm.val_transforms = SwAVFinetuneTransform( | ||
normalize=imagenet_normalization(), | ||
input_height=dm.size()[-1], | ||
eval_transform=True | ||
) | ||
|
||
args.num_samples = 0 | ||
args.maxpool1 = True | ||
else: | ||
raise NotImplementedError("other datasets have not been implemented till now") | ||
|
||
backbone = SwAV( | ||
gpus=args.gpus, | ||
num_samples=args.num_samples, | ||
batch_size=args.batch_size, | ||
datamodule=dm, | ||
maxpool1=args.maxpool1 | ||
).load_from_checkpoint(args.ckpt_path, strict=False) | ||
|
||
tuner = SSLFineTuner(backbone, in_features=2048, num_classes=dm.num_classes, hidden_dim=None) | ||
trainer = pl.Trainer.from_argparse_args( | ||
args, gpus=args.gpus, precision=16, early_stop_callback=True | ||
) | ||
trainer.fit(tuner, dm) | ||
|
||
trainer.test(datamodule=dm) | ||
|
||
|
||
if __name__ == '__main__': | ||
cli_main() |
Oops, something went wrong.