Skip to content

Commit

Permalink
Support NumPy array API (experimental)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jul 5, 2022
1 parent 787a96c commit c72a1c4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs):
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)

func = getattr(np, name)
if hasattr(values, "__array_namespace__"):
xp = values.__array_namespace__()
func = getattr(xp, name)
else:
func = getattr(np, name)

try:
with warnings.catch_warnings():
Expand Down
42 changes: 42 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def as_indexable(array):
return DaskIndexingAdapter(array)
if hasattr(array, "__array_function__"):
return NdArrayLikeIndexingAdapter(array)
if hasattr(array, "__array_namespace__"):
return ArrayApiIndexingAdapter(array)

raise TypeError(f"Invalid array type: {type(array)}")

Expand Down Expand Up @@ -1281,6 +1283,46 @@ def __init__(self, array):
self.array = array


class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""
__slots__ = ("array",)

def __init__(self, array):
if not hasattr(array, "__array_namespace__"):
raise TypeError(
"ArrayApiIndexingAdapter must wrap an object that "
"implements the __array_namespace__ protocol"
)
self.array = array

def __getitem__(self, key):
if isinstance(key, BasicIndexer):
return self.array[key.tuple]
elif isinstance(key, OuterIndexer):
# manual orthogonal indexing (implemented like DaskIndexingAdapter)
key = key.tuple
value = self.array
for axis, subkey in reversed(list(enumerate(key))):
value = value[(slice(None),) * axis + (subkey, Ellipsis)]
return value
else:
assert isinstance(key, VectorizedIndexer)
raise TypeError("Vectorized indexing is not supported")

def __setitem__(self, key, value):
if isinstance(key, BasicIndexer):
self.array[key.tuple] = value
elif isinstance(key, OuterIndexer):
self.array[key.tuple] = value
else:
assert isinstance(key, VectorizedIndexer)
raise TypeError("Vectorized indexing is not supported")

def transpose(self, order):
xp = self.array.__array_namespace__()
return xp.permute_dims(self.array, order)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
5 changes: 3 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def is_duck_array(value: Any) -> bool:
hasattr(value, "ndim")
and hasattr(value, "shape")
and hasattr(value, "dtype")
and hasattr(value, "__array_function__")
and hasattr(value, "__array_ufunc__")
and ((hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
or hasattr(value, "__array_namespace__"))
)


Expand Down Expand Up @@ -297,6 +297,7 @@ def _is_scalar(value, include_0d):
or not (
isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(value, "__array_function__")
or hasattr(value, "__array_namespace__")
)
)

Expand Down
4 changes: 3 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def as_compatible_data(data, fastpath=False):
if isinstance(data, (Variable, DataArray)):
return data.data

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
if (isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES)
or hasattr(data, "__array_function__")
or hasattr(data, "__array_namespace__")):
return _maybe_wrap_data(data)

if isinstance(data, tuple):
Expand Down

0 comments on commit c72a1c4

Please sign in to comment.