This repository contains the PyTorch
implementation for the submission Diffusion Schrödinger Bridge Matching.
The goal of learning Schrödinger Bridges is to build a bridge between two distributions
- Generative modeling: Gaussian
$\rightarrow$ Data distribution. - Data translation: Data distribution 1
$\rightarrow$ Data distribution 2.
The bridge is represented by a (stochastic) process
Schrödinger bridges not only impose extremal constraints that the bridge must have the right distributions at time
Minimizing the energy of the path can also be interpreted at minimizing the Kullback-Leibler divergence between the measure of the bridge
The solution
-
$\mathbb{P}^\star_0 = \pi_0$ . -
$\mathbb{P}^\star_1 = \pi_T$ . -
$\mathbb{P}^\star$ is Markov. -
$\mathbb{P}^\star$ is in the reciprocal class of$\mathbb{Q}$ , i.e.$\mathbb{P}^\star_ {|0,T} = \mathbb{Q}_ {|0,T}$ (the measures$\mathbb{P}^\star$ and$\mathbb{Q}$ are the same when conditioned on the initial and terminal conditions).
The Iterative Proportional Fitting (IPF) procedure proceeds by alternatively projecting the measure on the conditions 1 and 2. The conditions 3 and 4 are satisfied for all the iterates. The new Iterative Markovian Fitting (IMF) procedure we propose alternatively projects on the condition 3 and 4, while preserving the conditions 1 and 2.
We denote
We refer to our paper for details on the implementation of these projections. The practical algorithm associated with IMF leverages Flow and Bridge Matching. We call this practical algorithm Diffusion Schrödinger Bridge Matching (DSBM).
We provide a singularity container recipe in bridge.def
which can be used to set up a singularity container. Alternatively, a conda environment can be set up manually using the conda installation commands in bridge.def
.
A self-contained Gaussian experiment benchmark is provided in DSBM-Gaussian.py
.
DSB: python DSBM-Gaussian.py dim=5,20,50 model_name=dsb seed=1,2,3,4,5 inner_iters=10000 -m
IMF-b: python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm first_coupling=ind seed=1,2,3,4,5 inner_iters=10000 fb_sequence=['b'] -m
DSBM-IPF: python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm seed=1,2,3,4,5 inner_iters=10000 -m
DSBM-IMF: python DSBM-Gaussian.py dim=5,20,50 model_name=dsbm first_coupling=ind seed=1,2,3,4,5 inner_iters=10000 -m
Rectified Flow: python DSBM-Gaussian.py dim=5,20,50 model_name=rectifiedflow seed=1,2,3,4,5 inner_iters=10000 fb_sequence=[b] -m
SB-CFM: python DSBM-Gaussian.py dim=5,20,50 model_name=sbcfm seed=1,2,3,4,5 inner_iters=10000 -m
DSBM-IPF: python main.py num_steps=30 num_iter=5000 method=dbdsb gamma_min=0.034 gamma_max=0.034
DSBM-IMF: python main.py num_steps=30 num_iter=5000 method=dbdsb first_num_iter=100000 gamma_min=0.034 gamma_max=0.034 first_coupling=ind
For the dataset, it can be downloaded and processed using the script https://github.com/CliMA/diffusion-bridge-downscaling/blob/main/CliMAgen/examples/utils_data.jl
, then save as numpy arrays in ./data/downscaler
.
DSBM-IPF: python main.py dataset=downscaler_transfer num_steps=30 num_iter=5000 gamma_min=0.01 gamma_max=0.01 model=DownscalerUNET
DSBM-IMF: python main.py dataset=downscaler_transfer num_steps=30 num_iter=5000 gamma_min=0.01 gamma_max=0.01 model=DownscalerUNET first_coupling=ind