From 7eb6af62a8f9670c2dabbbb637b6728f6b92051a Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 1 Nov 2022 09:22:07 +0000 Subject: [PATCH 1/2] Make `broadcast` and `concat` work with the Array API. --- xarray/core/duck_array_ops.py | 10 +++++++++- xarray/tests/test_array_api.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6aba6617c37..f2dbcb5ac7d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -17,7 +17,7 @@ from numpy import all as array_all # noqa from numpy import any as array_any # noqa from numpy import zeros_like # noqa -from numpy import around, broadcast_to # noqa +from numpy import around # noqa from numpy import concatenate as _concatenate from numpy import ( # noqa einsum, @@ -207,6 +207,11 @@ def as_shared_dtype(scalars_or_arrays, xp=np): return [astype(x, out_type, copy=False) for x in arrays] +def broadcast_to(array, shape): + xp = get_array_namespace(array) + return xp.broadcast_to(array, shape) + + def lazy_array_equiv(arr1, arr2): """Like array_equal, but doesn't actually compare values. Returns True when arr1, arr2 identical or their dask tokens are equal. @@ -311,6 +316,9 @@ def fillna(data, other): def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" + if hasattr(arrays[0], "__array_namespace__"): + xp = get_array_namespace(arrays[0]) + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 7940c979249..fddaa120970 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -64,6 +64,27 @@ def test_astype(arrays) -> None: assert_equal(actual, expected) +def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = xr.broadcast(np_arr, np_arr2) + actual = xr.broadcast(xp_arr, xp_arr2) + assert len(actual) == len(expected) + for a, e in zip(actual, expected): + assert isinstance(a.data, Array) + assert_equal(a, e) + + +def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = xr.concat((np_arr, np_arr), dim="x") + actual = xr.concat((xp_arr, xp_arr), dim="x") + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = np_arr[:, 0] From d483a577871e46098416ce97cac06505ca12d498 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Dec 2022 10:50:44 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/duck_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index f2dbcb5ac7d..35239004af4 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,8 +16,8 @@ import pandas as pd from numpy import all as array_all # noqa from numpy import any as array_any # noqa -from numpy import zeros_like # noqa from numpy import around # noqa +from numpy import zeros_like # noqa from numpy import concatenate as _concatenate from numpy import ( # noqa einsum,