diff --git a/.github/workflows/jax-tests.yml b/.github/workflows/jax-tests.yml index 02fdac7b..3332ac3e 100644 --- a/.github/workflows/jax-tests.yml +++ b/.github/workflows/jax-tests.yml @@ -17,8 +17,10 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest"] - python-version: ["3.10"] + # How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/ + os: ["ubuntu-latest", "macos-14"] + python-version: ["3.11"] + precision: ["64", "32"] steps: - name: Checkout source @@ -38,13 +40,18 @@ jobs: - name: Install run: | python -m pip install --upgrade pip - python -m pip install -e '.[test]' 'jax[cpu]' - python -m pip uninstall -y lithops # tests don't run on Lithops + python -m pip install -e '.[test-jax]' + # Verify jax + python -c 'import jax; print(jax.numpy.arange(10))' - name: Run tests run: | # exclude tests that rely on structured types since JAX doesn't support these - pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby" + # exclude tests that rely on randomness because JAX is picky about this. + # TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization". + pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization" env: CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy - JAX_ENABLE_X64: True + JAX_ENABLE_X64: ${{ matrix.precision == "64" }} + CUBED_DEFAULT_PRECISION_X32: ${{ matrix.precision == "32" }} + ENABLE_PJRT_COMPATIBILITY: True diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index 717ec725..86f41858 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -1,7 +1,10 @@ import math from typing import TYPE_CHECKING, Iterable, List -from cubed.backend_array_api import namespace as nxp +import numpy as np + +from cubed.backend_array_api import namespace as nxp, to_default_precision +from cubed.backend_array_api import default_dtypes from cubed.core import Plan, gensym from cubed.core.ops import map_blocks from cubed.storage.virtual import ( @@ -17,6 +20,23 @@ from .array_object import Array +def _iterable_to_default_dtype(it, device=None): + """Determines the default precision dtype of a collection (of collections) of scalars""" + w = it + while isinstance(w, Iterable): + w = next(iter(w)) + + defaults = default_dtypes(device=device) + if nxp.issubdtype(type(w), np.integer): + return defaults["integral"] + elif nxp.isreal(w): + return defaults["real floating"] + elif nxp.iscomplex(w): + return defaults["complex floating"] + else: + raise ValueError(f"there are no default data types supported for {it}.") + + def arange( start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None ) -> "Array": @@ -24,7 +44,11 @@ def arange( start, stop = 0, start num = int(max(math.ceil((stop - start) / step), 0)) if dtype is None: + # TODO: Use inspect API dtype = nxp.arange(start, stop, step * num if num else step).dtype + # the default nxp call does not adjust the data type to the default precision. + dtype = to_default_precision(dtype, device=device) + chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -62,10 +86,12 @@ def asarray( ): # pragma: no cover return asarray(a.data) elif not isinstance(getattr(a, "shape", None), Iterable): - # ensure blocks are arrays + dtype = _iterable_to_default_dtype(a, device=device) a = nxp.asarray(a, dtype=dtype) + if dtype is None: - dtype = a.dtype + dtype = to_default_precision(a.dtype, device=device) + a = a.astype(dtype) chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype)) name = gensym() @@ -90,7 +116,7 @@ def empty_virtual_array( shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True ) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) name = gensym() @@ -108,7 +134,7 @@ def eye( if n_cols is None: n_cols = n_rows if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] shape = (n_rows, n_cols) chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) @@ -139,12 +165,13 @@ def full( shape = normalize_shape(shape) if dtype is None: # check bool first since True/False are instances of int and float + defaults = default_dtypes(device=device) if isinstance(fill_value, bool): dtype = nxp.bool elif isinstance(fill_value, int): - dtype = nxp.int64 + dtype = defaults['integral'] elif isinstance(fill_value, float): - dtype = nxp.float64 + dtype = defaults['real floating'] else: raise TypeError("Invalid input to full") chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype)) @@ -191,7 +218,7 @@ def linspace( div = 1 step = float(range_) / div if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -208,18 +235,20 @@ def linspace( step=step, endpoint=endpoint, linspace_dtype=dtype, + device=device, ) -def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None): +def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None): bs = x.shape[0] i = block_id[0] adjusted_bs = bs - 1 if endpoint else bs - blockstart = start + (i * size * step) - blockstop = blockstart + (adjusted_bs * step) - return nxp.linspace( - blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype - ) + # While the Array API supports `nxp.astype(x, dtype)`, using this method causes precision + # errors with Jax. For now, let's see how this works with other implementations. + float_ = default_dtypes(device=device)['real floating'] + blockstart = float_(start + (i * size * step)) + blockstop = float_(blockstart + float_(adjusted_bs * step)) + return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype) def meshgrid(*arrays, indexing="xy") -> List["Array"]: @@ -255,7 +284,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]: def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec) @@ -301,7 +330,7 @@ def _tri_mask(N, M, k, chunks, spec): def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array": if dtype is None: - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 7ee6525e..c0a64c8e 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -11,10 +11,12 @@ complex128, float32, float64, + int32, + uint32, int64, uint64, ) -from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import namespace as nxp, PRECISION from cubed.core import reduction @@ -129,15 +131,13 @@ def prod( if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in prod") if dtype is None: - if x.dtype in _boolean_dtypes: - dtype = int64 - elif x.dtype in _signed_integer_dtypes: - dtype = int64 + if x.dtype in _signed_integer_dtypes: + dtype = int64 if PRECISION == 64 else int32 elif x.dtype in _unsigned_integer_dtypes: - dtype = uint64 - elif x.dtype == float32: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: dtype = float64 - elif x.dtype == complex64: + elif x.dtype == complex64 and PRECISION == 64: dtype = complex128 else: dtype = x.dtype @@ -161,15 +161,13 @@ def sum( if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: raise TypeError("Only numeric or boolean dtypes are allowed in sum") if dtype is None: - if x.dtype in _boolean_dtypes: - dtype = int64 - elif x.dtype in _signed_integer_dtypes: - dtype = int64 + if x.dtype in _signed_integer_dtypes: + dtype = int64 if PRECISION == 64 else int32 elif x.dtype in _unsigned_integer_dtypes: - dtype = uint64 - elif x.dtype == float32: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: dtype = float64 - elif x.dtype == complex64: + elif x.dtype == complex64 and PRECISION == 64: dtype = complex128 else: dtype = x.dtype diff --git a/cubed/backend_array_api.py b/cubed/backend_array_api.py index bee99704..da01d11b 100644 --- a/cubed/backend_array_api.py +++ b/cubed/backend_array_api.py @@ -33,6 +33,24 @@ namespace = array_api_compat.numpy +_DEFAULT_DTYPES = { + "real floating": namespace.float64, + "complex floating": namespace.complex128, + "integral": namespace.int64, +} +PRECISION=64 +if "CUBED_DEFAULT_PRECISION_X32" in os.environ: + if os.environ['CUBED_DEFAULT_PRECISION_X32']: + _DEFAULT_DTYPES = { + "real floating": namespace.float32, + "complex floating": namespace.complex64, + "integral": namespace.int32, + } + PRECISION=32 + + +def default_dtypes(*, device=None) -> dict: + return _DEFAULT_DTYPES # These functions to convert to/from backend arrays # assume that no extra memory is allocated, by using the @@ -48,3 +66,12 @@ def numpy_array_to_backend_array(arr, *, dtype=None): if isinstance(arr, dict): return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()} return namespace.asarray(arr, dtype=dtype) + + +def to_default_precision(dtype, *, device=None): + """Returns a dtype of the same kind with the default precision.""" + for k, dtype_ in default_dtypes(device=device).items(): + if namespace.isdtype(dtype, k): + return dtype_ + + diff --git a/cubed/core/ops.py b/cubed/core/ops.py index f6549e43..b5612bc5 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -15,7 +15,7 @@ from toolz import accumulate, map from cubed import config -from cubed.backend_array_api import backend_array_to_numpy_array +from cubed.backend_array_api import backend_array_to_numpy_array, to_default_precision from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.array import CoreArray, check_array_specs, compute, gensym @@ -41,7 +41,7 @@ from cubed.array_api.array_object import Array -def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array": +def from_array(x, chunks="auto", asarray=None, spec=None, device=None) -> "Array": """Create a Cubed array from an array-like object.""" if isinstance(x, CoreArray): @@ -49,9 +49,16 @@ def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array": "Array is already a Cubed array. Use 'asarray' or 'rechunk' instead." ) + dtype = to_default_precision(x.dtype) + if x.dtype != dtype: + if hasattr(x, 'astype'): + x = x.astype(dtype) + elif hasattr(x, '__array__'): + x = x.__array__(dtype) + previous_chunks = getattr(x, "chunks", None) outchunks = normalize_chunks( - chunks, x.shape, dtype=x.dtype, previous_chunks=previous_chunks + chunks, x.shape, dtype=dtype, previous_chunks=previous_chunks ) if isinstance(x, zarr.Array): # zarr fast path diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 2acd308b..3aaf5b74 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -8,10 +8,12 @@ complex128, float32, float64, + int32, + uint32, int64, uint64, ) -from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import namespace as nxp, PRECISION from cubed.core import reduction # TODO: refactor once nan functions are standardized: @@ -69,12 +71,12 @@ def nansum( raise TypeError("Only numeric dtypes are allowed in nansum") if dtype is None: if x.dtype in _signed_integer_dtypes: - dtype = int64 + dtype = int64 if PRECISION == 64 else int32 elif x.dtype in _unsigned_integer_dtypes: - dtype = uint64 - elif x.dtype == float32: + dtype = uint64 if PRECISION == 64 else uint32 + elif x.dtype == float32 and PRECISION == 64: dtype = float64 - elif x.dtype == complex64: + elif x.dtype == complex64 and PRECISION == 64: dtype = complex128 else: dtype = x.dtype diff --git a/cubed/random.py b/cubed/random.py index 6c60a6c9..f1bab4a2 100644 --- a/cubed/random.py +++ b/cubed/random.py @@ -2,17 +2,16 @@ from numpy.random import Generator, Philox -from cubed.backend_array_api import namespace as nxp -from cubed.backend_array_api import numpy_array_to_backend_array +from cubed.backend_array_api import numpy_array_to_backend_array, default_dtypes from cubed.core.ops import map_blocks from cubed.utils import block_id_to_offset, normalize_shape from cubed.vendor.dask.array.core import normalize_chunks -def random(size, *, chunks=None, spec=None): +def random(size, *, chunks=None, spec=None, device=None): """Return random floats in the half-open interval [0.0, 1.0).""" shape = normalize_shape(size) - dtype = nxp.float64 + dtype = default_dtypes(device=device)['real floating'] chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) numblocks = tuple(map(len, chunks)) root_seed = pyrandom.getrandbits(128) @@ -27,9 +26,9 @@ def random(size, *, chunks=None, spec=None): ) -def _random(x, numblocks=None, root_seed=None, block_id=None): +def _random(x, numblocks=None, root_seed=None, block_id=None, dtype=None): stream_id = block_id_to_offset(block_id, numblocks) rg = Generator(Philox(key=root_seed + stream_id)) - out = rg.random(x.shape) - out = numpy_array_to_backend_array(out) + out = rg.random(x.shape, dtype=dtype) + out = numpy_array_to_backend_array(out, dtype=dtype) return out diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index b34b8bbb..a0332031 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -115,15 +115,15 @@ def test_eye(spec, k): def test_linspace(spec, endpoint): a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec) npa = np.linspace(6, 49, 50, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec) npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) a = xp.linspace(0, 0, 0, endpoint=endpoint) npa = np.linspace(0, 0, 0, endpoint=endpoint) - assert_allclose(a, npa) + assert_allclose(a, npa, rtol=1e-5) def test_ones(spec, executor): diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 1e415092..d415eaa9 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -1,3 +1,4 @@ +import os import platform import random from functools import partial @@ -59,9 +60,13 @@ def modal_executor(request): def test_as_array_fails(spec): a = np.ones((1000, 1000)) + expected_size = "8" + if os.environ.get('CUBED_DEFAULT_PRECISION_X32', False): + expected_size = "4" + with pytest.raises( ValueError, - match="Size of in memory array is 8.0 MB which exceeds maximum of 1.0 MB.", + match=f"Size of in memory array is {expected_size}.0 MB which exceeds maximum of 1.0 MB.", ): xp.asarray(a, chunks=(100, 100), spec=spec) @@ -183,55 +188,67 @@ def test_map_blocks_with_kwargs(spec, executor): def test_map_blocks_with_block_id(spec, executor): + dtype = "int64" + if os.environ.get('CUBED_DEFAULT_PRECISION_X32', False): + dtype = "int32" + # based on dask test def func(block, block_id=None, c=0): return nxp.ones_like(block) * int(sum(block_id)) + c - a = xp.arange(10, dtype="int64", chunks=(2,)) - b = cubed.map_blocks(func, a, dtype="int64") + a = xp.arange(10, dtype=dtype, chunks=(2,)) + b = cubed.map_blocks(func, a, dtype=dtype) assert_array_equal( b.compute(executor=executor), - np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], dtype="int64"), + np.array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], dtype=dtype), ) a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) - b = cubed.map_blocks(func, a, dtype="int64") + b = cubed.map_blocks(func, a, dtype=dtype) assert_array_equal( b.compute(executor=executor), - np.array([[0, 0, 1], [0, 0, 1], [1, 1, 2]], dtype="int64"), + np.array([[0, 0, 1], [0, 0, 1], [1, 1, 2]], dtype=dtype), ) - c = cubed.map_blocks(func, a, dtype="int64", c=1) + c = cubed.map_blocks(func, a, dtype=dtype, c=1) assert_array_equal( c.compute(executor=executor), - np.array([[0, 0, 1], [0, 0, 1], [1, 1, 2]], dtype="int64") + 1, + np.array([[0, 0, 1], [0, 0, 1], [1, 1, 2]], dtype=dtype) + 1, ) def test_map_blocks_no_array_args(spec, executor): + dtype = "int64" + if os.environ.get('CUBED_DEFAULT_PRECISION_X32', False): + dtype = "int32" + def func(block, block_id=None): return nxp.ones_like(block) * int(sum(block_id)) - a = cubed.map_blocks(func, dtype="int64", chunks=((5, 3),), spec=spec) + a = cubed.map_blocks(func, dtype=dtype, chunks=((5, 3),), spec=spec) assert a.chunks == ((5, 3),) assert_array_equal( a.compute(executor=executor), - np.array([0, 0, 0, 0, 0, 1, 1, 1], dtype="int64"), + np.array([0, 0, 0, 0, 0, 1, 1, 1], dtype=dtype), ) def test_map_blocks_with_different_block_shapes(spec): + dtype = "int64" + if os.environ.get('CUBED_DEFAULT_PRECISION_X32', False): + dtype = "int32" + def func(x, y): return x a = xp.asarray([[[12, 13]]], spec=spec) b = xp.asarray([14, 15], spec=spec) c = cubed.map_blocks( - func, a, b, dtype="int64", chunks=(1, 1, 2), drop_axis=2, new_axis=2 + func, a, b, dtype=dtype, chunks=(1, 1, 2), drop_axis=2, new_axis=2 ) assert_array_equal(c.compute(), np.array([[[12, 13]]])) diff --git a/cubed/tests/test_gufunc.py b/cubed/tests/test_gufunc.py index deb7d583..e99aacbc 100644 --- a/cubed/tests/test_gufunc.py +++ b/cubed/tests/test_gufunc.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from numpy.testing import assert_allclose, assert_equal @@ -71,13 +73,17 @@ def foo(x): def test_gufunc_two_inputs(spec): + dtype = int + if os.environ.get('CUBED_DEFAULT_PRECISION_X32', False): + dtype = nxp.int32 + def foo(x, y): return np.einsum("...ij,...jk->ik", x, y) - a = xp.ones((2, 3), chunks=100, dtype=int, spec=spec) - b = xp.ones((3, 4), chunks=100, dtype=int, spec=spec) - x = apply_gufunc(foo, "(i,j),(j,k)->(i,k)", a, b, output_dtypes=int) - assert_equal(x, 3 * np.ones((2, 4), dtype=int)) + a = xp.ones((2, 3), chunks=100, dtype=dtype, spec=spec) + b = xp.ones((3, 4), chunks=100, dtype=dtype, spec=spec) + x = apply_gufunc(foo, "(i,j),(j,k)->(i,k)", a, b, output_dtypes=dtype) + assert_equal(x, 3 * np.ones((2, 4), dtype=dtype)) def test_apply_gufunc_axes_two_kept_coredims(spec): diff --git a/cubed/tests/test_nan_functions.py b/cubed/tests/test_nan_functions.py index 53264e79..754ac2d4 100644 --- a/cubed/tests/test_nan_functions.py +++ b/cubed/tests/test_nan_functions.py @@ -27,10 +27,10 @@ def test_nanmean_allnan(spec): def test_nansum(spec): - a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) + a = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, xp.nan]], chunks=(2, 2), spec=spec) b = cubed.nansum(a) assert_array_equal( - b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) + b.compute(), np.nansum(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, np.nan]])) ) diff --git a/docs/contributing.md b/docs/contributing.md index 9cc18990..d5d90be5 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -12,3 +12,17 @@ conda activate cubed pip install -r requirements.txt pip install -e . ``` + +Optionally, to run Jax on the M1+ Mac, please follow these instructions from Apple: +https://developer.apple.com/metal/jax/ + +To summarize: +```shell +pip install jax-metal +export CUBED_BACKEND_ARRAY_API_MODULE=jax.numpy +export JAX_ENABLE_X64=False +export CUBED_DEFAULT_PRECISION_X32=True +export ENABLE_PJRT_COMPATIBILITY=True +``` + +Please make sure that your version of Python and all dependencies are compiled for ARM. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 87336f14..2b06ec3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,15 @@ test-modal = [ "pytest-mock", ] +test-jax = [ + "cubed[diagnostics]", + "dill", + "numpy_groupies", + "pytest", + "pytest-cov", + "pytest-mock", + "jax", +] [project.urls] homepage = "https://github.com/cubed-dev/cubed" documentation = "https://tomwhite.github.io/cubed"