Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Jax tests for the M1 mac. #508

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
90c65c6
gAdding Jax tests for the M1 mac.
alxmrs Jul 20, 2024
f94d083
Providing flag to default to single precision instead of double for n…
alxmrs Jul 21, 2024
5c10122
Matmul on jax-metal only supports floats and complex types. :/
alxmrs Jul 21, 2024
83a752b
Fixes for more tests.
alxmrs Jul 21, 2024
1f84ac8
Implemented feedback from Tom.
alxmrs Jul 22, 2024
b8445a0
Added device argument.
alxmrs Jul 22, 2024
61e51be
Exclude 'indexing' case.
alxmrs Jul 22, 2024
a81be3d
Removing 'indexing' from default types for now.
alxmrs Jul 22, 2024
5dcab9f
Fixed bag argument in test.
alxmrs Jul 22, 2024
4ce8e7a
Random should respect dtypes.
alxmrs Jul 23, 2024
38d0b7b
I think I got the linspace tests passing. Need to check if array equa…
alxmrs Jul 23, 2024
e71fa63
All jax tests pass locally.
alxmrs Jul 23, 2024
896a738
Cleanup for PR.
alxmrs Jul 23, 2024
a0080a9
Added developer instructions for getting set up on the M1+ processor.
alxmrs Jul 23, 2024
5b00d82
Fix env variable expressions.
alxmrs Jul 23, 2024
e9b53e3
Looks like 14 specifically supports python 3.11 and up.
alxmrs Jul 23, 2024
1ed8ac2
Splitting install steps for jax for mac and non-mac.
alxmrs Jul 27, 2024
72bd026
Mac install now enables jaxlib compatibility.
alxmrs Jul 27, 2024
8b6b1f4
Put the mac-specific env setting in the right place.
alxmrs Jul 27, 2024
c4b51ad
Test the xlarge runner.
alxmrs Jul 27, 2024
7566849
Extracted to method `_to_default_precision`.
alxmrs Jul 30, 2024
2f470d5
Added helpful note.
alxmrs Jul 30, 2024
06e030d
No need to cast in the linspace test.
alxmrs Jul 30, 2024
e642f8d
Remove macos-specific tests.
alxmrs Jul 30, 2024
2afba45
Back tracking to previous macos build
alxmrs Jul 30, 2024
38831b6
Mac should use CPU Jax
alxmrs Jul 31, 2024
9bd0210
Remove jax mac extra install.
alxmrs Jul 31, 2024
18e0aa0
Tensordot testing with and without hardcoded types.
alxmrs Jul 31, 2024
4d132ff
Defaulting to 64 bit precision.
alxmrs Jul 31, 2024
9f6017a
rm unneeded test arg.
alxmrs Jul 31, 2024
1449008
Revert tensordot change.
alxmrs Jul 31, 2024
38dfc32
Revert hardcoding precision as much as possible.
alxmrs Aug 3, 2024
aebb2ab
Make precision a focus point in test matrix.
alxmrs Aug 3, 2024
4687d33
Fix underlying bug in asarray.
alxmrs Aug 3, 2024
92a31cc
Exposing `to_default_precision` as a backend_array_api.
alxmrs Aug 3, 2024
dbefff4
All cubed tests pass!
alxmrs Aug 5, 2024
b8ce6ec
Inlined extracted function.
alxmrs Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions .github/workflows/jax-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10"]
# How to set up Jax on an ARM Mac: https://developer.apple.com/metal/jax/
os: ["ubuntu-latest", "macos-14"]
alxmrs marked this conversation as resolved.
Show resolved Hide resolved
python-version: ["3.11"]
precision: ["64", "32"]

steps:
- name: Checkout source
Expand All @@ -38,13 +40,18 @@ jobs:
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install -e '.[test]' 'jax[cpu]'
python -m pip uninstall -y lithops # tests don't run on Lithops
python -m pip install -e '.[test-jax]'
# Verify jax
python -c 'import jax; print(jax.numpy.arange(10))'

- name: Run tests
run: |
# exclude tests that rely on structured types since JAX doesn't support these
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby"
# exclude tests that rely on randomness because JAX is picky about this.
# TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization".
pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization"
env:
CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy
JAX_ENABLE_X64: True
JAX_ENABLE_X64: ${{ matrix.precision == "64" }}
CUBED_DEFAULT_PRECISION_X32: ${{ matrix.precision == "32" }}
ENABLE_PJRT_COMPATIBILITY: True
61 changes: 45 additions & 16 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import math
from typing import TYPE_CHECKING, Iterable, List

from cubed.backend_array_api import namespace as nxp
import numpy as np

from cubed.backend_array_api import namespace as nxp, to_default_precision
from cubed.backend_array_api import default_dtypes
from cubed.core import Plan, gensym
from cubed.core.ops import map_blocks
from cubed.storage.virtual import (
Expand All @@ -17,14 +20,35 @@
from .array_object import Array


def _iterable_to_default_dtype(it, device=None):
"""Determines the default precision dtype of a collection (of collections) of scalars"""
w = it
while isinstance(w, Iterable):
w = next(iter(w))

defaults = default_dtypes(device=device)
if nxp.issubdtype(type(w), np.integer):
return defaults["integral"]
elif nxp.isreal(w):
return defaults["real floating"]
elif nxp.iscomplex(w):
return defaults["complex floating"]
else:
raise ValueError(f"there are no default data types supported for {it}.")


def arange(
start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None
) -> "Array":
if stop is None:
start, stop = 0, start
num = int(max(math.ceil((stop - start) / step), 0))
if dtype is None:
# TODO: Use inspect API
dtype = nxp.arange(start, stop, step * num if num else step).dtype
# the default nxp call does not adjust the data type to the default precision.
dtype = to_default_precision(dtype, device=device)

chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand Down Expand Up @@ -62,10 +86,12 @@ def asarray(
): # pragma: no cover
return asarray(a.data)
elif not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
dtype = _iterable_to_default_dtype(a, device=device)
a = nxp.asarray(a, dtype=dtype)

if dtype is None:
dtype = a.dtype
dtype = to_default_precision(a.dtype, device=device)
a = a.astype(dtype)

chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype))
name = gensym()
Expand All @@ -90,7 +116,7 @@ def empty_virtual_array(
shape, *, dtype=None, device=None, chunks="auto", spec=None, hidden=True
) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']

chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
name = gensym()
Expand All @@ -108,7 +134,7 @@ def eye(
if n_cols is None:
n_cols = n_rows
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']

shape = (n_rows, n_cols)
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -139,12 +165,13 @@ def full(
shape = normalize_shape(shape)
if dtype is None:
# check bool first since True/False are instances of int and float
defaults = default_dtypes(device=device)
if isinstance(fill_value, bool):
dtype = nxp.bool
elif isinstance(fill_value, int):
dtype = nxp.int64
dtype = defaults['integral']
elif isinstance(fill_value, float):
dtype = nxp.float64
dtype = defaults['real floating']
else:
raise TypeError("Invalid input to full")
chunksize = to_chunksize(normalize_chunks(chunks, shape=shape, dtype=dtype))
Expand Down Expand Up @@ -191,7 +218,7 @@ def linspace(
div = 1
step = float(range_) / div
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]

Expand All @@ -208,18 +235,20 @@ def linspace(
step=step,
endpoint=endpoint,
linspace_dtype=dtype,
device=device,
)


def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None):
def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None):
bs = x.shape[0]
i = block_id[0]
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)
# While the Array API supports `nxp.astype(x, dtype)`, using this method causes precision
# errors with Jax. For now, let's see how this works with other implementations.
float_ = default_dtypes(device=device)['real floating']
blockstart = float_(start + (i * size * step))
blockstop = float_(blockstart + float_(adjusted_bs * step))
alxmrs marked this conversation as resolved.
Show resolved Hide resolved
return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype)


def meshgrid(*arrays, indexing="xy") -> List["Array"]:
Expand Down Expand Up @@ -255,7 +284,7 @@ def meshgrid(*arrays, indexing="xy") -> List["Array"]:

def ones(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
return full(shape, 1, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down Expand Up @@ -301,7 +330,7 @@ def _tri_mask(N, M, k, chunks, spec):

def zeros(shape, *, dtype=None, device=None, chunks="auto", spec=None) -> "Array":
if dtype is None:
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
return full(shape, 0, dtype=dtype, device=device, chunks=chunks, spec=spec)


Expand Down
28 changes: 13 additions & 15 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
complex128,
float32,
float64,
int32,
uint32,
int64,
uint64,
)
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import namespace as nxp, PRECISION
from cubed.core import reduction


Expand Down Expand Up @@ -129,15 +131,13 @@ def prod(
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64:
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
Expand All @@ -161,15 +161,13 @@ def sum(
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
if dtype is None:
if x.dtype in _boolean_dtypes:
dtype = int64
elif x.dtype in _signed_integer_dtypes:
dtype = int64
if x.dtype in _signed_integer_dtypes:
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64:
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
Expand Down
27 changes: 27 additions & 0 deletions cubed/backend_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,24 @@

namespace = array_api_compat.numpy

_DEFAULT_DTYPES = {
"real floating": namespace.float64,
"complex floating": namespace.complex128,
"integral": namespace.int64,
}
PRECISION=64
if "CUBED_DEFAULT_PRECISION_X32" in os.environ:
if os.environ['CUBED_DEFAULT_PRECISION_X32']:
_DEFAULT_DTYPES = {
"real floating": namespace.float32,
"complex floating": namespace.complex64,
"integral": namespace.int32,
}
PRECISION=32


def default_dtypes(*, device=None) -> dict:
return _DEFAULT_DTYPES

# These functions to convert to/from backend arrays
# assume that no extra memory is allocated, by using the
Expand All @@ -48,3 +66,12 @@ def numpy_array_to_backend_array(arr, *, dtype=None):
if isinstance(arr, dict):
return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()}
return namespace.asarray(arr, dtype=dtype)


def to_default_precision(dtype, *, device=None):
"""Returns a dtype of the same kind with the default precision."""
for k, dtype_ in default_dtypes(device=device).items():
if namespace.isdtype(dtype, k):
return dtype_


13 changes: 10 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from toolz import accumulate, map

from cubed import config
from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import backend_array_to_numpy_array, to_default_precision
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
Expand All @@ -41,17 +41,24 @@
from cubed.array_api.array_object import Array


def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array":
def from_array(x, chunks="auto", asarray=None, spec=None, device=None) -> "Array":
"""Create a Cubed array from an array-like object."""

if isinstance(x, CoreArray):
raise ValueError(
"Array is already a Cubed array. Use 'asarray' or 'rechunk' instead."
)

dtype = to_default_precision(x.dtype)
if x.dtype != dtype:
if hasattr(x, 'astype'):
x = x.astype(dtype)
elif hasattr(x, '__array__'):
x = x.__array__(dtype)

previous_chunks = getattr(x, "chunks", None)
outchunks = normalize_chunks(
chunks, x.shape, dtype=x.dtype, previous_chunks=previous_chunks
chunks, x.shape, dtype=dtype, previous_chunks=previous_chunks
)

if isinstance(x, zarr.Array): # zarr fast path
Expand Down
12 changes: 7 additions & 5 deletions cubed/nan_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
complex128,
float32,
float64,
int32,
uint32,
int64,
uint64,
)
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import namespace as nxp, PRECISION
from cubed.core import reduction

# TODO: refactor once nan functions are standardized:
Expand Down Expand Up @@ -69,12 +71,12 @@ def nansum(
raise TypeError("Only numeric dtypes are allowed in nansum")
if dtype is None:
if x.dtype in _signed_integer_dtypes:
dtype = int64
dtype = int64 if PRECISION == 64 else int32
elif x.dtype in _unsigned_integer_dtypes:
dtype = uint64
elif x.dtype == float32:
dtype = uint64 if PRECISION == 64 else uint32
elif x.dtype == float32 and PRECISION == 64:
dtype = float64
elif x.dtype == complex64:
elif x.dtype == complex64 and PRECISION == 64:
dtype = complex128
else:
dtype = x.dtype
Expand Down
13 changes: 6 additions & 7 deletions cubed/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

from numpy.random import Generator, Philox

from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.backend_array_api import numpy_array_to_backend_array, default_dtypes
from cubed.core.ops import map_blocks
from cubed.utils import block_id_to_offset, normalize_shape
from cubed.vendor.dask.array.core import normalize_chunks


def random(size, *, chunks=None, spec=None):
def random(size, *, chunks=None, spec=None, device=None):
"""Return random floats in the half-open interval [0.0, 1.0)."""
shape = normalize_shape(size)
dtype = nxp.float64
dtype = default_dtypes(device=device)['real floating']
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)
numblocks = tuple(map(len, chunks))
root_seed = pyrandom.getrandbits(128)
Expand All @@ -27,9 +26,9 @@ def random(size, *, chunks=None, spec=None):
)


def _random(x, numblocks=None, root_seed=None, block_id=None):
def _random(x, numblocks=None, root_seed=None, block_id=None, dtype=None):
stream_id = block_id_to_offset(block_id, numblocks)
rg = Generator(Philox(key=root_seed + stream_id))
out = rg.random(x.shape)
out = numpy_array_to_backend_array(out)
out = rg.random(x.shape, dtype=dtype)
out = numpy_array_to_backend_array(out, dtype=dtype)
return out
Loading
Loading