Skip to content

Commit

Permalink
Add train code of stylegan2 (PaddlePaddle#149)
Browse files Browse the repository at this point in the history
* add stylegan model
  • Loading branch information
LielinJiang committed Jan 22, 2021
1 parent e13e1c1 commit 530a6a8
Show file tree
Hide file tree
Showing 11 changed files with 711 additions and 285 deletions.
71 changes: 71 additions & 0 deletions configs/stylegan_v2_256_ffhq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
total_iters: 800000
output_dir: output_dir

model:
name: StyleGAN2Model
generator:
name: StyleGANv2Generator
size: 256
style_dim: 512
n_mlp: 8
discriminator:
name: StyleGANv2Discriminator
size: 256
gan_criterion:
name: GANLoss
gan_mode: logistic
loss_weight: !!float 1
# r1 regularization for discriminator
r1_reg_weight: 10.
# path length regularization for generator
path_batch_shrink: 2.
path_reg_weight: 2.
params:
gen_iters: 4
disc_iters: 16

dataset:
train:
name: SingleDataset
dataroot: data/ffhq/images256x256/
num_workers: 3
batch_size: 3
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: RandomHorizontalFlip
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]

lr_scheduler:
name: MultiStepDecay
learning_rate: 0.002
milestones: [600000]
gamma: 0.5

optimizer:
optimG:
name: Adam
beta1: 0.0
beta2: 0.792
net_names:
- gen
optimD:
name: Adam
net_names:
- disc
beta1: 0.0
beta2: 0.9317647058823529


log_config:
interval: 50
visiual_interval: 500

snapshot_config:
interval: 5000
10 changes: 6 additions & 4 deletions ppgan/datasets/preprocess/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,21 @@ def __call__(self, datas):
data = tuple(data)
for transform in self.transforms:
data = transform(data)

if hasattr(transform, 'params') and isinstance(
transform.params, dict):
datas.update(transform.params)

if len(self.input_keys) > 1:
for i, k in enumerate(self.input_keys):
datas[k] = data[i]
else:
datas[k] = data

if self.output_keys is not None:
for i, k in enumerate(self.output_keys):
datas[k] = data[i]
return datas

for i, k in enumerate(self.input_keys):
datas[k] = data[i]

return datas


Expand Down
2 changes: 1 addition & 1 deletion ppgan/datasets/single_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, dataroot, preprocess):
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
"""
super(SingleDataset).__init__(self, preprocess)
super(SingleDataset, self).__init__(preprocess)
self.dataroot = dataroot
self.data_infos = self.prepare_data_infos()

Expand Down
15 changes: 14 additions & 1 deletion ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(self, cfg):
self.batch_id = 0
self.global_steps = 0
self.weight_interval = cfg.snapshot_config.interval
if self.by_epoch:
self.weight_interval *= self.iters_per_epoch
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
if self.by_epoch:
Expand All @@ -143,6 +145,17 @@ def distributed_data_parallel(self):
for net_name, net in self.model.nets.items():
self.model.nets[net_name] = paddle.DataParallel(net, strategy)

def learning_rate_scheduler_step(self):
if isinstance(self.model.lr_scheduler, dict):
for lr_scheduler in self.model.lr_scheduler.values():
lr_scheduler.step()
elif isinstance(self.model.lr_scheduler,
paddle.optimizer.lr.LRScheduler):
self.model.lr_scheduler.step()
else:
raise ValueError(
'lr schedulter must be a dict or an instance of LRScheduler')

def train(self):
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
Expand Down Expand Up @@ -179,7 +192,7 @@ def train(self):
if self.current_iter % self.visual_interval == 0:
self.visual('visual_train')

self.model.lr_scheduler.step()
self.learning_rate_scheduler_step()

if self.validate_interval > -1 and self.current_iter % self.validate_interval == 0:
self.test()
Expand Down
1 change: 1 addition & 0 deletions ppgan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
102 changes: 60 additions & 42 deletions ppgan/models/discriminators/discriminator_styleganv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ def __init__(
activate=True,
):
layers = []

if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2

layers.append(Upfirdn2dBlur(blur_kernel, pad=(pad0, pad1)))

stride = 2
self.padding = 0

else:
stride = 1
self.padding = kernel_size // 2

layers.append(
EqualConv2D(
in_channel,
Expand All @@ -59,41 +59,58 @@ def __init__(
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)

))

if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))

super().__init__(*layers)


class ResBlock(nn.Layer):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()

self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

self.skip = ConvLayer(
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
)


self.skip = ConvLayer(in_channel,
out_channel,
1,
downsample=True,
activate=False,
bias=False)

def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)

skip = self.skip(input)
out = (out + skip) / math.sqrt(2)

return out




# temporally solve pow double grad problem
def var(x, axis=None, unbiased=True, keepdim=False, name=None):

u = paddle.mean(x, axis, True, name)
out = paddle.sum((x - u) * (x - u), axis, keepdim=keepdim, name=name)

n = paddle.cast(paddle.numel(x), x.dtype) \
/ paddle.cast(paddle.numel(out), x.dtype)
if unbiased:
one_const = paddle.ones([1], x.dtype)
n = paddle.where(n > one_const, n - 1., one_const)
out /= n
return out


@DISCRIMINATORS.register()
class StyleGANv2Discriminator(nn.Layer):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()

channels = {
4: 512,
8: 512,
Expand All @@ -105,47 +122,48 @@ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}

convs = [ConvLayer(3, channels[size], 1)]

log_size = int(math.log(size, 2))

in_channel = channels[size]

for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
out_channel = channels[2**(i - 1)]

convs.append(ResBlock(in_channel, out_channel, blur_kernel))

in_channel = out_channel

self.convs = nn.Sequential(*convs)

self.stddev_group = 4
self.stddev_feat = 1

self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4] * 4 * 4,
channels[4],
activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)

def forward(self, input):
out = self.convs(input)

batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.reshape((
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
))
stddev = paddle.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = out.reshape((group, -1, self.stddev_feat,
channel // self.stddev_feat, height, width))
stddev = paddle.sqrt(var(stddev, 0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
stddev = stddev.tile((group, 1, height, width))
out = paddle.concat([out, stddev], 1)

out = self.final_conv(out)

out = out.reshape((batch, -1))
out = self.final_linear(out)

return out
Loading

0 comments on commit 530a6a8

Please sign in to comment.