Skip to content
forked from pydata/xarray

Commit

Permalink
apply_func: Set meta=np.ndarray when vectorize=True and dask="paralle…
Browse files Browse the repository at this point in the history
…lized"

Closes pydata#3574
  • Loading branch information
dcherian committed Jan 2, 2020
1 parent b3d3b44 commit 3ef9257
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ New Features

Bug fixes
~~~~~~~~~
- Make applying a user-defined function that adds new dimensions using
:py:func:`apply_ufunc` with ``vectorize=True`` work with ``dask > 2.0``.
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix :py:meth:`xarray.combine_by_coords` to allow for combining incomplete
hypercubes of Datasets (:issue:`3648`). By `Ian Bolliger
<https://github.com/bolliger32>`_.
Expand Down
14 changes: 13 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def apply_variable_ufunc(
output_dtypes=None,
output_sizes=None,
keep_attrs=False,
vectorize=False,
):
"""Apply a ndarray level function over Variable and/or ndarray objects.
"""
Expand Down Expand Up @@ -579,6 +580,7 @@ def apply_variable_ufunc(
elif dask == "parallelized":
input_dims = [broadcast_dims + dims for dims in signature.input_core_dims]
numpy_func = func
meta = np.ndarray if vectorize else None

def func(*arrays):
return _apply_blockwise(
Expand All @@ -589,6 +591,7 @@ def func(*arrays):
signature,
output_dtypes,
output_sizes,
meta,
)

elif dask == "allowed":
Expand Down Expand Up @@ -647,7 +650,14 @@ def func(*arrays):


def _apply_blockwise(
func, args, input_dims, output_dims, signature, output_dtypes, output_sizes=None
func,
args,
input_dims,
output_dims,
signature,
output_dtypes,
output_sizes=None,
meta=None,
):
import dask.array

Expand Down Expand Up @@ -719,6 +729,7 @@ def _apply_blockwise(
dtype=dtype,
concatenate=True,
new_axes=output_sizes,
meta=meta,
)


Expand Down Expand Up @@ -1005,6 +1016,7 @@ def earth_mover_distance(first_samples,
dask=dask,
output_dtypes=output_dtypes,
output_sizes=output_sizes,
vectorize=vectorize,
)

if any(isinstance(a, GroupBy) for a in args):
Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,24 @@ def test_vectorize_dask():
assert_identical(expected, actual)


@requires_dask
def test_vectorize_dask_new_output_dims():
# regression test for GH3574
data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y"))
func = lambda x: x[np.newaxis, ...]
expected = data_array.expand_dims("z")
actual = apply_ufunc(
func,
data_array.chunk({"x": 1}),
output_core_dims=[["z"]],
vectorize=True,
dask="parallelized",
output_dtypes=[float],
output_sizes={"z": 1},
).transpose(*expected.dims)
assert_identical(expected, actual)


def test_output_wrong_number():
variable = xr.Variable("x", np.arange(10))

Expand Down

0 comments on commit 3ef9257

Please sign in to comment.