Unxt is unitful quantities and calculations in JAX, built on Equinox and Quax.
Unxt supports JAX's compelling features:
- JIT compilation (
jit
) - vectorization (
vmap
, etc.) - auto-differentiation (
grad
,jacobian
,hessian
) - GPU/TPU/multi-host acceleration
And best of all, unxt
doesn't force you to use special unit-compatible
re-exports of JAX libraries. You can use unxt
with existing JAX code, and with
quax's simple decorator, JAX will work with unxt.Quantity
.
pip install unxt
using uv
uv add unxt
from source, using pip
pip install git+https://https://github.com/GalacticDynamics/unxt.git
building from source
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e . # editable mode
import unxt as u
x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length'](Array([1., 2., 3., 4.], dtype=float64), unit='km')
The constituent value and unit are accessible as attributes:
print(x.value)
# Array([1., 2., 3., 4.], dtype=float64)
print(x.unit)
# Unit("m")
Quantity
objects obey the rules of unitful arithmetic.
# Addition / Subtraction
print(x + x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')
# Multiplication / Division
print(2 * x)
# Quantity['length'](Array([2., 4., 6., 8.], dtype=float64), unit='km')
y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")
print(x / y)
# Quantity['speed'](Array([0.25 , 0.4 , 0.5 , 0.57142857], dtype=float64), unit='km / yr')
# Exponentiation
print(x**2)
# Quantity['area'](Array([0., 1., 4., 9.], dtype=float64), unit='km2')
# Unit checking on operations
try:
x + y
except Exception as e:
print(e)
# 'yr' (time) and 'km' (length) are not convertible
Quantities can be converted to different units:
print(u.uconvert("m", x)) # via function
# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')
print(x.uconvert("m")) # via method
# Quantity['length'](Array([1000., 2000., 3000., 4000.], dtype=float64), unit='m')
Since Quantity
is parametric, it can do runtime dimension checking!
LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](Array(2, dtype=int64, weak_type=True), unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.
unxt
is built on quax
, which enables custom array-ish objects in
JAX. For convenience we use the quaxed
library, which is just a
quax.quaxify
wrapper around jax
to avoid boilerplate code.
Note
Using quaxed
is optional. You can directly use quaxify
, and even
apply it to the top-level function instead of individual functions.
from quaxed import grad, vmap
import quaxed.numpy as jnp
print(jnp.square(x))
# Quantity['area'](Array([ 1., 4., 9., 16.], dtype=float64), unit='km2')
print(qnp.power(x, 3))
# Quantity['volume'](Array([ 1., 8., 27., 64.], dtype=float64), unit='km3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area'](Array([ 3., 12., 27., 48.], dtype=float64), unit='km2')
See the documentation for more examples and details of JIT and AD
If you found this library to be useful and want to support the development and maintenance of lower-level code libraries for the scientific community, please consider citing this work.
We welcome contributions!