Skip to content

Commit

Permalink
Exposing to_default_precision as a backend_array_api.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Sep 21, 2024
1 parent 4687d33 commit 92a31cc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
13 changes: 3 additions & 10 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from typing import TYPE_CHECKING, Iterable, List

from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import namespace as nxp, to_default_precision
from cubed.backend_array_api import default_dtypes
from cubed.core import Plan, gensym
from cubed.core.ops import map_blocks
Expand All @@ -18,13 +18,6 @@
from .array_object import Array


def _to_default_precision(dtype, *, device=None):
"""Returns a dtype of the same kind with the default precision."""
for k, dtype_ in default_dtypes(device=device).items():
if nxp.isdtype(dtype, k):
return dtype_


def arange(
start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None
) -> "Array":
Expand All @@ -35,7 +28,7 @@ def arange(
# TODO: Use inspect API
dtype = nxp.arange(start, stop, step * num if num else step).dtype
# the default nxp call does not adjust the data type to the default precision.
dtype = _to_default_precision(dtype, device=device)
dtype = to_default_precision(dtype, device=device)

chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype)
chunksize = chunks[0][0]
Expand Down Expand Up @@ -77,7 +70,7 @@ def asarray(
# ensure blocks are arrays
a = nxp.asarray(a, dtype=dtype)
if dtype is None:
dtype = _to_default_precision(a.dtype, device=device)
dtype = to_default_precision(a.dtype, device=device)
a = a.astype(dtype)

chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype))
Expand Down
7 changes: 7 additions & 0 deletions cubed/backend_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ def numpy_array_to_backend_array(arr, *, dtype=None):
if isinstance(arr, dict):
return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()}
return namespace.asarray(arr, dtype=dtype)


def to_default_precision(dtype, *, device=None):
"""Returns a dtype of the same kind with the default precision."""
for k, dtype_ in default_dtypes(device=device).items():
if namespace.isdtype(dtype, k):
return dtype_
10 changes: 7 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from toolz import accumulate, map

from cubed import config
from cubed.backend_array_api import backend_array_to_numpy_array
from cubed.backend_array_api import backend_array_to_numpy_array, to_default_precision
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
Expand All @@ -41,17 +41,21 @@
from cubed.array_api.array_object import Array


def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array":
def from_array(x, chunks="auto", asarray=None, spec=None, device=None) -> "Array":
"""Create a Cubed array from an array-like object."""

if isinstance(x, CoreArray):
raise ValueError(
"Array is already a Cubed array. Use 'asarray' or 'rechunk' instead."
)

dtype = to_default_precision(x.dtype)
if x.dtype != dtype:
x = x.astype(dtype)

previous_chunks = getattr(x, "chunks", None)
outchunks = normalize_chunks(
chunks, x.shape, dtype=x.dtype, previous_chunks=previous_chunks
chunks, x.shape, dtype=dtype, previous_chunks=previous_chunks
)

if isinstance(x, zarr.Array): # zarr fast path
Expand Down

0 comments on commit 92a31cc

Please sign in to comment.