Last year, I released PureJaxRL, a simple repository that implements RL algorithms entirely end-to-end in JAX, which enables speedups of up to 4000x in RL training. PureJaxRL, in turn, was inspired by multiple projects, including CleanRL and Gymnax. Since the release of PureJaxRL, a large number of projects related to or inspired by PureJaxRL have come out, vastly expanding its use case from standard single-agent RL settings. This curated list contains those projects alongside other relevant implementations of algorithms, environments, tools, and tutorials.
To understand more about the benefits PureJaxRL, I recommend viewing the original blog post or tweet thread.
The PureJaxRL repository can be found here:
https://github.com/luchris429/purejaxrl/.
The format of the list is from awesome and awesome-jax. While this list is curated, it is certainly not complete. If you have a repository you would like to add, please contribute!
If you find this resource useful, please star the repo! It helps establish and grow the end-to-end JAX RL community.
-
purejaxrl - Classic and simple end-to-end RL training in pure JAX.
-
rejax - Modular and importable end-to-end JAX RL training.
-
Stoix - End-to-end JAX RL training with advanced logging, configs, and more.
-
purejaxql - Simple single-file end-to-end JAX baselines for Q-Learning.
-
jym - Educational and beginner-friendly end-to-end JAX RL training.
-
cleanrl - Clean implementations of RL Algorithms (in both PyTorch and JAX!).
-
jaxrl - JAX implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.
-
rlbase - Single-file JAX implementations of Deep RL algorithms.
-
JaxMARL - Multi-Agent RL Algorithms and Environments in pure JAX.
-
Mava - Multi-Agent RL Algorithms in pure JAX (previously tensorflow-based algorithms).
-
pax - Scalable Opponent Shaping Algorithms in pure JAX.
- JAX-CORL - Single-file implementations of offline RL algorithms in JAX.
- jaxirl - Pure JAX for Inverse Reinforcement Learning.
-
minimax - Canonical implementations of UED algorithms in pure JAX, including SSM-based acceleration.
-
jaxued - Single-file implementations of UED algorithms in pure JAX.
- QDax - Quality-Diversity algorithms in pure JAX.
- popjaxrl - Partially-observed RL environments (POPGym) and architectures (incl. SSM's) in pure JAX.
-
discovered-policy-optimisation - Library for LPO meta-RL in Pure JAX.
-
rl-learned-optimization - Library for OPEN in Pure JAX.
-
gymnax - Classic RL environments in JAX.
-
brax - Continuous control environments in JAX.
-
JaxMARL - Multi-agent algorithms and environments in pure JAX.
-
jumanji - Suite of unique RL environments in JAX.
-
pgx - Suite of popular board games in JAX.
-
popjaxrl - Partially-observed RL environments (POPGym) in JAX.
-
waymax - Self-driving car simulator in JAX.
-
Craftax - A challenging crafter-like and nethack-inspired benchmark in JAX.
-
xland-minigrid - A large-scale meta-RL environment in JAX.
-
navix - Classic minigrid environments in JAX.
-
autoverse - A fast, evolvable description language for reinforcement learning environments.
-
qdx - Quantum Error Corection with JAX.
-
matrax - Matrix games in JAX.
-
AlphaTrade - Limit Order Book (LOB) in JAX.
-
evosax - Evolution strategies in JAX.
-
evojax - Evolution strategies in JAX.
-
flashbax - Accelerated replay buffers in JAX.
-
dejax - Accelerated replay buffers in JAX.
-
rlax - RL components and building blocks in JAX.
-
mctx - Monte Carlo tree searh in JAX.
-
distrax - Distributions in JAX.
-
optax - Gradient-based optimizers in JAX.
-
flax - Neural Networks in JAX.
-
Achieving 4000x Speedups with PureJaxRL - A blog post on how JAX can massively speedup RL training through vectorisation.
-
Breaking down State-of-the-Art PPO Implementations in JAX - A blog post explaining PureJaxRL's PPO Implementation in depth.
-
A Gentle Introduction to Deep Reinforcement Learning in JAX - A JAX tutorial on Deep RL.
-
Writing an RL Environment in JAX - A JAX tutorial on making environments.
-
Getting started with JAX (MLPs, CNNs & RNNs) - A basic JAX neural network tutorial.
-
awesome-jax - A list of useful libraries in JAX