diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f608468ed9f..5116a6d2651 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -734,6 +734,13 @@ def interp_func(var, x, new_x, method, kwargs): # if usefull, re-use localize for each chunk of new_x localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + # scipy.interpolate.interp1d always forces to float. + # Use the same check for blockwise as well: + if not issubclass(var.dtype.type, np.inexact): + dtype = np.float_ + else: + dtype = var.dtype + return da.blockwise( _dask_aware_interpnd, out_ind, @@ -742,7 +749,7 @@ def interp_func(var, x, new_x, method, kwargs): interp_kwargs=kwargs, localize=localize, concatenate=True, - dtype=var.dtype, + dtype=dtype, new_axes=new_axes, ) diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index 21d82b1948b..2ab3508b667 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -370,6 +370,20 @@ def test_interpolate_dask_raises_for_invalid_chunk_dim(): da.interpolate_na("time") +@requires_dask +@requires_scipy +@pytest.mark.parametrize("dtype, method", [(int, "linear"), (int, "nearest")]) +def test_interpolate_dask_expected_dtype(dtype, method): + da = xr.DataArray( + data=np.array([0, 1], dtype=dtype), + dims=["time"], + coords=dict(time=np.array([0, 1])), + ).chunk(dict(time=2)) + da = da.interp(time=np.array([0, 0.5, 1, 2]), method=method) + + assert da.dtype == da.compute().dtype + + @requires_bottleneck def test_ffill(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")