Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic curvefit implementation #4849

Merged
merged 14 commits into from
Mar 31, 2021
Merged

Basic curvefit implementation #4849

merged 14 commits into from
Mar 31, 2021

Conversation

slevang
Copy link
Contributor

@slevang slevang commented Jan 30, 2021

  • Closes General curve fitting method #4300
  • Tests added
  • Passes pre-commit run --all-files
  • User visible changes (including notable bug fixes) are documented in whats-new.rst
  • New functions/methods are listed in api.rst

This is a simple implementation of a more general curve-fitting API as discussed in #4300, using the existing scipy curve_fit functionality wrapped with apply_ufunc. It works for arbitrary user-supplied 1D functions that ingest numpy arrays. Formatting and nomenclature of the outputs was largely copied from .polyfit, but could probably be improved.

@pep8speaks
Copy link

pep8speaks commented Jan 30, 2021

Hello @slevang! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-03-30 22:35:02 UTC

@slevang slevang marked this pull request as draft January 30, 2021 12:36
@TomNicholas
Copy link
Member

This is great, thanks for submitting this!

I just had a go with it, and it worked nicely. I have a couple of suggestions for improving it though:

  1. Different fit coefficients as differently-named variables in the output, rather than indexed with a coordinate. This would then be consistent with Dataset.polyfit, which returns a set of different variables [var]_polyfit_coefficients. We could also get the names from the names of the keyword args when inspecting the function, and if it fails to get names just call them like param1_fit_coefficients, param2_fit_coefficients etc.

    What you have now works nicely though, so perhaps you could just reorganise the result before returning, like Dataset.polyfit does?

  2. Initial guesses for each fit parameter. At the moment the user has to pass an ordered array of initial guesses through like

    da.curvefit(x=da.x, dim='x', func=linear, kwargs={'p0': [m_guess, c_guess]})

    but it would be nicer to just pass them as a dictionary like

    da.curvefit(x=da.x, dim='x', func=linear, initial_guess={'m': m_guess, 'c': c_guess})

    or even have the guesses read from the function definition maybe? i.e.

    def linear(x, m=m_guess, c=c_guess):
        return m*x + c
  3. (Stretch goal) Ability to fit >1D functions, e.g. fit a 2D gaussian to find a peak in a 2D image. But if we get the API right then this could be left to a later PR.

Also, the whole argument inspection thing probably deserves a few dedicated tests, in addition to testing the fitting functionality.

@slevang
Copy link
Contributor Author

slevang commented Jan 31, 2021

  1. Different fit coefficients as differently-named variables in the output, rather than indexed with a coordinate.

I think the way I configured things now does replicate the polyfit results. For example:

ds = xr.tutorial.open_dataset('air_temperature')
ds['air2'] = ds.air.copy()
ds.polyfit(dim='time', deg=2)

<xarray.Dataset>
Dimensions:                    (degree: 3, lat: 25, lon: 53)
Coordinates:
  * degree                     (degree) int64 2 1 0
  * lat                        (lat) float64 75.0 72.5 70.0 ... 20.0 17.5 15.0
  * lon                        (lon) float64 200.0 202.5 205.0 ... 327.5 330.0
Data variables:
    air_polyfit_coefficients   (degree, lat, lon) float64 -1.162e-32 ... 1.13...
    air2_polyfit_coefficients  (degree, lat, lon) float64 -1.162e-32 ... 1.14...

Compared to this:

def square(x, a, b ,c):
    return a*np.power(x, 2) + b*x + c

ds.curvefit(x=ds.time, dim='time', func=square)

<xarray.Dataset>
Dimensions:                     (lat: 25, lon: 53, param: 3)
Coordinates:
  * lat                         (lat) float32 75.0 72.5 70.0 ... 20.0 17.5 15.0
  * lon                         (lon) float32 200.0 202.5 205.0 ... 327.5 330.0
  * param                       (param) <U1 'a' 'b' 'c'
Data variables:
    air_curvefit_coefficients   (param, lat, lon) float64 -1.162e-32 ... 1.13...
    air2_curvefit_coefficients  (param, lat, lon) float64 -1.162e-32 ... 1.14...

In both cases, each variable in the dataset returns a separate coefficients variable, and all fittable coefficients are stacked along a dimension, degree for the more specific polyfit case and a generic param for curvefit.

  1. Initial guesses for each fit parameter.

Yeah this would be good. Should be easy to look for default values in the function itself using inspect.

  1. (Stretch goal) Ability to fit >1D functions

Looks like this could be possible with a call to ravel like this. I'll do some experimenting.

@TomNicholas
Copy link
Member

I think the way I configured things now does replicate the polyfit results.

You're right! My bad. The consistency with polyfit looks good.

Looks like this could be possible with a call to ravel

Oh nice! That looks like it would allow for ND functions fit to ND data. It looks like there is a dask version of ravel which might be useful. (And judging by the comments on that blog post I think @StanczakDominik would appreciate this feature too!)

@slevang
Copy link
Contributor Author

slevang commented Feb 2, 2021

Some more progress here.

  1. I liked the idea of being able to supply p0 and bounds as a dictionary mapping between parameter names and values, which is much more in the style of xarray. So I've included the logic to translate between this and the scipy args which are just ordered lists. The logic first prioritizes any of these values supplied directly in keyword args, then looks for defaults in the function definition, and otherwise falls back to the scipy defaults.
  2. There's now a working implementation of multi-dimensional fits using ravel to flatten the coordinate arrays. One limitation is that func must be specified like f((x, y), z) as scipy takes a single x argument for the fitting coordinate. Maybe there is a clever way around this, but I think it's ok.
  3. Some changes to the API.

The best way to specify the fitting coordinates is a bit tricky to figure out. My original use case for this was needing to fit a relationship between two time/lat/lon dataarrays with the fit done over all time. But probably a more common use would be to just fit a curve over one or two dimensions that already exist in your data. So it would be great to handle these possibilities seamlessly.

What I've settled on for now is a coords argument that takes a list of coordinates that should be the same length as the input coordinates of your fitting function. Additionally, there is a reduce_dim arg, so that, say if you want to fit a function in time but aggregate over all lat and lon, you can supply those dimensions here. So with a 3D x/y/time array, any of the following should work:

# Fit a 1d function in time, returns parameters with dims (x, y)
da.curvefit(coords='time', ...)
# Fit a 2d function in space, returns parameters with dims (t)
da.curvefit(coords=['x', 'y'], ...)
# Fit a 1d function with another 3d dataarray and aggregate over time, returns parameters with dims (x, y)
da.curvefit(coords=da1, reduce_dim='time', ...)

The logic to make this work got a bit complicated, since we need to supply the right input_core_dims to apply_ufunc, and also to explicitly broadcast the coordinates to ensure the ravel operation works. Any thoughts on cleanup here appreciated.

Will eventually need to add tests and improve docs and examples. Tests especially I could use some help on.

@slevang
Copy link
Contributor Author

slevang commented Feb 2, 2021

Added a couple usage examples in the docs, including one that replicates the scipy example of fitting multiple peaks. Because of the wrapper function and variable args, this requires supplying param_names manually. Works nicely though.

xarray/core/dataset.py Outdated Show resolved Hide resolved
@slevang
Copy link
Contributor Author

slevang commented Feb 3, 2021

Added some checks that will raise errors if inspect.signature is not able to determine the function arguments, either for functions with variable-length args or things like numpy ufuncs that don't seem to work with signature. In these cases, param_names can be passed manually.

Also added minimal tests, but these should probably be expanded.

@slevang slevang marked this pull request as ready for review February 4, 2021 01:18
xarray/core/dataset.py Outdated Show resolved Hide resolved
@slevang
Copy link
Contributor Author

slevang commented Feb 11, 2021

I've been playing around with this some more, and found the performance to be much better using a process-heavy dask scheduler. For example:

import xarray as xr
import numpy as np
import time
import daskdef exponential(x, a, xc):
    return np.exp((x - xc) / a)
​
x = np.arange(-5, 5, 0.001)
t = np.arange(-5, 5, 0.01)
X, T = np.meshgrid(x, t)
Z1 = np.random.uniform(low=-5, high=5, size=X.shape)
Z2 = exponential(Z1, 3, X) + np.random.normal(scale=0.1, size=Z1.shape)
​
ds = xr.Dataset(
    data_vars=dict(var1=(["t", "x"], Z1), var2=(["t", "x"], Z2)),
    coords={"t": t, "x": x},
)
​
ds = ds.chunk({'x':10})
​
def test_fit():
    start = time.time()
    fit = ds.var2.curvefit(
        coords=ds.var1,
        func=exponential,
        reduce_dim="t",
    ).compute()
    print(f'Fitting time: {time.time() - start:.2f}s')

with dask.config.set(scheduler='threads'):
    test_fit()
with dask.config.set(scheduler='processes'):
    test_fit()
with dask.distributed.Client() as client:
    test_fit()
with dask.distributed.Client(n_workers=8, threads_per_worker=1) as client:
    test_fit()

On my 8-core machine, takes:

Fitting time: 8.32s
Fitting time: 2.71s
Fitting time: 4.40s
Fitting time: 3.43s

According to this the underlying scipy routines should be thread safe.

@slevang slevang mentioned this pull request Feb 18, 2021
6 tasks
Copy link
Contributor

@dcherian dcherian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @slevang This is an amazing first PR!

It's very thorough and nicely written. I have just minor comments.

The only major comment is that I suggest refactoring some code out to a couple of helper functions which can then be tested independently. A few more test cases would be nice but I think you've covered most of the functionality.

xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
xarray/tests/test_dataarray.py Show resolved Hide resolved
xarray/core/dataset.py Show resolved Hide resolved
xarray/core/dataset.py Outdated Show resolved Hide resolved
@slevang
Copy link
Contributor Author

slevang commented Feb 20, 2021

Thanks for the review @dcherian! The latest commit has been refactored with a couple helper functions and associated tests, and any steps that served no purpose other than consistency with polyfit have been removed.

If you can think of any more specific test cases that should be included, happy to add them.

@TomNicholas
Copy link
Member

This seems ready to be merged?

@slevang
Copy link
Contributor Author

slevang commented Mar 30, 2021

I think so. I pushed a merge commit to get this up to date with the current release.

@dcherian
Copy link
Contributor

Thanks @slevang. Sorry for the delay!

@dcherian dcherian merged commit ddc352f into pydata:master Mar 31, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

General curve fitting method
7 participants