Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Jit functionality to Plan finialization. #1

Open
wants to merge 34 commits into
base: m1-jax
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
69a906c
Adding Jax tests for the M1 mac.
alxmrs Jul 20, 2024
d571970
Providing flag to default to single precision instead of double for n…
alxmrs Jul 21, 2024
9a043da
Matmul on jax-metal only supports floats and complex types. :/
alxmrs Jul 21, 2024
f2afa77
Fixes for more tests.
alxmrs Jul 21, 2024
c0c7751
Run jax tests on PRs.
alxmrs Jul 22, 2024
eb059d9
Implemented feedback from Tom.
alxmrs Jul 22, 2024
fbb78f5
Added device argument.
alxmrs Jul 22, 2024
83cdd4f
Exclude 'indexing' case.
alxmrs Jul 22, 2024
ff5e13f
Removing 'indexing' from default types for now.
alxmrs Jul 22, 2024
a6fcaf3
Fixed bag argument in test.
alxmrs Jul 22, 2024
aa0fdfc
Random should respect dtypes.
alxmrs Jul 23, 2024
ecf2a42
I think I got the linspace tests passing. Need to check if array equa…
alxmrs Jul 23, 2024
f37a5eb
All jax tests pass locally.
alxmrs Jul 23, 2024
3adca6b
Cleanup for PR.
alxmrs Jul 23, 2024
f3ab3b3
Added developer instructions for getting set up on the M1+ processor.
alxmrs Jul 23, 2024
8c0384a
Fix env variable expressions.
alxmrs Jul 23, 2024
20cad4b
The macos-14 image requires at least Python 3.10.
alxmrs Jul 23, 2024
bfe0d87
Looks like 14 specifically supports python 3.11 and up.
alxmrs Jul 23, 2024
448c60d
Splitting install steps for jax for mac and non-mac.
alxmrs Jul 27, 2024
b2a16d8
Mac install now enables jaxlib compatibility.
alxmrs Jul 27, 2024
d8c6066
Put the mac-specific env setting in the right place.
alxmrs Jul 27, 2024
d52fce5
Test the xlarge runner.
alxmrs Jul 27, 2024
bdb5777
Extracted to method `_to_default_precision`.
alxmrs Jul 30, 2024
a9ad5df
Added helpful note.
alxmrs Jul 30, 2024
0d436b1
No need to cast in the linspace test.
alxmrs Jul 30, 2024
a67fee2
Remove macos-specific tests.
alxmrs Jul 30, 2024
2e0a2f9
Back tracking to previous macos build
alxmrs Jul 30, 2024
a19ca2e
Mac should use CPU Jax
alxmrs Jul 31, 2024
780e349
Remove jax mac extra install.
alxmrs Jul 31, 2024
2f7c324
Tensordot testing with and without hardcoded types.
alxmrs Jul 31, 2024
ecf7f15
Defaulting to 64 bit precision.
alxmrs Jul 31, 2024
e189c09
rm unneeded test arg.
alxmrs Jul 31, 2024
9a7a419
Revert tensordot change.
alxmrs Jul 31, 2024
995436b
Merged conflicts.
alxmrs Jul 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: JAX tests

on:
pull_request:
schedule:
# Every weekday at 03:53 UTC, see https://crontab.guru/
- cron: "53 3 * * 1-5"
Expand All @@ -16,8 +17,9 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.9"]
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/
os: ["ubuntu-latest", "macos-14"]
python-version: ["3.11"]

steps:
- name: Checkout source
Expand All @@ -37,13 +39,17 @@ jobs:
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'jax[cpu]'
python -m pip uninstall -y lithops # tests don't run on Lithops
python -m pip install -e '.[test-jax]'
# Verify jax
python -c 'import jax; print(jax.numpy.arange(10))'

- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby"
# exclude tests that rely on randomness because JAX is picky about this.
# TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization".
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
ENABLE_PJRT_COMPATIBILITY: True
43 changes: 29 additions & 14 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Iterable, List

from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import default_dtypes
from cubed.core import Plan, gensym
from cubed.core.ops import map_blocks
from cubed.storage.virtual import (
Expand All @@ -17,14 +18,25 @@
from .array_object import Array


def _to_default_precision(dtype, *, device=None):
"""Returns a dtype of the same kind with the default precision."""
for k, dtype_ in default_dtypes(device=device).items():
if nxp.isdtype(dtype, k):
return dtype_


def arange(
start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None
) -> "Array":
if stop is None:
start, stop = 0, start
num = int(max(math.ceil((stop - start) / step), 0))
if dtype is None:
# TODO: Use inspect API
dtype = nxp.arange(start, stop, step * num if num else step).dtype
# the default nxp call does not adjust the data type to the default precision.
dtype = _to_default_precision(dtype, device=device)

chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand Down Expand Up @@ -65,7 +77,7 @@ def asarray(
# ensure blocks are arrays
a = nxp.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype
dtype = _to_default_precision(a.dtype, device=device)

chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype))
name = gensym()
Expand All @@ -90,7 +102,7 @@ def empty_virtual_array(
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']

chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
name = gensym()
Expand All @@ -108,7 +120,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']

shape = (n_rows, n_cols)
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -139,12 +151,13 @@ def full(
shape = normalize_shape(shape)
if dtype is None:
# check bool first since True/False are instances of int and float
defaults = default_dtypes(device=device)
if isinstance(fill_value, bool):
dtype = nxp.bool
elif isinstance(fill_value, int):
dtype = nxp.int64
dtype = defaults['integral']
elif isinstance(fill_value, float):
dtype = nxp.float64
dtype = defaults['real floating']
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down Expand Up @@ -191,7 +204,7 @@ def linspace(
div = 1
step = float(range_) / div
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand All @@ -208,18 +221,20 @@ def linspace(
step=step,
endpoint=endpoint,
linspace_dtype=dtype,
device=device,
)


def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None):
def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None):
bs = x.shape[0]
i = block_id[0]
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)
# While the Array API supports `nxp.astype(x, dtype)`, using this method causes precision
# errors with Jax. For now, let's see how this works with other implementations.
float_ = default_dtypes(device=device)['real floating']
blockstart = float_(start + (i * size * step))
blockstop = float_(blockstart + float_(adjusted_bs * step))
return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype)


def meshgrid(*arrays, indexing="xy") -> List["Array"]:
Expand Down Expand Up @@ -255,7 +270,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:

def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down Expand Up @@ -301,7 +316,7 @@ def _tri_mask(N, M, k, chunks, spec):

def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down
16 changes: 16 additions & 0 deletions cubed/backend_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@

namespace = array_api_compat.numpy

_DEFAULT_DTYPES = {
"real floating": namespace.float64,
"complex floating": namespace.complex128,
"integral": namespace.int64,
}
if "CUBED_DEFAULT_PRECISION_X32" in os.environ:
if os.environ['CUBED_DEFAULT_PRECISION_X32']:
_DEFAULT_DTYPES = {
"real floating": namespace.float32,
"complex floating": namespace.complex64,
"integral": namespace.int32,
}


def default_dtypes(*, device=None) -> dict:
return _DEFAULT_DTYPES

# These functions to convert to/from backend arrays
# assume that no extra memory is allocated, by using the
Expand Down
35 changes: 33 additions & 2 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import atexit
import inspect
import shutil
Expand Down Expand Up @@ -33,6 +34,8 @@ def delete_on_exit(context_dir: str) -> None:

sym_counter = 0

Decorator = Callable[[Callable], Callable]


def gensym(name="op"):
global sym_counter
Expand Down Expand Up @@ -194,13 +197,40 @@ def _create_lazy_zarr_arrays(self, dag):

return dag

def _compile_blockwise(self, dag, jit_function: Decorator) -> nx.MultiDiGraph:
"""JIT-compiles the functions from all blockwise ops by mutating the input dag."""
# Recommended: make a copy of the dag before calling this function.
for n in dag.nodes:
node = dag.nodes[n]

if "primitive_op" not in node:
continue

if not isinstance(node["pipeline"].config, BlockwiseSpec):
continue

# node is a blockwise primitive_op.
# maybe we should investigate some sort of optics library for frozen dataclasses...
new_pipeline = dataclasses.replace(
node["pipeline"],
config=dataclasses.replace(
node["pipeline"].config,
function=jit_function(node["pipeline"].config.function)
)
)
node["pipeline"] = new_pipeline

return dag

@lru_cache
def _finalize_dag(
self, optimize_graph: bool = True, optimize_function=None
self, optimize_graph: bool = True, optimize_function=None, jit_function: Optional[Decorator] = None,
) -> nx.MultiDiGraph:
dag = self.optimize(optimize_function).dag if optimize_graph else self.dag
# create a copy since _create_lazy_zarr_arrays mutates the dag
dag = dag.copy()
if callable(jit_function):
dag = self._compile_blockwise(dag, jit_function)
dag = self._create_lazy_zarr_arrays(dag)
return nx.freeze(dag)

Expand All @@ -210,11 +240,12 @@ def execute(
callbacks=None,
optimize_graph=True,
optimize_function=None,
jit_function=None,
resume=None,
spec=None,
**kwargs,
):
dag = self._finalize_dag(optimize_graph, optimize_function)
dag = self._finalize_dag(optimize_graph, optimize_function, jit_function)

compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}"

Expand Down
13 changes: 6 additions & 7 deletions cubed/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from numpy.random import Generator, Philox

from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.backend_array_api import numpy_array_to_backend_array, default_dtypes
from cubed.core.ops import map_blocks
from cubed.utils import block_id_to_offset, normalize_shape
from cubed.vendor.dask.array.core import normalize_chunks


def random(size, *, chunks=None, spec=None):
def random(size, *, chunks=None, spec=None, device=None):
"""Return random floats in the half-open interval [0.0, 1.0)."""
shape = normalize_shape(size)
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
numblocks = tuple(map(len, chunks))
root_seed = pyrandom.getrandbits(128)
Expand All @@ -27,9 +26,9 @@ def random(size, *, chunks=None, spec=None):
)


def _random(x, numblocks=None, root_seed=None, block_id=None):
def _random(x, numblocks=None, root_seed=None, block_id=None, dtype=None):
stream_id = block_id_to_offset(block_id, numblocks)
rg = Generator(Philox(key=root_seed + stream_id))
out = rg.random(x.shape)
out = numpy_array_to_backend_array(out)
out = rg.random(x.shape, dtype=dtype)
out = numpy_array_to_backend_array(out, dtype=dtype)
return out
16 changes: 9 additions & 7 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,15 @@ def test_eye(spec, k):
def test_linspace(spec, endpoint):
a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec)
npa = np.linspace(6, 49, 50, endpoint=endpoint)
assert_allclose(a, npa)
assert_allclose(a, npa, rtol=1e-5)

a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec)
npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint)
assert_allclose(a, npa)
assert_allclose(a, npa, rtol=1e-5)

a = xp.linspace(0, 0, 0, endpoint=endpoint)
npa = np.linspace(0, 0, 0, endpoint=endpoint)
assert_allclose(a, npa)
assert_allclose(a, npa, rtol=1e-5)


def test_ones(spec, executor):
Expand Down Expand Up @@ -351,15 +351,17 @@ def test_matmul(spec, executor):
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
spec=spec,
dtype=xp.float32,
)
b = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 2),
spec=spec,
dtype=xp.float32,
)
c = xp.matmul(a, b)
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=xp.float32)
y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=xp.float32)
expected = np.matmul(x, y)
assert_array_equal(c.compute(executor=executor), expected)

Expand Down Expand Up @@ -415,8 +417,8 @@ def test_matmul_modal(modal_executor):


def test_outer(spec, executor):
a = xp.asarray([0, 1, 2], chunks=2, spec=spec)
b = xp.asarray([10, 50, 100], chunks=2, spec=spec)
a = xp.asarray([0, 1, 2], chunks=2, spec=spec, dtype=xp.float32)
b = xp.asarray([10, 50, 100], chunks=2, spec=spec, dtype=xp.float32)
c = xp.outer(a, b)
assert_array_equal(c.compute(executor=executor), np.outer([0, 1, 2], [10, 50, 100]))

Expand Down
Loading
Loading