From e41e6faefeed650a55894048f72241308f5a6511 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Sat, 3 Aug 2024 18:00:14 +0100 Subject: [PATCH] Exposing `to_default_precision` as a backend_array_api. --- cubed/array_api/creation_functions.py | 13 +++---------- cubed/backend_array_api.py | 7 +++++++ cubed/core/ops.py | 10 +++++++--- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index f45d9a847..a5317a422 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -1,7 +1,7 @@ import math from typing import TYPE_CHECKING, Iterable, List -from cubed.backend_array_api import namespace as nxp +from cubed.backend_array_api import namespace as nxp, to_default_precision from cubed.backend_array_api import default_dtypes from cubed.core import Plan, gensym from cubed.core.ops import map_blocks @@ -18,13 +18,6 @@ from .array_object import Array -def _to_default_precision(dtype, *, device=None): - """Returns a dtype of the same kind with the default precision.""" - for k, dtype_ in default_dtypes(device=device).items(): - if nxp.isdtype(dtype, k): - return dtype_ - - def arange( start, /, stop=None, step=1, *, dtype=None, device=None, chunks="auto", spec=None ) -> "Array": @@ -35,7 +28,7 @@ def arange( # TODO: Use inspect API dtype = nxp.arange(start, stop, step * num if num else step).dtype # the default nxp call does not adjust the data type to the default precision. - dtype = _to_default_precision(dtype, device=device) + dtype = to_default_precision(dtype, device=device) chunks = normalize_chunks(chunks, shape=(num,), dtype=dtype) chunksize = chunks[0][0] @@ -77,7 +70,7 @@ def asarray( # ensure blocks are arrays a = nxp.asarray(a, dtype=dtype) if dtype is None: - dtype = _to_default_precision(a.dtype, device=device) + dtype = to_default_precision(a.dtype, device=device) a = a.astype(dtype) chunksize = to_chunksize(normalize_chunks(chunks, shape=a.shape, dtype=dtype)) diff --git a/cubed/backend_array_api.py b/cubed/backend_array_api.py index e976f3220..b6743ba87 100644 --- a/cubed/backend_array_api.py +++ b/cubed/backend_array_api.py @@ -64,3 +64,10 @@ def numpy_array_to_backend_array(arr, *, dtype=None): if isinstance(arr, dict): return {k: namespace.asarray(v, dtype=dtype) for k, v in arr.items()} return namespace.asarray(arr, dtype=dtype) + + +def to_default_precision(dtype, *, device=None): + """Returns a dtype of the same kind with the default precision.""" + for k, dtype_ in default_dtypes(device=device).items(): + if namespace.isdtype(dtype, k): + return dtype_ diff --git a/cubed/core/ops.py b/cubed/core/ops.py index e6d1ebc97..a02204f7a 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -15,7 +15,7 @@ from toolz import accumulate, map from cubed import config -from cubed.backend_array_api import backend_array_to_numpy_array +from cubed.backend_array_api import backend_array_to_numpy_array, to_default_precision from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.array import CoreArray, check_array_specs, compute, gensym @@ -41,7 +41,7 @@ from cubed.array_api.array_object import Array -def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array": +def from_array(x, chunks="auto", asarray=None, spec=None, device=None) -> "Array": """Create a Cubed array from an array-like object.""" if isinstance(x, CoreArray): @@ -49,9 +49,13 @@ def from_array(x, chunks="auto", asarray=None, spec=None) -> "Array": "Array is already a Cubed array. Use 'asarray' or 'rechunk' instead." ) + dtype = to_default_precision(x.dtype) + if x.dtype != dtype: + x = x.astype(dtype) + previous_chunks = getattr(x, "chunks", None) outchunks = normalize_chunks( - chunks, x.shape, dtype=x.dtype, previous_chunks=previous_chunks + chunks, x.shape, dtype=dtype, previous_chunks=previous_chunks ) if isinstance(x, zarr.Array): # zarr fast path