diff --git a/.github/workflows/jax-tests.yml b/.github/workflows/jax-tests.yml index e8ca8662..730084e1 100644 --- a/.github/workflows/jax-tests.yml +++ b/.github/workflows/jax-tests.yml @@ -1,6 +1,7 @@ name: JAX tests on: + pull_request: schedule: # Every weekday at 03:53 UTC, see https://crontab.guru/ - cron: "53 3 * * 1-5" @@ -16,8 +17,9 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest"] - python-version: ["3.9"] + # How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/ + os: ["ubuntu-latest", "macos-14"] + python-version: ["3.11"] steps: - name: Checkout source @@ -37,13 +39,17 @@ jobs: - name: Install run: | python -m pip install --upgrade pip - python -m pip install -e '.[test]' 'jax[cpu]' - python -m pip uninstall -y lithops # tests don't run on Lithops + python -m pip install -e '.[test-jax]' + # Verify jax + python -c 'import jax; print(jax.numpy.arange(10))' - name: Run tests run: | # exclude tests that rely on structured types since JAX doesn't support these - pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby" + # exclude tests that rely on randomness because JAX is picky about this. + # TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization". + pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization" env: CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy JAX_ENABLE_X64: True + ENABLE_PJRT_COMPATIBILITY: True diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index 717ec725..06e0dcb7 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Iterable, List from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import default_dtypes from cubed.core import Plan, gensym from cubed.core.ops import map_blocks from cubed.storage.virtual import ( @@ -17,6 +18,13 @@ from .array_object import Array +def _to_default_precision(dtype, *, device=None): + """Returns a dtype of the same kind with the default precision.""" + for k, dtype_ in default_dtypes(device=device).items(): + if nxp.isdtype(dtype, k): + return dtype_ + + def arange( start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None ) -> "Array": @@ -24,7 +32,11 @@ def arange( start, stop = 0, start num = int(max(math.ceil((stop - start) / step), 0)) if dtype is None: + # TODO: Use inspect API dtype = nxp.arange(start, stop, step * num if num else step).dtype + # the default nxp call does not adjust the data type to the default precision. + dtype = _to_default_precision(dtype, device=device) + chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -65,7 +77,7 @@ def asarray( # ensure blocks are arrays a = nxp.asarray(a, dtype=dtype) if dtype is None: - dtype = a.dtype + dtype = _to_default_precision(a.dtype, device=device) chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype)) name = gensym() @@ -90,7 +102,7 @@ def empty_virtual_array( shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True ) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) name = gensym() @@ -108,7 +120,7 @@ def eye( if n_cols is None: n_cols = n_rows if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] shape = (n_rows, n_cols) chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) @@ -139,12 +151,13 @@ def full( shape = normalize_shape(shape) if dtype is None: # check bool first since True/False are instances of int and float + defaults = default_dtypes(device=device) if isinstance(fill_value, bool): dtype = nxp.bool elif isinstance(fill_value, int): - dtype = nxp.int64 + dtype = defaults['integral'] elif isinstance(fill_value, float): - dtype = nxp.float64 + dtype = defaults['real floating'] else: raise TypeError("Invalid input to full") chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) @@ -191,7 +204,7 @@ def linspace( div = 1 step = float(range_) / div if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -208,18 +221,20 @@ def linspace( step=step, endpoint=endpoint, linspace_dtype=dtype, + device=device, ) -def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None): +def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None): bs = x.shape[0] i = block_id[0] adjusted_bs = bs - 1 if endpoint else bs - blockstart = start + (i * size * step) - blockstop = blockstart + (adjusted_bs * step) - return nxp.linspace( - blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype - ) + # While the Array API supports `nxp.astype(x, dtype)`, using this method causes precision + # errors with Jax. For now, let's see how this works with other implementations. + float_ = default_dtypes(device=device)['real floating'] + blockstart = float_(start + (i * size * step)) + blockstop = float_(blockstart + float_(adjusted_bs * step)) + return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype) def meshgrid(*arrays, indexing="xy") -> List["Array"]: @@ -255,7 +270,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]: def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec) @@ -301,7 +316,7 @@ def _tri_mask(N, M, k, chunks, spec): def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec) diff --git a/cubed/backend_array_api.py b/cubed/backend_array_api.py index bee99704..e976f322 100644 --- a/cubed/backend_array_api.py +++ b/cubed/backend_array_api.py @@ -33,6 +33,22 @@ namespace = array_api_compat.numpy +_DEFAULT_DTYPES = { + "real floating": namespace.float64, + "complex floating": namespace.complex128, + "integral": namespace.int64, +} +if "CUBED_DEFAULT_PRECISION_X32" in os.environ: + if os.environ['CUBED_DEFAULT_PRECISION_X32']: + _DEFAULT_DTYPES = { + "real floating": namespace.float32, + "complex floating": namespace.complex64, + "integral": namespace.int32, + } + + +def default_dtypes(*, device=None) -> dict: + return _DEFAULT_DTYPES # These functions to convert to/from backend arrays # assume that no extra memory is allocated, by using the diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 919bd14c..dfa78ba5 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -1,3 +1,4 @@ +import dataclasses import atexit import inspect import shutil @@ -33,6 +34,8 @@ def delete_on_exit(context_dir: str) -> None: sym_counter = 0 +Decorator = Callable[[Callable], Callable] + def gensym(name="op"): global sym_counter @@ -194,13 +197,40 @@ def _create_lazy_zarr_arrays(self, dag): return dag + def _compile_blockwise(self, dag, jit_function: Decorator) -> nx.MultiDiGraph: + """JIT-compiles the functions from all blockwise ops by mutating the input dag.""" + # Recommended: make a copy of the dag before calling this function. + for n in dag.nodes: + node = dag.nodes[n] + + if "primitive_op" not in node: + continue + + if not isinstance(node["pipeline"].config, BlockwiseSpec): + continue + + # node is a blockwise primitive_op. + # maybe we should investigate some sort of optics library for frozen dataclasses... + new_pipeline = dataclasses.replace( + node["pipeline"], + config=dataclasses.replace( + node["pipeline"].config, + function=jit_function(node["pipeline"].config.function) + ) + ) + node["pipeline"] = new_pipeline + + return dag + @lru_cache def _finalize_dag( - self, optimize_graph: bool = True, optimize_function=None + self, optimize_graph: bool = True, optimize_function=None, jit_function: Optional[Decorator] = None, ) -> nx.MultiDiGraph: dag = self.optimize(optimize_function).dag if optimize_graph else self.dag # create a copy since _create_lazy_zarr_arrays mutates the dag dag = dag.copy() + if callable(jit_function): + dag = self._compile_blockwise(dag, jit_function) dag = self._create_lazy_zarr_arrays(dag) return nx.freeze(dag) @@ -210,11 +240,12 @@ def execute( callbacks=None, optimize_graph=True, optimize_function=None, + jit_function=None, resume=None, spec=None, **kwargs, ): - dag = self._finalize_dag(optimize_graph, optimize_function) + dag = self._finalize_dag(optimize_graph, optimize_function, jit_function) compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" diff --git a/cubed/random.py b/cubed/random.py index 6c60a6c9..f1bab4a2 100644 --- a/cubed/random.py +++ b/cubed/random.py @@ -2,17 +2,16 @@ from numpy.random import Generator, Philox -from cubed.backend_array_api import namespace as nxp -from cubed.backend_array_api import numpy_array_to_backend_array +from cubed.backend_array_api import numpy_array_to_backend_array, default_dtypes from cubed.core.ops import map_blocks from cubed.utils import block_id_to_offset, normalize_shape from cubed.vendor.dask.array.core import normalize_chunks -def random(size, *, chunks=None, spec=None): +def random(size, *, chunks=None, spec=None, device=None): """Return random floats in the half-open interval [0.0, 1.0).""" shape = normalize_shape(size) - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) numblocks = tuple(map(len, chunks)) root_seed = pyrandom.getrandbits(128) @@ -27,9 +26,9 @@ def random(size, *, chunks=None, spec=None): ) -def _random(x, numblocks=None, root_seed=None, block_id=None): +def _random(x, numblocks=None, root_seed=None, block_id=None, dtype=None): stream_id = block_id_to_offset(block_id, numblocks) rg = Generator(Philox(key=root_seed + stream_id)) - out = rg.random(x.shape) - out = numpy_array_to_backend_array(out) + out = rg.random(x.shape, dtype=dtype) + out = numpy_array_to_backend_array(out, dtype=dtype) return out diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 2be010f0..6ce42db0 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -115,15 +115,15 @@ def test_eye(spec, k): def test_linspace(spec, endpoint): a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec) npa = np.linspace(6, 49, 50, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec) npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) a = xp.linspace(0, 0, 0, endpoint=endpoint) npa = np.linspace(0, 0, 0, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) def test_ones(spec, executor): @@ -351,15 +351,17 @@ def test_matmul(spec, executor): [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], chunks=(2, 2), spec=spec, + dtype=xp.float32, ) b = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], chunks=(2, 2), spec=spec, + dtype=xp.float32, ) c = xp.matmul(a, b) - x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) - y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=xp.float32) + y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=xp.float32) expected = np.matmul(x, y) assert_array_equal(c.compute(executor=executor), expected) @@ -415,8 +417,8 @@ def test_matmul_modal(modal_executor): def test_outer(spec, executor): - a = xp.asarray([0, 1, 2], chunks=2, spec=spec) - b = xp.asarray([10, 50, 100], chunks=2, spec=spec) + a = xp.asarray([0, 1, 2], chunks=2, spec=spec, dtype=xp.float32) + b = xp.asarray([10, 50, 100], chunks=2, spec=spec, dtype=xp.float32) c = xp.outer(a, b) assert_array_equal(c.compute(executor=executor), np.outer([0, 1, 2], [10, 50, 100])) diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index f40c7a0f..11a391ec 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -383,7 +383,7 @@ def test_reduction_not_enough_memory(tmp_path): def test_partial_reduce(spec): - a = xp.asarray(np.arange(242).reshape((11, 22)), chunks=(3, 4), spec=spec) + a = xp.asarray(np.arange(242, dtype=np.int32).reshape((11, 22)), chunks=(3, 4), spec=spec, dtype=xp.int32) b = partial_reduce(a, np.sum, split_every={0: 2}) c = partial_reduce(b, np.sum, split_every={0: 2}) assert_array_equal( @@ -468,13 +468,13 @@ def test_compute_multiple_different_specs(tmp_path): def test_visualize(tmp_path): - a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float64, chunks=(2, 2)) + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float32, chunks=(2, 2)) b = cubed.random.random((3, 3), chunks=(2, 2)) c = xp.add(a, b) d = c.rechunk((3, 1)) e = c * 3 - f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2)) + f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), dtype=xp.float32) g = f * 4 assert not (tmp_path / "e.dot").exists() @@ -504,11 +504,13 @@ def test_array_pickle(spec, executor): [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], chunks=(2, 2), spec=spec, + dtype=xp.float32, ) b = xp.asarray( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], chunks=(2, 2), spec=spec, + dtype=xp.float32, ) c = xp.matmul(a, b) @@ -516,8 +518,8 @@ def test_array_pickle(spec, executor): # note we have to use dill which can serialize local functions, unlike pickle c = dill.loads(dill.dumps(c)) - x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) - y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]) + x = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) + y = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], dtype=np.float32) expected = np.matmul(x, y) assert_array_equal(c.compute(executor=executor), expected) @@ -539,10 +541,10 @@ def test_plan_scaling(tmp_path, factor): spec = cubed.Spec(tmp_path, allowed_mem="2GB") chunksize = 5000 a = cubed.random.random( - (factor * chunksize, factor * chunksize), chunks=chunksize, spec=spec + (factor * chunksize, factor * chunksize), chunks=chunksize, spec=spec, ) b = cubed.random.random( - (factor * chunksize, factor * chunksize), chunks=chunksize, spec=spec + (factor * chunksize, factor * chunksize), chunks=chunksize, spec=spec, ) c = xp.matmul(a, b) diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 22fce7bf..c8fc5390 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -315,3 +315,30 @@ def test_check_runtime_memory_processes(spec, executor): # OK if we use fewer workers c.compute(executor=executor, max_workers=max_workers // 2) + + + +JIT_FUNCTIONS = [lambda fn: fn] + +try: + from numba import jit as numba_jit + JIT_FUNCTIONS.append(numba_jit) +except ModuleNotFoundError: + pass + +try: + if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''): + from jax import jit as jax_jit + JIT_FUNCTIONS.append(jax_jit) +except ModuleNotFoundError: + pass + + +@pytest.mark.parametrize("jit_function", JIT_FUNCTIONS) +def test_check_jit_compliation(spec, executor, jit_function): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + assert_array_equal( + c.compute(executor=executor, jit_function=jit_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]) + ) \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md index 477fdc46..0d442de9 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -12,3 +12,17 @@ conda activate cubed pip install -r requirements.txt pip install -e . ``` + +Optionally, to run Jax on the M1+ Mac, please follow these instructions from Apple: +https://developer.apple.com/metal/jax/ + +To summarize: +```shell +pip install jax-metal +export CUBED_BACKEND_ARRAY_API_MODULE=jax.numpy +export JAX_ENABLE_X64=False +export CUBED_DEFAULT_PRECISION_X32=True +export ENABLE_PJRT_COMPATIBILITY=True +``` + +Please make sure that your version of Python and all dependencies are compiled for ARM. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 97e1be3c..c5b6531f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,15 @@ test-modal = [ "pytest-mock", ] +test-jax = [ + "cubed[diagnostics]", + "dill", + "numpy_groupies", + "pytest", + "pytest-cov", + "pytest-mock", + "jax", +] [project.urls] homepage = "https://github.com/cubed-dev/cubed" documentation = "https://tomwhite.github.io/cubed"