From d1f7cb8fd95d588d3f7a7e90916c25747b90ad5a Mon Sep 17 00:00:00 2001 From: Keisuke Fujii Date: Tue, 26 May 2020 05:02:36 +0900 Subject: [PATCH] Improve interp performance (#4069) * Fixes 2223 * more tests * add @requires_scipy to test * fix tests * black * update whatsnew. Added a test for nearest --- doc/whats-new.rst | 7 +++++++ xarray/core/missing.py | 15 ++++++++++++++- xarray/testing.py | 7 +------ xarray/tests/test_interp.py | 18 ++++++++++++++++++ 4 files changed, 40 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e1012283c94..59c7faa8973 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,13 @@ Breaking changes (:pull:`3274`) By `Elliott Sales de Andrade `_ +Enhancements +~~~~~~~~~~~~ +- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` + For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially + rather than interpolating in multidimensional space. (:issue:`2223`) + By `Keisuke Fujii `_. + New Features ~~~~~~~~~~~~ diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f973b4a5468..374eaec1fa7 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -619,6 +619,19 @@ def interp(var, indexes_coords, method, **kwargs): # default behavior kwargs["bounds_error"] = kwargs.get("bounds_error", False) + # check if the interpolation can be done in orthogonal manner + if ( + len(indexes_coords) > 1 + and method in ["linear", "nearest"] + and all(dest[1].ndim == 1 for dest in indexes_coords.values()) + and len(set([d[1].dims[0] for d in indexes_coords.values()])) + == len(indexes_coords) + ): + # interpolate sequentially + for dim, dest in indexes_coords.items(): + var = interp(var, {dim: dest}, method, **kwargs) + return var + # target dimensions dims = list(indexes_coords) x, new_x = zip(*[indexes_coords[d] for d in dims]) @@ -659,7 +672,7 @@ def interp_func(var, x, new_x, method, kwargs): New coordinates. Should not contain NaN. method: string {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for - 1-dimensional itnterpolation. + 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation **kwargs: Optional keyword arguments to be passed to scipy.interpolator diff --git a/xarray/testing.py b/xarray/testing.py index ac189f7e023..e7bf5f9221a 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -10,12 +10,7 @@ from xarray.core.indexes import default_indexes from xarray.core.variable import IndexVariable, Variable -__all__ = ( - "assert_allclose", - "assert_chunks_equal", - "assert_equal", - "assert_identical", -) +__all__ = ("assert_allclose", "assert_chunks_equal", "assert_equal", "assert_identical") def _decode_string_data(data): diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 0502348160e..7a0dda216e2 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -699,3 +699,21 @@ def test_3641(): times = xr.cftime_range("0001", periods=3, freq="500Y") da = xr.DataArray(range(3), dims=["time"], coords=[times]) da.interp(time=["0002-05-01"]) + + +@requires_scipy +@pytest.mark.parametrize("method", ["nearest", "linear"]) +def test_decompose(method): + da = xr.DataArray( + np.arange(6).reshape(3, 2), + dims=["x", "y"], + coords={"x": [0, 1, 2], "y": [-0.1, -0.3]}, + ) + x_new = xr.DataArray([0.5, 1.5, 2.5], dims=["x1"]) + y_new = xr.DataArray([-0.15, -0.25], dims=["y1"]) + x_broadcast, y_broadcast = xr.broadcast(x_new, y_new) + assert x_broadcast.ndim == 2 + + actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y")) + expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y")) + assert_allclose(actual, expected)