From 496452fe89dc19fa866bfc53ebc7960f092b3aad Mon Sep 17 00:00:00 2001 From: Atharva Phatak Date: Thu, 27 Oct 2022 22:04:56 -0400 Subject: [PATCH 1/5] minor dcgan-import fix --- pl_bolts/models/gans/dcgan/components.py | 6 ++---- pl_bolts/models/gans/dcgan/dcgan_module.py | 12 +++++------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index fe6ccfd8a8..3ef431961d 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -1,10 +1,8 @@ # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py -from torch import Tensor, nn +from torch import Tensor +import torch.nn as nn -from pl_bolts.utils.stability import under_review - -@under_review() class DCGANGenerator(nn.Module): def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: """ diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index 6b80eb95c8..c48064d0e8 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -3,13 +3,13 @@ import torch from pytorch_lightning import LightningModule, Trainer, seed_everything -from torch import Tensor, nn +from torch import Tensor +import torch.nn as nn from torch.utils.data import DataLoader from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.models.gans.dcgan.components import DCGANDiscriminator, DCGANGenerator from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _TORCHVISION_AVAILABLE: @@ -19,7 +19,6 @@ warn_missing_pkg("torchvision") -@under_review() class DCGAN(LightningModule): """DCGAN implementation. @@ -80,10 +79,10 @@ def _get_discriminator(self) -> nn.Module: def _weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: - torch.nn.init.normal_(m.weight, 0.0, 0.02) + nn.init.normal_(m.weight, 0.0, 0.02) elif classname.find("BatchNorm") != -1: - torch.nn.init.normal_(m.weight, 1.0, 0.02) - torch.nn.init.zeros_(m.bias) + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) def configure_optimizers(self): lr = self.hparams.learning_rate @@ -173,7 +172,6 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: return parser -@under_review() def cli_main(args=None): seed_everything(1234) From 32e13bd8f47ff7750776048101db0bc2e829229d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 02:08:17 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/gans/dcgan/components.py | 2 +- pl_bolts/models/gans/dcgan/dcgan_module.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index 3ef431961d..5069f8bf56 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -1,6 +1,6 @@ # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py -from torch import Tensor import torch.nn as nn +from torch import Tensor class DCGANGenerator(nn.Module): diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index c48064d0e8..40580f94b4 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -2,9 +2,9 @@ from typing import Any import torch +import torch.nn as nn from pytorch_lightning import LightningModule, Trainer, seed_everything from torch import Tensor -import torch.nn as nn from torch.utils.data import DataLoader from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler From f9a16613684c4395bb125f583a780047bb0a634a Mon Sep 17 00:00:00 2001 From: Atharva Phatak Date: Thu, 27 Oct 2022 22:13:40 -0400 Subject: [PATCH 3/5] fix under_review import --- pl_bolts/models/gans/dcgan/components.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index 5069f8bf56..2b0627902f 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -47,8 +47,6 @@ def _make_gen_block( def forward(self, noise: Tensor) -> Tensor: return self.gen(noise) - -@under_review() class DCGANDiscriminator(nn.Module): def __init__(self, feature_maps: int, image_channels: int) -> None: """ From 45ea572837d65c33a20d34bb2a679de5f990b776 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 02:14:09 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/gans/dcgan/components.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pl_bolts/models/gans/dcgan/components.py b/pl_bolts/models/gans/dcgan/components.py index 2b0627902f..d1d52779bc 100644 --- a/pl_bolts/models/gans/dcgan/components.py +++ b/pl_bolts/models/gans/dcgan/components.py @@ -47,6 +47,7 @@ def _make_gen_block( def forward(self, noise: Tensor) -> Tensor: return self.gen(noise) + class DCGANDiscriminator(nn.Module): def __init__(self, feature_maps: int, image_channels: int) -> None: """ From 148fc9426647cd5526b6c473b21d8f040e0e7e5c Mon Sep 17 00:00:00 2001 From: Atharva Phatak Date: Fri, 28 Oct 2022 15:25:47 -0400 Subject: [PATCH 5/5] update docstring --- pl_bolts/models/gans/dcgan/dcgan_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index 40580f94b4..92389118cd 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -27,7 +27,7 @@ class DCGAN(LightningModule): from pl_bolts.models.gans import DCGAN m = DCGAN() - Trainer(gpus=2).fit(m) + Trainer(accelerator="gpu", devices=2).fit(m) Example CLI::