This is me learning JAX by porting geohot's tinygrad implementation of this paper. I'm quite curious how they compare speed-wise to the tinygrad and PyTorch versions. As this is my first touch with JAX, the code surely isn't the most optimal, so pull requests are encouraged. I will likely benchmark them and also share the results here in this repo.
After trying out Numba, Triton, and raw CUDA, JAX and raw CUDA felt the most intuitive for me sofar. So it was definitely worth implementing this.
python -m venv self-compressing-nn-jax
source self-compressing-nn-jax/bin/activate
pip install -e .
python3 train_mnist.py
tinygrad
PyTorch
JAX
I had some weird drops in this run two times ... need to investigate this.