Skip to content
forked from pydata/xarray

Commit

Permalink
Improve interp performance (pydata#4069)
Browse files Browse the repository at this point in the history
* Fixes 2223

* more tests

* add @requires_scipy to test

* fix tests

* black

* update whatsnew. Added a test for nearest
  • Loading branch information
fujiisoup authored May 25, 2020
1 parent 1de38bc commit d1f7cb8
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ Breaking changes
(:pull:`3274`)
By `Elliott Sales de Andrade <https://github.com/QuLogic>`_

Enhancements
~~~~~~~~~~~~
- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp`
For orthogonal linear- and nearest-neighbor interpolation, we do 1d-interpolation sequentially
rather than interpolating in multidimensional space. (:issue:`2223`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

New Features
~~~~~~~~~~~~

Expand Down
15 changes: 14 additions & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,19 @@ def interp(var, indexes_coords, method, **kwargs):
# default behavior
kwargs["bounds_error"] = kwargs.get("bounds_error", False)

# check if the interpolation can be done in orthogonal manner
if (
len(indexes_coords) > 1
and method in ["linear", "nearest"]
and all(dest[1].ndim == 1 for dest in indexes_coords.values())
and len(set([d[1].dims[0] for d in indexes_coords.values()]))
== len(indexes_coords)
):
# interpolate sequentially
for dim, dest in indexes_coords.items():
var = interp(var, {dim: dest}, method, **kwargs)
return var

# target dimensions
dims = list(indexes_coords)
x, new_x = zip(*[indexes_coords[d] for d in dims])
Expand Down Expand Up @@ -659,7 +672,7 @@ def interp_func(var, x, new_x, method, kwargs):
New coordinates. Should not contain NaN.
method: string
{'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
1-dimensional itnterpolation.
1-dimensional interpolation.
{'linear', 'nearest'} for multidimensional interpolation
**kwargs:
Optional keyword arguments to be passed to scipy.interpolator
Expand Down
7 changes: 1 addition & 6 deletions xarray/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from xarray.core.indexes import default_indexes
from xarray.core.variable import IndexVariable, Variable

__all__ = (
"assert_allclose",
"assert_chunks_equal",
"assert_equal",
"assert_identical",
)
__all__ = ("assert_allclose", "assert_chunks_equal", "assert_equal", "assert_identical")


def _decode_string_data(data):
Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,21 @@ def test_3641():
times = xr.cftime_range("0001", periods=3, freq="500Y")
da = xr.DataArray(range(3), dims=["time"], coords=[times])
da.interp(time=["0002-05-01"])


@requires_scipy
@pytest.mark.parametrize("method", ["nearest", "linear"])
def test_decompose(method):
da = xr.DataArray(
np.arange(6).reshape(3, 2),
dims=["x", "y"],
coords={"x": [0, 1, 2], "y": [-0.1, -0.3]},
)
x_new = xr.DataArray([0.5, 1.5, 2.5], dims=["x1"])
y_new = xr.DataArray([-0.15, -0.25], dims=["y1"])
x_broadcast, y_broadcast = xr.broadcast(x_new, y_new)
assert x_broadcast.ndim == 2

actual = da.interp(x=x_new, y=y_new, method=method).drop(("x", "y"))
expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop(("x", "y"))
assert_allclose(actual, expected)

0 comments on commit d1f7cb8

Please sign in to comment.