This is the codebase for Elucidating The Design Space of Classifier-Guided Diffusion Generation, accepted by ICLR2024 as poster
This repository contains three main folders, targeting the off-the-shelf classifier guidance for DDPM, EDM and DiT respectively.
export PYTHONPATH="${PYTHONPATH}:...{Folder_Path}/EluCD-main/DDPM"
pip install blobfile
pip install mpi4py
For all the pre-trained diffusion, classifier models and reference batch, please place them in the ./DDPM/pretrained_models
folder:
For DDPM diffusion models, the ImageNet128x128 Diffusion model and fine-tuned classifier is from DDPM diffusion model,
We use the off-the-shelf Pytorch ResNet classifier: ResNet50 and ResNet101 classifier for guided-sampling.
For FID evaluation, use the ImageNet 128x128 reference batch .
Firstly go to folder ./DDPM
, which contains all files for the off-the-shelf classifier guidance for DDPM diffusion model.
you can directly run ./DDPM/guided_sample.sh
.
All the model checkpoints are stored in the ./DDPM/pretrained_models/
folder.
For FID evaluation, use ./pytorch-fid-master/src/evaluation_image.sh
, and replace the filename with the sample folder name you created.
Run ./guided_sample.sh
, to generate sample the off-the-shelf classifier guided sampling
- off-the-shelf ResNet101 DDPM guided:
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_type resnet101 --classifier_scale 1.0 --softplus_beta 3.0 --joint_temperature 1.0 --margin_temperature_discount 0.5 --gamma_factor 0.3 --classifier_use_fp16 True"
This table summarizes our conditional ImageNet128x128 generation results for 250 steps of DDPM guided sampling:
ImageNet 128x128 | FID | Precision | Recall |
---|---|---|---|
Diffusion Baseline | 5.91 | 0.70 | 0.65 |
Diffusion Finetune classifier Guided | 2.97 | 0.78 | 0.59 |
Diffusion Classifier-Free | 2.43 | - | - |
Diffusion ResNet50 Guided (Ours) | 2.36 | 0.77 | 0.60 |
Diffusion ResNet101 Guided (Ours) | 2.19 | 0.79 | 0.58 |
For all the pre-trained diffusion, classifier models and reference batch, please place them in the ./EDM/pretrained_models
folder:
For EDM diffusion models, the ImageNet64x64 Diffusion model is from EDM diffusion model,
We use the off-the-shelf Pytorch ResNet classifier: ResNet50 and ResNet101 classifier for guided-sampling.
For FID evaluation, use the ImageNet 64x64 reference batch.
Firstly go to folder ./EDM
, which contains all files for the off-the-shelf classifier guidance for EDM diffusion model.
you can directly run ./EDM/guided_sample.sh
.
All the model checkpoints are stored in the ./EDM/pretrained_models/
folder.
For FID evaluation, use ./pytorch-fid-master/src/evaluation_image.sh
, and replace the filename with the sample folder name you created; the reference batch uses ImageNet 64x64 reference batch.
Run ./guided_sample.sh
, to generate sample the off-the-shelf classifier guided sampling
This table summarizes our conditional ImageNet64x64 generation results for diverse sampling steps of EDM-guided sampling:
ImageNet 64x64 | Classifier | FID | Steps |
---|---|---|---|
EDM baseline | - | 2.35 | 36 |
EDM Res101 guided | Off-the-Shelf | 2.22 | 36 |
EDM baseline | - | 2.54 | 18 |
EDM Res101 guided | Off-the-Shelf | 2.35 | 18 |
EDM baseline | - | 3.64 | 10 |
EDM Res101 guided | Off-the-Shelf | 3.38 | 10 |
We will release the code of the off-the-shelf classifier guided sampling for DiT soon.