Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View
Xuanchi Ren*, Tao Yang*, Yuwang Wang and Wenjun Zeng
ICLR 2022
* indicates equal contribution
✅ Update StyleGAN2
✅ Update SNGAN
✅ Evaluation
NOTE: The code base for VAE and Glow is not easy to merge in this repo. If you are interested in them, please connect me!
In this repo, we propose an unsupervised and model-agnostic method: Disentanglement via Contrast (DisCo) in the Variation Space. This code discovers disentangled directions in the latent space and extract disentangled representations from images with Contrastive Learning. DisCo achieves the state-of-the-art disentanglement given pretrained non-disentangled generative models, including GAN, VAE, and Flow.
NOTE: The following results are obtained in a completely unsupervised manner. More results (including VAE and Flow) are presented in Appendix.
FFHQ StyleGAN2 | |
---|---|
Pose | Smile |
Race | Oldness |
Overexpose | Hair |
Shapes3D StyleGAN2 | |
---|---|
Wall Color | Floor Color |
Object Color | Pose |
Car3D StyleGAN2 | |
---|---|
Azimuth | Yaw |
Anime SNGAN | |
---|---|
Pose | Natureness |
Glass | Tone |
NOTE: DisCo achieves the state-of-the-art disentanglement
Shapes3D | |
---|---|
MIG | DCI |
Car3D | |
---|---|
MIG | DCI |
MPI3D | |
---|---|
MIG | DCI |
- NVIDIA GPU + CUDA CuDNN
- Python 3
- Clone the repository:
git clone https://github.com/xrenaa/DisCo.git
cd DisCo
- Dependencies (To Do):
We recommend running this repository using Anaconda. The main dependencies arePyTorch 1.3.1
andCUDA 10.1/10.2
. If you encounter issues with the dependencies, please refer to issues in https://github.com/rosinality/stylegan2-pytorch.
- Docker:
Alternatively, you can useDocker
to run the code. We providethomasyt/gan-disc
for easy use.
Please download the pre-trained models from the following links and put them to the corresponding paths.
Path | Description |
---|---|
shapes3d_StyleGAN | StyleGAN2 model pretrained on shapes3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/shapes3d/ . |
cars3d_StyleGAN | StyleGAN2 model pretrained on cars3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/cars3d/ . |
mpi3d_StyleGAN | StyleGAN2 model pretrained on mpi3d : range from 0-4.pt . Corresponding path: ./pretrained_weights/mpi3d/ . |
shapes3d_VAE | VAE model pretrained on shapes3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/shapes3d/ . |
cars3d_VAE | VAE model pretrained on cars3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/cars3d/ . |
mpi3d_VAE | VAE model pretrained on mpi3d : range from VAE_0-4 . Corresponding path: ./pretrained_weights/mpi3d/ . |
For SNGAN, you can run the following code to download the weights for MNIST
and Anime
:
python ./pretrained_weights/download.py
To train the models, make sure you download the required models and put them to the correct path.
python train.py \
--G stylegan \
--dataset 0 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset
, you can choose 0
for shapes3D, 1
for mpi3d, 2
for cars3d.
python train.py \
--G sngan \
--dataset 5 \
--exp_name your_name \
--B 32 \
--N 32 \
--K 64
For --dataset
, you can choose 5
for MNIST, 6
for Anime.
-
Dependencies: For evaluation, you will need
tensorflow
,gin-config
. -
Download the dataset (except for Shapes3D):
cd data
./dlib_download_data.sh
For Shapes3D, you will first need to download the data from Google Cloud Storage. Click on this link and left-click the file 3dshapes.h5
to download it. Then you should put it under directory data
.
- Run the evaluation:
python evaluate.py --dataset 0 --exp_name your_name
For --dataset
, you can choose 0
for shapes3D, 1
for mpi3d, 2
for cars3d (you can only evaluate the performance on these datasets). The results will be put under the same directory with the checkpoint.
Navigator and SNGAN are based on: https://github.com/anvoynov/GANLatentDiscovery.
StyleGAN are based on: https://github.com/rosinality/stylegan2-pytorch.
Disentanglement metrics are based on: https://github.com/google-research/disentanglement_lib.
@inproceedings{ren2022DisCo,
title = {Learning Disentangled Representation by Exploiting Pretrained Generative Models: A Contrastive Learning View},
author = {Ren, Xuanchi and Yang, Tao and Wang, Yuwang and Zeng, Wenjun},
booktitle = {ICLR},
year = {2022}
}