Model-agnostic variational autoencoder tools.
Here we are putting together a set of generic modules for variational autoencoder related tasks.
In what follows, no assumptions are made as to the nature of latent variables (they can be both discrete and continuous). However, it is assumed that the model is a "black-box" one, i.e. we consider the most general case without resorting to reparametrization. Special cases allowing reparametrization will be considered/implemented later.
We follow a uniform API for inference network, generative model, and loss calculation, allowing for modularity and ease of use. The API is as follows.
The inference network and generative model essentially represent probability
distributions and therefore have the same API implementing forward
, sample
,
and log_prob
methods. Conceptually, this amounts to:
- Calling the object. This will perform forward pass and update the internal state of the underlying probability distribution.
- Sampling from the probability distribution.
- Log-probability calculation.
The loss API amounts to simply passing the entire model and training batch to the loss object. This will allow the loss object to drive the model (using the model's API described above) in accordance with its specific forward pass requirements (e.g., the wake update forward pass is different from sleep update's one). The loss instance will also collect and store loss history during training.
In more detail, the API is:
Inference network: |
|
---|---|
Generative model: |
|
Loss: |
The loss instance
Also, optionally, loss object collects additional information, such as |
Data format: | The shape and format of |
TODO:
Add abstract classes (templates)...
Wake update sequence...
Sleep update sequence...
Short version with one-liners...
Concrete examples with, say, logits for spike-inference...
- Diederik P. Kingma and Max Welling, Auto-Encoding Variational Bayes, arXiv:1312.6114v10 [stat.ML], 2013.
- Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra, Stochastic Backpropagation and Approximate Inference in Deep Generative Models, arXiv:1401.4082 [stat.ML], 2014.
- Yuri Burda, Roger Grosse, Ruslan Salakhutdinov, Importance Weighted Autoencoders, arXiv:1509.00519v4 [cs.LG], 2015.
- John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel, Gradient Estimation Using Stochastic Computation Graphs, arXiv:1506.05254v3 [cs.LG], 2016.
- Shakir Mohamed, Mihaela Rosca, Michael Figurnov, Andriy Mnih, Monte Carlo Gradient Estimation in Machine Learning, arXiv:1906.10652v1 [stat.ML], 2019.
- Andriy Mnih and Danilo J. Rezende, Variational Inference for Monte Carlo Objectives, arXiv:1602.06725v2 [cs.LG], 2016.
- Andriy Mnih and Karol Gregor, Neural Variational Inference and Learning in Belief Networks, arXiv:1402.0030v2 [cs.LG], 2014.
- Jorg Bornschein, Yoshua Bengio, Reweighted Wake-Sleep, arXiv:1406.2751v4 [cs.LG], 2015.
- Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood, Revisiting Reweighted Wake-Sleep, arXiv:1805.10469v2 [stat.ML], 2019.
- Chris J. Maddison, Andriy Mnih, Yee Whye Teh, The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables, arXiv:1611.00712v3 [cs.LG], 2017.
- Eric Jang, Shixiang Gu, Ben Poole, Categorical Reparameterization with Gumbel-Softmax, arXiv:1611.01144v5 [stat.ML], 2017.