This codebase provides an implementation of the diffusion model composition experiments on images. The diffusion model implementation is based on Song et al.'s PyTorch tutorial notebook.
Create a conda environment using the requirements file with the following command:
conda env create --name diffusion_sculpting --file environment.yml
conda activate diffusion_sculpting
models
folder contains implementations for the various score models and classifierssamplers
contains an implementation of the PC samplercustom_datasets.py
implements the various datasets used to train the base modelssample_individual_model.py
implements sample generation from one of the base modelssample_composition.py
implements sample generation from a composition of 2 base modelssample_3way_composition.py
implements sample generation from a composition of 3 base modelstrain_diffusion.py
implements the training procedure for the base modelstrain_classifier.py
implements the training procedure for training a classifier to sample from 2 base modelstrain_3way_classifier.py
trains a classifier for classifying the first two observations in a composition of 3 base modelstrain_3way_conditional_classifier.py
trains a classifier for classifying the third observation in a composition of 3 base models
This is the experimental setting reported in the main paper.
This experiments composed 3 base models. The base models can be trained with:
python3 -m train_diffusion
This scrips contains a variable GEN_IDX
which determines which base model to train. Run the script three times with this variable set to "MN1"
, "MN2"
and "MN3"
respectively.
In this setting we will be composing our three base models based on three observations. However, not all observations are classified by the same classifier. We will use two classifiers, one to classify the first two observations, and a second one to classify the last one conditioned on the first two.
Train the first classifier with
python3 -m train_3way_classifier.py
Once that has finished, train the second classifier with
python3 -m train_3way_conditional_classifier.py
Once the classifiers have been trained we can generate samples from the resulting composition. To do this run
python3 -m sample_3way_composition.py
By default this generates samples from the composition correspondig to y_1=1,y_2=2,y_3=3
. You can change this in the code by changing the definition of the composition. The composition is defined in two steps. First we construct a BinaryDiffusionComposition
with the first two observations. Then we construct a ConditionalDiffusionComposition
with the third observation.
These two experiments are a simplification of the experiment above. Their results are discussed in the appendix of the paper.
These experiments composed 2 base models. The base models can be trained with:
python3 -m train_classifier
This scrips contains a variable GEN_IDX
which determines which base model to train. For the ColorMNIST instance set GEN_IDX
to "M1"
and "M2"
respectively. For subdigits set GEN_IDX
to 1
and 2
respectively.
In this setting we will be composing our base models based on two observations.
Train the classifier with
python3 -m train_3way_classifier.py
To train the classifier correctly we need to ensure that the correct score models are loaded. For ColorMNIST, please ensure that score_model1
and score_model2
are loaded from checkpoints gen_M1_ckpt_195.pth
and gen_M2_ckpt_195.pth
respectively. For subdigits, change these checkpoints to gen_1_ckpt_195.pth
and gen_2_ckpt_195.pth
respectively and set the input_channels
variable in the ScoreNet constructors to 1
.
Once the classifier have been trained we can generate samples from the resulting composition. To do this run
python3 -m sample_composition.py
Similar to before, the composition is defined by a BinaryDiffusionComposition
object whose constructor takes the observations defining the composition as inputs. By default the code is set-up for the ColorMNIST condition. This means that score_model1
and score_model2
are loaded from checkpoints gen_M1_ckpt_195.pth
and gen_M2_ckpt_195.pth
respectively. For subdigits, change these checkpoints to gen_1_ckpt_195.pth
and gen_2_ckpt_195.pth
and set the INPUT_CHANNELS
global variable to 1
.