From 3ef925722cfabf1b560495c13a8986e96173c572 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 1 Jan 2020 14:35:33 -0700 Subject: [PATCH] apply_func: Set meta=np.ndarray when vectorize=True and dask="parallelized" Closes #3574 --- doc/whats-new.rst | 3 +++ xarray/core/computation.py | 14 +++++++++++++- xarray/tests/test_computation.py | 18 ++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 00d1c50780e..58f842b10c8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -44,6 +44,9 @@ New Features Bug fixes ~~~~~~~~~ +- Make applying a user-defined function that adds new dimensions using + :py:func:`apply_ufunc` with ``vectorize=True`` work with ``dask > 2.0``. + By `Deepak Cherian `_. - Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger `_. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 643c1137d6c..15a0a95da40 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -547,6 +547,7 @@ def apply_variable_ufunc( output_dtypes=None, output_sizes=None, keep_attrs=False, + vectorize=False, ): """Apply a ndarray level function over Variable and/or ndarray objects. """ @@ -579,6 +580,7 @@ def apply_variable_ufunc( elif dask == "parallelized": input_dims = [broadcast_dims + dims for dims in signature.input_core_dims] numpy_func = func + meta = np.ndarray if vectorize else None def func(*arrays): return _apply_blockwise( @@ -589,6 +591,7 @@ def func(*arrays): signature, output_dtypes, output_sizes, + meta, ) elif dask == "allowed": @@ -647,7 +650,14 @@ def func(*arrays): def _apply_blockwise( - func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None + func, + args, + input_dims, + output_dims, + signature, + output_dtypes, + output_sizes=None, + meta=None, ): import dask.array @@ -719,6 +729,7 @@ def _apply_blockwise( dtype=dtype, concatenate=True, new_axes=output_sizes, + meta=meta, ) @@ -1005,6 +1016,7 @@ def earth_mover_distance(first_samples, dask=dask, output_dtypes=output_dtypes, output_sizes=output_sizes, + vectorize=vectorize, ) if any(isinstance(a, GroupBy) for a in args): diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 1f2634cc9b0..5a386de64d2 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -817,6 +817,24 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@requires_dask +def test_vectorize_dask_new_output_dims(): + # regression test for GH3574 + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + func = lambda x: x[np.newaxis, ...] + expected = data_array.expand_dims("z") + actual = apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + output_sizes={"z": 1}, + ).transpose(*expected.dims) + assert_identical(expected, actual) + + def test_output_wrong_number(): variable = xr.Variable("x", np.arange(10))