Skip to content

Commit

Permalink
add stargan-v2 test and train code (PaddlePaddle#165)
Browse files Browse the repository at this point in the history
* add configuration options
e.g. python tools/main.py --c configs/stylegan_v2_256_ffhq.yaml -o total_iters=1 log_config.visiual_interval=1

* add stargan-v2 test and train code

* code normalization

* modify FAN code

* add munch in requirements.txt
  • Loading branch information
lyl120117 authored Mar 2, 2021
1 parent 0dab3be commit 8ece3c2
Show file tree
Hide file tree
Showing 15 changed files with 1,449 additions and 0 deletions.
141 changes: 141 additions & 0 deletions configs/starganv2_afhq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
epochs: 200
output_dir: output_dir

model:
name: StarGANv2Model
latent_dim: &LATENT_DIM 16
lambda_sty: 1
lambda_ds: 2
lambda_cyc: 1
generator:
name: StarGANv2Generator
img_size: &IMAGE_SIZE 256
w_hpf: 0
style_dim: &STYLE_DIM 64
style:
name: StarGANv2Style
img_size: *IMAGE_SIZE
style_dim: *STYLE_DIM
num_domains: &NUM_DOMAINS 3
mapping:
name: StarGANv2Mapping
latent_dim: *LATENT_DIM
style_dim: *STYLE_DIM
num_domains: *NUM_DOMAINS
discriminator:
name: StarGANv2Discriminator
img_size: *IMAGE_SIZE
num_domains: *NUM_DOMAINS

dataset:
train:
name: StarGANv2Dataset
dataroot: data/stargan-v2/afhq/train
is_train: True
num_workers: 8
batch_size: 4
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: LoadImageFromFile
key: ref2
- name: Transforms
input_keys: [src, ref, ref2]
pipeline:
- name: RandomResizedCropProb
prob: 0.9
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image, image]

test:
name: StarGANv2Dataset
dataroot: data/stargan-v2/afhq/val
is_train: False
num_workers: 8
batch_size: 16
test_count: 16
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: Transforms
input_keys: [src, ref]
pipeline:
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]

lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 365

optimizer:
generator:
name: Adam
net_names:
- generator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
style_encoder:
name: Adam
net_names:
- style_encoder
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
mapping_network:
name: Adam
net_names:
- mapping_network
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
discriminator:
name: Adam
net_names:
- discriminator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001

validate:
interval: 5000
save_img: false

log_config:
interval: 5
visiual_interval: 100

snapshot_config:
interval: 5
144 changes: 144 additions & 0 deletions configs/starganv2_celeba_hq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
epochs: 200
output_dir: output_dir

model:
name: StarGANv2Model
latent_dim: &LATENT_DIM 16
lambda_sty: 1
lambda_ds: 1
lambda_cyc: 1
generator:
name: StarGANv2Generator
img_size: &IMAGE_SIZE 256
w_hpf: 1
style_dim: &STYLE_DIM 64
style:
name: StarGANv2Style
img_size: *IMAGE_SIZE
style_dim: *STYLE_DIM
num_domains: &NUM_DOMAINS 2
mapping:
name: StarGANv2Mapping
latent_dim: *LATENT_DIM
style_dim: *STYLE_DIM
num_domains: *NUM_DOMAINS
fan:
name: FAN
fname_pretrained: models/stargan-v2/wing.pdparams
discriminator:
name: StarGANv2Discriminator
img_size: *IMAGE_SIZE
num_domains: *NUM_DOMAINS

dataset:
train:
name: StarGANv2Dataset
dataroot: data/stargan-v2/celeba_hq/train/
is_train: True
num_workers: 8
batch_size: 4
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: LoadImageFromFile
key: ref2
- name: Transforms
input_keys: [src, ref, ref2]
pipeline:
- name: RandomResizedCropProb
prob: 0.9
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image, image]

test:
name: StarGANv2Dataset
dataroot: data/stargan-v2/celeba_hq/val/
is_train: False
num_workers: 8
batch_size: 16
test_count: 16
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: Transforms
input_keys: [src, ref]
pipeline:
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]

lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 365

optimizer:
generator:
name: Adam
net_names:
- generator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
style_encoder:
name: Adam
net_names:
- style_encoder
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
mapping_network:
name: Adam
net_names:
- mapping_network
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
discriminator:
name: Adam
net_names:
- discriminator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001

validate:
interval: 5000
save_img: false

log_config:
interval: 5
visiual_interval: 100

snapshot_config:
interval: 5
1 change: 1 addition & 0 deletions ppgan/datasets/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
19 changes: 19 additions & 0 deletions ppgan/datasets/preprocess/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ def _apply_image(self, image):
return image


@TRANSFORMS.register()
class RandomResizedCropProb(T.RandomResizedCrop):
"""RandomResizedCropProb.
Args:
prob (float): probabilty of using random-resized cropping.
size (int): cropped size.
"""
def __init__(self, prob, size, scale, ratio, interpolation, keys=None):
super().__init__(size, scale, ratio, interpolation)
self.prob = prob
self.keys = keys

def _apply_image(self, image):
if random.random() < self.prob:
image = super()._apply_image(image)
return image


@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
Expand Down
Loading

0 comments on commit 8ece3c2

Please sign in to comment.