diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6aba6617c37..35239004af4 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -16,8 +16,8 @@ import pandas as pd from numpy import all as array_all # noqa from numpy import any as array_any # noqa +from numpy import around # noqa from numpy import zeros_like # noqa -from numpy import around, broadcast_to # noqa from numpy import concatenate as _concatenate from numpy import ( # noqa einsum, @@ -207,6 +207,11 @@ def as_shared_dtype(scalars_or_arrays, xp=np): return [astype(x, out_type, copy=False) for x in arrays] +def broadcast_to(array, shape): + xp = get_array_namespace(array) + return xp.broadcast_to(array, shape) + + def lazy_array_equiv(arr1, arr2): """Like array_equal, but doesn't actually compare values. Returns True when arr1, arr2 identical or their dask tokens are equal. @@ -311,6 +316,9 @@ def fillna(data, other): def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" + if hasattr(arrays[0], "__array_namespace__"): + xp = get_array_namespace(arrays[0]) + return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) return _concatenate(as_shared_dtype(arrays), axis=axis) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 7940c979249..fddaa120970 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -64,6 +64,27 @@ def test_astype(arrays) -> None: assert_equal(actual, expected) +def test_broadcast(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + np_arr2 = xr.DataArray(np.array([1.0, 2.0]), dims="x") + xp_arr2 = xr.DataArray(xp.asarray([1.0, 2.0]), dims="x") + + expected = xr.broadcast(np_arr, np_arr2) + actual = xr.broadcast(xp_arr, xp_arr2) + assert len(actual) == len(expected) + for a, e in zip(actual, expected): + assert isinstance(a.data, Array) + assert_equal(a, e) + + +def test_concat(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = xr.concat((np_arr, np_arr), dim="x") + actual = xr.concat((xp_arr, xp_arr), dim="x") + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = np_arr[:, 0]