We provide PyTorch implementation for "Multi-marginal Wasserstein GAN".
News: A web demo is provided to show the translation performance on website DEMO!
Multi-marginal Wasserstein GAN
Jiezhang Cao *, Langyuan Mo *, Yifan Zhang, Kui Jia, Chunhua Shen, Mingkui Tan *
Advances in Neural Information Processing Systems(NeurIPS), 2019
- Python3
- PyTorch 1.0.1
- tensorboardX (optional, need TensorFlow)
- dependencies in requirements.txt
-
Clone this repo:
git clone https://github.com/deepmo24/MWGAN.git cd MWGAN
-
Install pytorch and other dependencies.
- pip install -r requirements.txt
- tensorboardX and tensorflow are optional for logging
-
Download the CelebA dataset and corresponding attibute labels.
- Link: Dropbox or BaiduNetdisk
- Put the data in
./data/celeba
directory
-
Construct the facial attribute translation dataset (i.e. Black_Hair, Blond_Hair, Eyeglasses, Mustache, Pale_Skin).
python data_process.py --process celeba --source_attr Black_Hair
-
Construct the edge -> celeba dataset (i.e. Edge, Black_Hair, Blond_Hair, Brown_Hair).
- Organize data using folder structure described here.
-
Get Black_Hair, Blond_Hair data from step 2.
-
Get Brown_Hair data:
python data_process.py --process celeba \ --selected_attrs Brown_Hair --target_dir data/Edge2Celeba
-
Get Edge data:
python data_process.py --process edge \ --source_dir data/Edge2Celeba --target_dir data/Edge2Celeba\ --selected_attrs Black_Hair Blond_Hair Brown_Hair --select_nums 15000
-
- Organize data using folder structure described here.
To train MWGAN on facial attribute translation task:
python main.py --num_domains 5 --batch_size 16 \
--data_root data/Celeba5domain/train --src_domain Black_Hair \
--result_root results_celeba \
--lambda_cls 1 --lambda_info 20 --lambda_idt 10
To train MWGAN on edge->celeba task:
python main.py --num_domains 4 --batch_size 16 \
--data_root data/Edge2Celeba/train --src_domain Edge \
--result_root results_edge \
--lambda_cls 10 --lambda_info 10 --cls_loss BCE
- if you don't have tensorboardX and tensorflow, please add
--use_tensorboard false
To test MWGAN on facial attribute translation task:
python main.py --mode test --num_domains 5 --batch_size 16 \
--data_root data/Celeba5domain/test --src_domain Black_Hair \
--result_root results_celeba
To test MWGAN on edge->celeba task:
python main.py --mode test --num_domains 4 --batch_size 16 \
--data_root data/Edge2Celeba/test --src_domain Edge \
--result_root results_edge
Download the pretrained models.
- Link: Dropbox or BaiduNetdisk
- For facial attribute translation: put the pretrained models in
./celeba_pretrained/models
directory - For edge->celeba: put the pretrained models in
./edge_pretrained/models
directory
To test MWGAN using the pretrained models:
python main.py --mode test --num_domains 5 --batch_size 16 \
--data_root data/Celeba5domain/test --src_domain Black_Hair \
--result_root celeba_pretrained
or
python main.py --mode test --num_domains 4 --batch_size 16 \
--data_root data/Edge2Celeba/test --src_domain Edge \
--result_root edge_pretrained
If this work is useful for your research, please cite our paper:
@InProceedings{cao2019mwgan,
title = {Multi-marginal Wasserstein GAN},
author = {Cao, Jiezhang and Mo, Langyuan and Zhang, Yifan and Jia, Kui and Shen, Chunhua and Tan, Mingkui},
booktitle = {Advances in Neural Information Processing Systems},
year = {2019}
}
Our code is inspired by StarGAN.