Installation | Quickstart | Documentation
📣 Scalify has been accepted to ICML 2024 workshop WANT! 📣
JAX Scalify is a library implementing end-to-end scale propagation and scaled arithmetic, allowing easy training and inference of deep neural networks in low precision (BF16, FP16, FP8).
Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). Scalify
is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a ScaledArray
graph where every operation has ScaledArray
inputs and returns ScaledArray
:
@dataclass
class ScaledArray:
# Main data component, in low precision.
data: Array
# Scale, usually scalar, in FP32 or E8M0.
scale: Array
def __array__(self) -> Array:
# Tensor represented as a `ScaledArray`.
return data * scale.astype(self.data.dtype)
The main benefits of the scalify
approach are:
- Agnostic to neural-net model definition;
- Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling;
- FP8 matrix multiplications and reductions as simple as a cast;
- Out-of-the-box support of FP16 (scaled) master weights and optimizer state;
- Composable with JAX ecosystem: Flax, Optax, ...
JAX Scalify can be directly installed from PyPi:
pip install jax-scalify
Please follow JAX documentation for a proper JAX installation on GPU/TPU.
The latest version of JAX Scalify is available directly from Github:
pip install git+https://github.com/graphcore-research/jax-scalify.git
A typical JAX training loop just requires a couple of modifications to take advantage of scalify
. More specifically:
- Represent input and state as
ScaledArray
using theas_scaled_array
method (or variations of it); - End-to-end scale propagation in
update
training method usingscalify
decorator; - (Optionally) add
dynamic_rescale
calls to improve low-precision accuracy and stability;
The following (simplified) example presents how to scalify
can be incorporated into a JAX training loop.
import jax_scalify as jsa
# Scalify transform on FWD + BWD + optimizer.
# Propagating scale in the computational graph.
@jsa.scalify
def update(state, data, labels):
# Forward and backward pass on the NN model.
loss, grads =
jax.grad(model)(state, data, labels)
# Optimizer applied on scaled state.
state = optimizer.apply(state, grads)
return loss, state
# Model + optimizer state.
state = (model.init(...), optimizer.init(...))
# Transform state to scaled array(s)
sc_state = jsa.as_scaled_array(state)
for (data, labels) in dataset:
# If necessary (e.g. images), scale input data.
data = jsa.as_scaled_array(data)
# State update, with full scale propagation.
sc_state = update(sc_state, data, labels)
# Optional dynamic rescaling of state.
sc_state = jsa.ops.dynamic_rescale_l2(sc_state)
As presented in the code above, the model state is represented as a JAX PyTree of ScaledArray
, propagated end-to-end through the model (forward and backward passes) as well as the optimizer.
A full collection of examples is available:
- Scalify quickstart notebook: basics of
ScaledArray
andscalify
transform; - MNIST FP16 training example: adapting JAX MNIST example to
scalify
; - MNIST FP8 training example: easy FP8 support in
scalify
; - MNIST Flax example:
scalify
Flax training, with Optax optimizer integration;
For a local development setup, we recommend an interactive install:
git clone git@github.com:graphcore-research/jax-scalify.git
pip install -e ./
Running pre-commit
and pytest
on the JAX Scalify repository:
pip install pre-commit
pre-commit run --all-files
pytest -v ./tests
Python wheel can be built with the usual command python -m build
.