Deep Generalized Schrödinger Bridge
[NeurIPS 2022 Oral]
Official PyTorch implementation of the paper "Deep Generalized Schrödinger Bridge (DeepGSB)" which introduces a new class of diffusion models as a scalable numerical solver for Mean-Field Games (MFGs), e.g., population modeling & opinion depolarization, with hard distributional constraints.
Population modeling (crowd navigation) | Opinion depolarization | |||
This repo is co-maintained by Guan-Horng Liu, Tianrong Chen, and Oswin So. Contact us if you have any questions! If you find this library useful, please cite ⬇️
@inproceedings{liu2022deep,
title={Deep Generalized Schr{\"o}dinger Bridge},
author={Liu, Guan-Horng and Chen, Tianrong and So, Oswin and Theodorou, Evangelos A},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
Install the dependencies with Anaconda and activate the environment deepgsb
with
conda env create --file requirements.yaml
conda activate deepgsb
The repo contains 2 classes of Mean-Field Games, namely
- population modeling:
GMM
,Vneck
,Stunnel
- opinion depolarization:
opinion
,opinion-1k
(dim=1000).
The commands to generate similar results shown in our paper can be found in run.sh
. Results, checkpoints, and tensorboard log files will be saved respectively to the folders results/
, checkpoint/
, and runs/
.
bash run.sh <problem> # <problem> can be {GMM, Vneck, Stunnel, opinion, opinion-1k}
You can visualize the trained DeepGSB policies by making gif animation
python make_animation.py --load <path to checkpoint npz> --name <gif name>
We briefly document the file structure to ease the effort if you wish to integrate DeepGSB with your work flow.
deepgsb/
├── deepgsb.py # the DeepGSB MFG solver
├── sb_policy.py # the parametrized Schrödinger Bridge policy
├── loss_lib.py # all loss functions (IPF/KL, TD, FK/grad)
├── eval_metrics.py # all logging metrics (Wasserstein, etc)
├── replay_buffer.py
└── util.py
mfg/
├── mfg.py # the Mean-Field Game environment
├── constraint.py # the distributional boundary constraint (p0, pT)
├── state_cost.py # all mean-field interaction state costs (F)
├── sde.py # the associated stochastic processes (f, sigma)
├── opinion_lib.py # all utilities for opinion depolarization MFG
├── plotting.py
└── util.py
models/ # the deep networks for parametrizing SB policy
configs/ # the configurations for each MFG