Skip to content

Commit

Permalink
Apply yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Clement committed Jan 17, 2021
1 parent 2c0d4d0 commit 01fea8a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
2 changes: 2 additions & 0 deletions pl_bolts/models/gans/dcgan/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class DCGANGenerator(nn.Module):

def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
"""
Args:
Expand Down Expand Up @@ -49,6 +50,7 @@ def forward(self, noise: torch.Tensor) -> torch.Tensor:


class DCGANDiscriminator(nn.Module):

def __init__(self, feature_maps: int, image_channels: int) -> None:
"""
Args:
Expand Down
26 changes: 11 additions & 15 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,24 +186,20 @@ def cli_main(args=None):
script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "lsun":
transforms = transform_lib.Compose(
[
transform_lib.Resize(script_args.image_size),
transform_lib.CenterCrop(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
transforms = transform_lib.Compose([
transform_lib.Resize(script_args.image_size),
transform_lib.CenterCrop(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = LSUN(root=script_args.data_dir, classes=["bedroom_train"], transform=transforms)
image_channels = 3
elif script_args.dataset == "mnist":
transforms = transform_lib.Compose(
[
transform_lib.Resize(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5,), (0.5,)),
]
)
transforms = transform_lib.Compose([
transform_lib.Resize(script_args.image_size),
transform_lib.ToTensor(),
transform_lib.Normalize((0.5, ), (0.5, )),
])
dataset = MNIST(root=script_args.data_dir, download=True, transform=transforms)
image_channels = 1

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_gans.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_gan(tmpdir, datadir, dm_cls):


@pytest.mark.parametrize(
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")]
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"),
pytest.param(CIFAR10DataModule, id="cifar10")]
)
def test_dcgan(tmpdir, datadir, dm_cls):
seed_everything()
Expand Down

0 comments on commit 01fea8a

Please sign in to comment.