Skip to content

Latest commit

 

History

History
 
 

Styleformer

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Styleformer: Transformer based Generative Adversarial Networks with Style Vector, arxiv

PaddlePaddle training/validation code and pretrained models for Styleformer.

The official pytorch implementation is here.

This implementation is developed by PaddleViT.

drawing

Styleformer Model Overview

Update

Update (2021-08-17): Code is released and ported weights are uploaded.

Models Zoo

Model FID Image Size Crop_pct Interpolation Model
styleformer_cifar10 2.73 32 1.0 lanczos google/baidu(ztky)
styleformer_stl10 15.65 48 1.0 lanczos google/baidu(i973)
styleformer_celeba 3.32 64 1.0 lanczos google/baidu(fh5s)
styleformer_lsun 9.68 128 1.0 lanczos google/baidu(158t)

*The results are evaluated on Cifar10, STL10, Celeba and LSUNchurch dataset, using fid50k_full metric.

Notebooks

We provide a few notebooks in aistudio to help you get started:

*(coming soon)*

Requirements

Data

STL10, Celeba and LSUNchurch dataset is used in the following folder structure:

│STL10/
├── train_X.bin
│── train_y.bin
├── test_X.bin
│── test_y.bin
│── unlabeled.bin
│Celeba/
├──img_align_celeba/
│  ├── 000017.jpg
│  │── 000019.jpg
│  ├── 000026.jpg
│  │── unlabeled.bin
│  │── ......
│LSUNchurch/
├──church_outdoor_train_lmdb/
│  ├── data.mdb
│  │── lock.mdb

Usage

To use the model with pretrained weights, download the .pdparam weight file and change related file paths in the following python scripts. The model config files are located in ./configs/.

For example, assume the downloaded weight file is stored in ./cifar10.pdparams, to use the styleformer_cifar10 model in python:

from config import get_config
from generator import Generator
# config files in ./configs/
config = get_config('./configs/styleformer_cifar10.yaml')
# build model
model = Generator(config)
# load pretrained weights, .pdparams is NOT needed
model_state_dict = paddle.load('./cifar10')
model.set_dict(model_state_dict)

Generate Sample Images

To generate sample images from pretrained models, download the pretrained weights, and run the following script using command line:

sh run_generate.sh

or

python generate.py \
  -cfg='./configs/styleformer_cifar10.yaml' \
  -num_out_images=16 \
  -out_folder='./images_cifar10' \
  -pretrained='./cifar10.pdparams'

The output images are stored in -out_folder path.

Evaluation

To evaluate Styleformer model performance on Cifar10 with a single GPU, run the following script using command line:

sh run_eval.sh

or

CUDA_VISIBLE_DEVICES=0 \
python main_single_gpu.py \
    -cfg='./configs/styleformer_cifar10.yaml' \
    -dataset='cifar10' \
    -batch_size=32 \
    -eval \
    -pretrained='./cifar10'
Run evaluation using multi-GPUs:
sh run_eval_multi.sh

or

CUDA_VISIBLE_DEVICES=0,1,2,3 \
python main_single_gpu.py \
    -cfg='./configs/styleformer_cifar10.yaml' \
    -dataset='cifar10' \
    -batch_size=32 \
    -eval \
    -pretrained='./cifar10'

Training

To train the Styleformer Transformer model on Cifar10 with single GPU, run the following script using command line:

sh run_train.sh

or

CUDA_VISIBLE_DEVICES=0 \
python main_single_gpu.py \
    -cfg='./configs/styleformer_cifar10.yaml' \
    -dataset='cifar10' \
    -batch_size=32 \
    -pretrained='./cifar10'
Run training using multi-GPUs:
sh run_train_multi.sh

or

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python main_single_gpu.py \
    -cfg='./configs/styleformer_cifar10.yaml' \
    -dataset='cifar10' \
    -batch_size=32 \
    -pretrained='./cifar10'

Visualization of Generated Images

Generated Images after Training

drawing

Generated Images from CelebA(left) and LSUN-church(right) datasets

Generated Images during Training

(coming soon)

Reference

@article{park2021styleformer,
      title={Styleformer: Transformer based Generative Adversarial Networks with Style Vector}, 
      author={Jeeseung Park and Younggeun Kim},
      year={2021},
      eprint={2106.07023},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}