This package implements the Rotation-Based Iterative Gaussianization (RBIG) algorithm using Jax. It is a normalizing flow algorithm that can transform any multi-dimensional distribution into a Gaussian distribution using a sequence of simple marginal Gaussianization transforms (e.g. histogram) and rotations (e.g. PCA). It is invertible which means you can calculate probabilities as well as sample from your distribution. Seen the example below for details.
Demo Colab Notebooks
Mainly because I wanted to practice. It's an iterative scheme so perhaps Jax isn't the best for this. But I would like to improve my functional programming skills. In addition, Jax is much faster because of the jit compilation and autobatching. So it handles some difficult aspects of programming a lot easier. Also, the same code can be used for CPU, GPU and TPU with only minor changes. Overall, I didn't see any downside to having some free speed-ups.
This repo uses the most updated jax
library on github so this is absolutely essential, e.g. it uses the latest np.interp
function which isn't on the pip
distribution yet. The environment.yml
file will have the most updated distribution.
- Clone the repository.
git clone https://github.com/IPL-UV/rbig_jax
- Install using conda.
conda env create -f environment.yml
- If you already have the environment installed, you can update it.
conda activate jaxrbig
conda env update --file environment.yml
- Python Code - github
- RBIG applied to Earth Observation Data - github
- Original Webpage - ISP
- Original MATLAB Code - webpage
- Original Python Code - github
- Paper - Iterative Gaussianization: from ICA to Random Rotations
This work was supported by the European Research Council (ERC) Synergy Grant “Understanding and Modelling the Earth System with Machine Learning (USMILE)” under Grant Agreement No 855187.