From c72a1c4a4c52152bdab83f60f35615de28e8be7f Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 28 Jun 2022 10:38:47 +0100 Subject: [PATCH] Support NumPy array API (experimental) --- xarray/core/duck_array_ops.py | 6 ++++- xarray/core/indexing.py | 42 +++++++++++++++++++++++++++++++++++ xarray/core/utils.py | 5 +++-- xarray/core/variable.py | 4 +++- 4 files changed, 53 insertions(+), 4 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6e73ee41b40..2cd2fb3af04 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs): if name in ["sum", "prod"]: kwargs.pop("min_count", None) - func = getattr(np, name) + if hasattr(values, "__array_namespace__"): + xp = values.__array_namespace__() + func = getattr(xp, name) + else: + func = getattr(np, name) try: with warnings.catch_warnings(): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index cbbd507eeff..008532fd449 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -672,6 +672,8 @@ def as_indexable(array): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): return NdArrayLikeIndexingAdapter(array) + if hasattr(array, "__array_namespace__"): + return ArrayApiIndexingAdapter(array) raise TypeError(f"Invalid array type: {type(array)}") @@ -1281,6 +1283,46 @@ def __init__(self, array): self.array = array +class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap an array API array to use explicit indexing.""" + __slots__ = ("array",) + + def __init__(self, array): + if not hasattr(array, "__array_namespace__"): + raise TypeError( + "ArrayApiIndexingAdapter must wrap an object that " + "implements the __array_namespace__ protocol" + ) + self.array = array + + def __getitem__(self, key): + if isinstance(key, BasicIndexer): + return self.array[key.tuple] + elif isinstance(key, OuterIndexer): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = key.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value + else: + assert isinstance(key, VectorizedIndexer) + raise TypeError("Vectorized indexing is not supported") + + def __setitem__(self, key, value): + if isinstance(key, BasicIndexer): + self.array[key.tuple] = value + elif isinstance(key, OuterIndexer): + self.array[key.tuple] = value + else: + assert isinstance(key, VectorizedIndexer) + raise TypeError("Vectorized indexing is not supported") + + def transpose(self, order): + xp = self.array.__array_namespace__() + return xp.permute_dims(self.array, order) + + class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" diff --git a/xarray/core/utils.py b/xarray/core/utils.py index b253f1661ae..4be532d6bcf 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -262,8 +262,8 @@ def is_duck_array(value: Any) -> bool: hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") - and hasattr(value, "__array_function__") - and hasattr(value, "__array_ufunc__") + and ((hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__")) ) @@ -297,6 +297,7 @@ def _is_scalar(value, include_0d): or not ( isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) or hasattr(value, "__array_function__") + or hasattr(value, "__array_namespace__") ) ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2d115ff0ed9..baeae480e81 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -201,7 +201,9 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, (Variable, DataArray)): return data.data - if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): + if (isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES) + or hasattr(data, "__array_function__") + or hasattr(data, "__array_namespace__")): return _maybe_wrap_data(data) if isinstance(data, tuple):