Cooperative learning of energy-based model and variational auto-encoder
This repository contains a pytorch implementation for the paper "Learning energy-based model with variational auto-encoder as amortized sampler" AAAI 2021.
conda create --name coopvaebm python=2.7
conda activate coopvaebm
conda install tensorflow-gpu==1.12.0
conda install Pillow
Download the cifar10 checkpoint from here
(1) Cifar10 dataset
python main.py --test False
(2) MNIST dataset
python main.py --net_type 'mnist' --category 'mnist' --image_size 28 --num_channel 1 --batch_size 400 --nTileRow 20 --nTileCol 20 --des_step_size 0.001 --des_sample_steps 50 --vae_lr 0.0001 --weight_latent_loss 3
(3) MNIST-fashion dataset
python main.py --net_type 'mnist' --category 'mnist-fashion' --image_size 28 --num_channel 1 --batch_size 400 --nTileRow 20 --nTileCol 20 --des_step_size 0.001 --des_sample_steps 50 --vae_lr 0.0001 --weight_latent_loss 3
python main.py --test True --test_type 'syn' --ckpt 'pretrained/checkpoints/cifar/model.ckpt-3000'
python main.py --test True --test_type 'inter' --ckpt 'pretrained/checkpoints/cifar/model.ckpt-3000'
@inproceedings{xie2021learning,
title={Learning energybased model with variational auto-encoder as amortized sampler},
author={Xie, Jianwen and Zheng, Zilong and Li, Ping},
booktitle={The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI)},
year={2021}
}