From 6df8bd606a8a9a3378c7672c087e08ced00b2e15 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Sat, 9 Nov 2024 00:33:42 -0500 Subject: [PATCH] Dispatch to Dask if nanquantile is available (#9719) * Dispatch to Dask is nanquantile is available * Fixup * Change test --- xarray/core/variable.py | 3 ++- xarray/tests/__init__.py | 1 + xarray/tests/test_variable.py | 14 +++++++++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 7b6598a6406..d732a18fe23 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -46,6 +46,7 @@ ) from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array +from xarray.namedarray.utils import module_available from xarray.util.deprecation_helpers import deprecate_dims NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -1948,7 +1949,7 @@ def _wrapper(npa, **kwargs): output_core_dims=[["quantile"]], output_dtypes=[np.float64], dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}), - dask="parallelized", + dask="allowed" if module_available("dask", "2024.11.0") else "parallelized", kwargs=kwargs, ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7293a6fd931..5ed334e61dd 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -107,6 +107,7 @@ def _importorskip( has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") has_cftime, requires_cftime = _importorskip("cftime") has_dask, requires_dask = _importorskip("dask") +has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9ebd4e4a4d3..0ed47c2b5fe 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -36,6 +36,7 @@ assert_equal, assert_identical, assert_no_warnings, + has_dask_ge_2024_11_0, has_pandas_3, raise_if_dask_computes, requires_bottleneck, @@ -1871,9 +1872,16 @@ def test_quantile_interpolation_deprecation(self, method) -> None: def test_quantile_chunked_dim_error(self): v = Variable(["x", "y"], self.d).chunk({"x": 2}) - # this checks for ValueError in dask.array.apply_gufunc - with pytest.raises(ValueError, match=r"consists of multiple chunks"): - v.quantile(0.5, dim="x") + if has_dask_ge_2024_11_0: + # Dask rechunks + np.testing.assert_allclose( + v.compute().quantile(0.5, dim="x"), v.quantile(0.5, dim="x") + ) + + else: + # this checks for ValueError in dask.array.apply_gufunc + with pytest.raises(ValueError, match=r"consists of multiple chunks"): + v.quantile(0.5, dim="x") @pytest.mark.parametrize("compute_backend", ["numbagg", None], indirect=True) @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]])