Skip to content

Commit

Permalink
add some custom gan models and datasets (PaddlePaddle#95)
Browse files Browse the repository at this point in the history
* add some custom gan models
  • Loading branch information
HighCWu committed Dec 8, 2020
1 parent fdbc6ae commit b2b881e
Show file tree
Hide file tree
Showing 13 changed files with 591 additions and 7 deletions.
67 changes: 67 additions & 0 deletions configs/cond_dcgan_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
epochs: 200
output_dir: output_dir

model:
name: GANModel
generator:
name: ConditionalDeepConvGenerator
latent_dim: 128
output_nc: 1
size: 28
ngf: 64
n_class: 10
discriminator:
name: NLayerDiscriminatorWithClassification
ndf: 16
n_layers: 3
input_nc: 1
norm_type: batch
n_class: 10
use_sigmoid: True
gan_mode: vanilla

dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 4
batch_size: 64
mode: train
return_cls: True
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 0
batch_size: 64
mode: test
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: True


optimizer:
name: Adam
beta1: 0.5

lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100

log_config:
interval: 100
visiual_interval: 500

snapshot_config:
interval: 5
65 changes: 65 additions & 0 deletions configs/wgan_mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
epochs: 200
output_dir: output_dir

model:
name: GANModel
generator:
name: DeepConvGenerator
latent_dim: 128
output_nc: 1
size: 28
ngf: 64
discriminator:
name: NLayerDiscriminator
ndf: 16
n_layers: 3
input_nc: 1
norm_type: instance
gan_mode: wgan
n_critic: 5

dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 4
batch_size: 64
mode: train
return_cls: False
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
num_workers: 0
batch_size: 64
mode: test
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: False


optimizer:
name: Adam
beta1: 0.5

lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100

log_config:
interval: 100
visiual_interval: 500

snapshot_config:
interval: 5
1 change: 1 addition & 0 deletions ppgan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
66 changes: 66 additions & 0 deletions ppgan/datasets/common_vision_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle

from .builder import DATASETS
from .base_dataset import BaseDataset
from .transforms.builder import build_transforms


@DATASETS.register()
class CommonVisionDataset(BaseDataset):
"""
Dataset for using paddle vision default datasets
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
"""
super(CommonVisionDataset, self).__init__(cfg)

dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name'))
transform = build_transforms(cfg.pop('transforms', None))
self.return_cls = cfg.pop('return_cls', True)

param_dict = {}
param_names = list(dataset_cls.__init__.__code__.co_varnames)
if 'transform' in param_names:
param_dict['transform'] = transform
for name in param_names:
if name in cfg:
param_dict[name] = cfg.get(name)

self.dataset = dataset_cls(**param_dict)

def __getitem__(self, index):
return_dict = {}
return_list = self.dataset[index]
if isinstance(return_list, (tuple, list)):
if len(return_list) == 2:
return_dict['img'] = return_list[0]
if self.return_cls:
return_dict['class_id'] = np.asarray(return_list[1])
else:
return_dict['img'] = return_list[0]
else:
return_dict['img'] = return_list

return return_dict

def __len__(self):
return len(self.dataset)
2 changes: 1 addition & 1 deletion ppgan/datasets/single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __len__(self):
return len(self.A_paths)

def get_path_by_indexs(self, indexs):
if isinstance(indexs, paddle.Variable):
if isinstance(indexs, paddle.Tensor):
indexs = indexs.numpy()
current_paths = []
for index in indexs:
Expand Down
1 change: 1 addition & 0 deletions ppgan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .base_model import BaseModel
from .gan_model import GANModel
from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
Expand Down
2 changes: 1 addition & 1 deletion ppgan/models/discriminators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .nlayers import NLayerDiscriminator
from .nlayers import NLayerDiscriminator, NLayerDiscriminatorWithClassification
from .discriminator_ugatit import UGATITDiscriminator
from .dcdiscriminator import DCDiscriminator
from .discriminator_animegan import AnimeDiscriminator
26 changes: 24 additions & 2 deletions ppgan/models/discriminators/nlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
@DISCRIMINATORS.register()
class NLayerDiscriminator(nn.Layer):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance', use_sigmoid=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_type (str) -- normalization layer type
use_sigmoid (bool) -- whether use sigmoid at last
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = build_norm_layer(norm_type)
Expand Down Expand Up @@ -139,7 +140,28 @@ def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
] # output 1 channel prediction map

self.model = nn.Sequential(*sequence)
self.final_act = F.sigmoid if use_sigmoid else (lambda x:x)

def forward(self, input):
"""Standard forward."""
return self.model(input)
return self.final_act(self.model(input))


@DISCRIMINATORS.register()
class NLayerDiscriminatorWithClassification(NLayerDiscriminator):
def __init__(self, input_nc, n_class=10, **kwargs):
input_nc = input_nc + n_class
super(NLayerDiscriminatorWithClassification, self).__init__(input_nc, **kwargs)

self.n_class = n_class

def forward(self, x, class_id):
if self.n_class > 0:
class_id = (class_id % self.n_class).detach()
class_id = F.one_hot(class_id, self.n_class).astype('float32')
class_id = class_id.reshape([x.shape[0], -1, 1, 1])
class_id = class_id.tile([1,1,*x.shape[2:]])
x = paddle.concat([x, class_id], 1)

return super(NLayerDiscriminatorWithClassification, self).forward(x)

Loading

0 comments on commit b2b881e

Please sign in to comment.