diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index ae7a6aebea..3e2d4220e2 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -5,20 +5,20 @@ import torch from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models.gans.basic.components import Generator, Discriminator class GAN(pl.LightningModule): - def __init__(self, - datamodule: pl.LightningDataModule = None, - latent_dim: int = 32, - batch_size: int = 100, - learning_rate: float = 0.0002, - data_dir: str = '', - num_workers: int = 8, - **kwargs): + def __init__( + self, + input_channels: int, + input_height: int, + input_width: int, + latent_dim: int = 32, + learning_rate: float = 0.0002, + **kwargs + ): """ Vanilla GAN implementation. @@ -53,24 +53,12 @@ def __init__(self, # makes self.hparams under the hood and saves to ckpt self.save_hyperparameters() - - self._set_default_datamodule(datamodule) + self.img_dim = (input_channels, input_height, input_width) # networks self.generator = self.init_generator(self.img_dim) self.discriminator = self.init_discriminator(self.img_dim) - def _set_default_datamodule(self, datamodule): - # link default data - if datamodule is None: - datamodule = MNISTDataModule( - data_dir=self.hparams.data_dir, - num_workers=self.hparams.num_workers, - normalize=True - ) - self.datamodule = datamodule - self.img_dim = self.datamodule.size() - def init_generator(self, img_dim): generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_dim) return generator @@ -179,44 +167,40 @@ def add_model_specific_args(parent_parser): help="adam: decay of first order momentum of gradient") parser.add_argument('--latent_dim', type=int, default=100, help="generator embedding dim") - parser.add_argument('--batch_size', type=int, default=64, help="size of the batches") - parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") - parser.add_argument('--data_dir', type=str, default=os.getcwd()) - return parser -def cli_main(): +def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler - from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule pl.seed_everything(1234) parser = ArgumentParser() - parser.add_argument('--dataset', type=str, default='mnist', help='mnist, stl10, imagenet2012') - + parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == "mnist": + dm_cls = MNISTDataModule + elif script_args.dataset == "cifar10": + dm_cls = CIFAR10DataModule + elif script_args.dataset == "stl10": + dm_cls = STL10DataModule + elif script_args.dataset == "imagenet": + dm_cls = ImagenetDataModule + + parser = dm_cls.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser = GAN.add_model_specific_args(parser) - parser = ImagenetDataModule.add_argparse_args(parser) - args = parser.parse_args() - - # default is mnist - datamodule = None - if args.dataset == 'imagenet2012': - datamodule = ImagenetDataModule.from_argparse_args(args) - elif args.dataset == 'stl10': - datamodule = STL10DataModule.from_argparse_args(args) - - gan = GAN(**vars(args), datamodule=datamodule) - callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator()] + args = parser.parse_args(args) - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - progress_bar_refresh_rate=10 - ) - trainer.fit(gan) + dm = dm_cls.from_argparse_args(args) + model = GAN(*dm.size(), **vars(args)) + callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)] + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20) + trainer.fit(model, dm) + return dm, model, trainer if __name__ == '__main__': - cli_main() + dm, model, trainer = cli_main() diff --git a/tests/callbacks/test_variational_callbacks.py b/tests/callbacks/test_variational_callbacks.py index 062ce733a7..ba49540554 100644 --- a/tests/callbacks/test_variational_callbacks.py +++ b/tests/callbacks/test_variational_callbacks.py @@ -12,7 +12,7 @@ def __init__(self): self.global_step = 1 self.logger = DummyLogger() - model = GAN() + model = GAN(3, 28, 28) cb = LatentDimInterpolator(interpolate_epoch_interval=2) cb.on_epoch_end(FakeTrainer(), model) diff --git a/tests/models/test_executable_scripts.py b/tests/models/test_executable_scripts.py index 7923357b6c..ba683dc1b4 100644 --- a/tests/models/test_executable_scripts.py +++ b/tests/models/test_executable_scripts.py @@ -3,14 +3,23 @@ import pytest -@pytest.mark.parametrize('cli_args', ['--max_epochs 1' - ' --limit_train_batches 3' - ' --limit_val_batches 3' - ' --batch_size 3']) -def test_cli_basic_gan(cli_args): +@pytest.mark.parametrize( + "dataset_name", [ + pytest.param('mnist', id="mnist"), + pytest.param('cifar10', id="cifar10") + ] +) +def test_cli_basic_gan(dataset_name): from pl_bolts.models.gans.basic.basic_gan_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = f""" + --dataset {dataset_name} + --max_epochs 1 + --limit_train_batches 3 + --limit_val_batches 3 + --batch_size 3 + """.strip().split() + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() diff --git a/tests/models/test_gans.py b/tests/models/test_gans.py index 50e9db3a90..eaf149a54c 100644 --- a/tests/models/test_gans.py +++ b/tests/models/test_gans.py @@ -1,13 +1,19 @@ +import pytest import pytorch_lightning as pl from pytorch_lightning import seed_everything +from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule from pl_bolts.models.gans import GAN -def test_gan(tmpdir): +@pytest.mark.parametrize( + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] +) +def test_gan(tmpdir, dm_cls): seed_everything() - model = GAN(data_dir=tmpdir) + dm = dm_cls() + model = GAN(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) - trainer.fit(model) - trainer.test(model) + trainer.fit(model, dm) + trainer.test(datamodule=dm)