Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic curvefit implementation #4849

Merged
merged 14 commits into from
Mar 31, 2021
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ Computation
Dataset.integrate
Dataset.map_blocks
Dataset.polyfit
Dataset.curvefit

**Aggregation**:
:py:attr:`~Dataset.all`
Expand Down Expand Up @@ -371,7 +372,7 @@ Computation
DataArray.integrate
DataArray.polyfit
DataArray.map_blocks

DataArray.curvefit

**Aggregation**:
:py:attr:`~DataArray.all`
Expand Down
83 changes: 83 additions & 0 deletions doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,89 @@ The inverse operation is done with :py:meth:`~xarray.polyval`,
.. note::
These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`.


.. _compute.curvefit:

Fitting arbitrary functions
===========================

Xarray objects also provide an interface for fitting more complex functions using
:py:meth:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts
user-defined functions and can fit along multiple coordinates.

For example, we can fit a relationship between two ``DataArray`` objects, maintaining
a unique fit at each spatial coordinate but aggregating over the time dimension:

.. ipython:: python

def exponential(x, a, xc):
return np.exp((x - xc) / a)


x = np.arange(-5, 5, 0.1)
t = np.arange(-5, 5, 0.1)
X, T = np.meshgrid(x, t)
Z1 = np.random.uniform(low=-5, high=5, size=X.shape)
Z2 = exponential(Z1, 3, X)
Z3 = exponential(Z1, 1, -X)

ds = xr.Dataset(
data_vars=dict(
var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3)
),
coords={"t": t, "x": x},
)
ds[["var2", "var1"]].curvefit(
coords=ds.var1,
func=exponential,
reduce_dim="t",
bounds={"a": (0.5, 5), "xc": (-5, 5)},
)

We can also fit multi-dimensional functions, and even use a wrapper function to
simultaneously fit a summation of several functions, such as this field containing
two gaussian peaks:

.. ipython:: python

def gaussian_2d(coords, a, xc, yc, xalpha, yalpha):
x, y = coords
z = a * np.exp(
-np.square(x - xc) / 2 / np.square(xalpha)
- np.square(y - yc) / 2 / np.square(yalpha)
)
return z


def multi_peak(coords, *args):
z = np.zeros(coords[0].shape)
for i in range(len(args) // 5):
z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5])
return z


x = np.arange(-5, 5, 0.1)
y = np.arange(-5, 5, 0.1)
X, Y = np.meshgrid(x, y)

n_peaks = 2
names = ["a", "xc", "yc", "xalpha", "yalpha"]
names = [f"{name}{i}" for i in range(n_peaks) for name in names]
Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1)
Z += np.random.normal(scale=0.1, size=Z.shape)

da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x})
da.curvefit(
coords=["x", "y"],
func=multi_peak,
param_names=names,
kwargs={"maxfev": 10000},
)

.. note::
This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`.


.. _compute.broadcasting:

Broadcasting by dimension name
Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Breaking changes
- xarray no longer supports python 3.6

The minimum versions of some other dependencies were changed:

============ ====== ====
Package Old New
============ ====== ====
Expand Down Expand Up @@ -62,6 +63,8 @@ New Features
- :py:meth:`DataArray.swap_dims` & :py:meth:`Dataset.swap_dims` now accept dims
in the form of kwargs as well as a dict, like most similar methods.
By `Maximilian Roos <https://github.com/max-sixty>`_.
- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`)
By `Sam Levang <https://github.com/slevang>`_.

Bug fixes
~~~~~~~~~
Expand Down
83 changes: 83 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4306,6 +4306,89 @@ def argmax(
else:
return self._replace_maybe_drop_dims(result)

def curvefit(
self,
coords: Union["DataArray", Iterable["DataArray"]],
func: Callable[..., Any],
reduce_dim: Union[Hashable, Iterable[Hashable]] = None,
skipna: bool = True,
cov: bool = False,
p0: Dict[str, Any] = None,
bounds: Dict[str, Any] = None,
param_names: Sequence[str] = None,
kwargs: Dict[str, Any] = None,
):
"""
Curve fitting optimization for arbitrary functions.

Wraps `scipy.optimize.curve_fit` with `apply_ufunc`.

Parameters
----------
coords : DataArray, str or sequence of DataArray, str
Independent coordinate(s) over which to perform the curve fitting. Must share
at least one dimension with the calling object. When fitting multi-dimensional
functions, supply `coords` as a sequence in the same order as arguments in
`func`. To fit along existing dimensions of the calling object, `coords` can
also be specified as a str or sequence of strs.
func : callable
User specified function in the form `f(x, *params)` which returns a numpy
array of length x. `params` are the fittable parameters which are optimized
by scipy curve_fit. `x` can also be specified as a sequence containing multiple
coordinates, e.g. `f((x, y), *params)`.
reduce_dim : str or sequence of str
Additional dimension(s) over which to aggregate while fitting. For example,
calling `ds.curvefit(coords='time', reduce_dims=['lat', 'lon'], ...)` will
aggregate all lat and lon points and fit the specified function along the
time dimension.
skipna : bool, optional
Whether to skip missing values when fitting. Default is True.
cov : bool, optional
Whether to return the covariance matrix in addition to the coefficients.
p0 : dictionary, optional
Optional dictionary of parameter names to initial guesses passed to the
`curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
be assigned initial values following the default scipy behavior.
bounds : dictionary, optional
Optional dictionary of parameter names to bounding values passed to the
`curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
will be unbounded following the default scipy behavior.
param_names: iterable, optional
Sequence of names for the fittable parameters of `func`. If not supplied,
this will be automatically determined by arguments of `func`. `param_names`
should be manually supplied when fitting a function that takes a variable
number of parameters.
kwargs : dictionary
Additional keyword arguments to passed to scipy curve_fit.

Returns
-------
curvefit_results : Dataset
A single dataset which contains:

[var]_curvefit_coefficients
The coefficients of the best fit.
[var]_curvefit_covariance
The covariance matrix of the coefficient estimates (only included if
`cov=True`)

See also
--------
DataArray.polyfit
scipy.optimize.curve_fit
"""
return self._to_temp_dataset().curvefit(
coords,
func,
reduce_dim=reduce_dim,
skipna=skipna,
cov=cov,
p0=p0,
bounds=bounds,
param_names=param_names,
kwargs=kwargs,
)

# this needs to be at the end, or mypy will confuse with `str`
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
str = utils.UncachedAccessor(StringAccessor)
Expand Down
Loading