A python script to evaluate and plot the discretized Wasserstein-2 gradient flow starting at an empirical measure with respect to an Maximum-Mean-Discrepancy-regularized f-divergence functional, whose target is an empirical measure as well.
This repository provides the method
MMD_reg_f_div_flow
(from the file MMD_reg_fDiv_ParticleFlows.py
)
used to produce the numerical experiments for the paper
Wasserstein Gradient Flows for Moreau Envelopes of f-Divergences in Reproducing Kernel Hilbert Spaces by Sebastian Neumayer, Viktor Stein, Gabriele Steidl and Nikolaj Rux.
If you use this code please cite this preprint, preferably like this:
@unpublished{NSSR24,
author = {Neumayer, Sebastian and Stein, Viktor and Steidl, Gabriele and Rux, Nicolaj},
title = {Wasserstein Gradient Flows for {M}oreau Envelopes of $f$-Divergences in Reproducing Kernel {H}ilbert Spaces},
note = {ArXiv preprint},
volume = {arXiv:2402.04613},
year = {2024},
month = {Feb},
url = {https://arxiv.org/abs/2402.04613},
doi = {10.48550/arXiv.2402.04613}
}
The other python files contain auxillary functions.
Scripts to exactly reproduce the figures in the preprint are soon to come. An example file is AlphaComparison.py
.
This code is written and maintained by Viktor Stein. Any comments, feedback, questions and bug reports are welcome! Alternatively you can use the GitHub issue tracker.
- Required packages
- Supported kernels
- Supported
$f$ -divergences / entropy functions - Supported targets
This script requires the following Python packages. We tested the code with Python 3.11.7 and the following package versions:
- torch 2.1.2
- scipy 1.12.0
- numpy 1.26.3
- pillow 10.2.0 (if you want to generate a gif of the evolution of the flow)
- matplotlib 3.8.2
- pot 0.9.3 (if you want to evaluate the exact Wasserstein-2 loss along the flow)
- sklearn.datasets 1.4.1.post1 (for more targets)
Usually code is also compatible with some later or earlier versions of those packages.
The following kernels all are radial and twice-differentiable, hence fulfilling all assumptions in the paper.
We denote the reLU by
Kernel | Name | Expression |
---|---|---|
inverse multiquadric | imq |
|
Gauss | gauss |
|
Matérn- |
matern |
|
Matérn- |
matern2 |
|
|
compact |
|
Another Spline | compact2 |
|
inverse log | inv_log |
|
inverse quadric | inv_quad |
|
student t | student |
We also implemented the following two "$W_2$-metrizing kernels", which metrize the Wasserstein-2 distance on
Kernel | Name | Expression |
---|---|---|
|
W2_1 |
|
|
W2_2 |
The following entropy functions each have an infinite recession constant if
Entropy | Name | Expression |
---|---|---|
Kullback-Leibler |
tsallis , |
|
Tsallis- |
tsallis |
|
Jeffreys | jeffreys |
|
chi |
Below we list some other implemented entropy functions with finite recession constant. For even more entropy functions we refer to table 1 in the above mentioned preprint.
Entropy | Name | Expression |
---|---|---|
Burg | reverse_kl |
|
Jensen-Shannon | jensen_shannon |
|
total variation | tv |
|
Matusita | matusita |
|
Kafka | kafka |
|
Marton | marton |
|
perimeter | perimeter |
bananas
: the two parabolas in the gif at the topcircles
: three circles
cross
: four versions of Neals funnel arranged in a cross shape
-
GMM
: two exactly equal Gaussians which have a symmetry axis at$y = - x$
four_wells
: a sum of four Gaussians, which don't have a symmetry axis. The initial measure is initiated at one of the Gaussians.
swiss_role_2d
:
We also include some target measures from sklearn.data
: moons
, annulus
and the three-dimensional data sets swiss_role_3d
and s_curve
.
I am still working on improving the speed of this script, the bottleneck being the L-BFGS-B on the CPU. Currently, running the simulation for 50000 steps (exact parameters: tsallis-divergence, alpha=3, lambd=1.0, tau=0.001, kernel = IMQ, sigma = 0.5, N = 900, target_name = bananas) takes less than 12 minutes on a CUDA 7.5 GPU with 12 GB of RAM.