Skip to content

Commit

Permalink
Merge branch 'master' into drop_duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Apr 4, 2021
2 parents daa6e42 + 3cbd21a commit cc94bbe
Show file tree
Hide file tree
Showing 32 changed files with 622 additions and 63 deletions.
6 changes: 3 additions & 3 deletions ci/min_deps_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime
from typing import Dict, Iterator, Optional, Tuple

import conda.api
import conda.api # type: ignore[import]
import yaml
from dateutil.relativedelta import relativedelta

Expand Down Expand Up @@ -76,9 +76,9 @@ def parse_requirements(fname) -> Iterator[Tuple[str, int, int, Optional[int]]]:
raise ValueError("non-numerical version: " + row)

if len(version_tup) == 2:
yield (pkg, *version_tup, None) # type: ignore
yield (pkg, *version_tup, None) # type: ignore[misc]
elif len(version_tup) == 3:
yield (pkg, *version_tup) # type: ignore
yield (pkg, *version_tup) # type: ignore[misc]
else:
raise ValueError("expected major.minor or major.minor.patch: " + row)

Expand Down
3 changes: 2 additions & 1 deletion doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Computation
Dataset.integrate
Dataset.map_blocks
Dataset.polyfit
Dataset.curvefit

**Aggregation**:
:py:attr:`~Dataset.all`
Expand Down Expand Up @@ -377,7 +378,7 @@ Computation
DataArray.integrate
DataArray.polyfit
DataArray.map_blocks

DataArray.curvefit

**Aggregation**:
:py:attr:`~DataArray.all`
Expand Down
83 changes: 83 additions & 0 deletions doc/user-guide/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,89 @@ The inverse operation is done with :py:meth:`~xarray.polyval`,
.. note::
These methods replicate the behaviour of :py:func:`numpy.polyfit` and :py:func:`numpy.polyval`.


.. _compute.curvefit:

Fitting arbitrary functions
===========================

Xarray objects also provide an interface for fitting more complex functions using
:py:meth:`scipy.optimize.curve_fit`. :py:meth:`~xarray.DataArray.curvefit` accepts
user-defined functions and can fit along multiple coordinates.

For example, we can fit a relationship between two ``DataArray`` objects, maintaining
a unique fit at each spatial coordinate but aggregating over the time dimension:

.. ipython:: python
def exponential(x, a, xc):
return np.exp((x - xc) / a)
x = np.arange(-5, 5, 0.1)
t = np.arange(-5, 5, 0.1)
X, T = np.meshgrid(x, t)
Z1 = np.random.uniform(low=-5, high=5, size=X.shape)
Z2 = exponential(Z1, 3, X)
Z3 = exponential(Z1, 1, -X)
ds = xr.Dataset(
data_vars=dict(
var1=(["t", "x"], Z1), var2=(["t", "x"], Z2), var3=(["t", "x"], Z3)
),
coords={"t": t, "x": x},
)
ds[["var2", "var3"]].curvefit(
coords=ds.var1,
func=exponential,
reduce_dims="t",
bounds={"a": (0.5, 5), "xc": (-5, 5)},
)
We can also fit multi-dimensional functions, and even use a wrapper function to
simultaneously fit a summation of several functions, such as this field containing
two gaussian peaks:

.. ipython:: python
def gaussian_2d(coords, a, xc, yc, xalpha, yalpha):
x, y = coords
z = a * np.exp(
-np.square(x - xc) / 2 / np.square(xalpha)
- np.square(y - yc) / 2 / np.square(yalpha)
)
return z
def multi_peak(coords, *args):
z = np.zeros(coords[0].shape)
for i in range(len(args) // 5):
z += gaussian_2d(coords, *args[i * 5 : i * 5 + 5])
return z
x = np.arange(-5, 5, 0.1)
y = np.arange(-5, 5, 0.1)
X, Y = np.meshgrid(x, y)
n_peaks = 2
names = ["a", "xc", "yc", "xalpha", "yalpha"]
names = [f"{name}{i}" for i in range(n_peaks) for name in names]
Z = gaussian_2d((X, Y), 3, 1, 1, 2, 1) + gaussian_2d((X, Y), 2, -1, -2, 1, 1)
Z += np.random.normal(scale=0.1, size=Z.shape)
da = xr.DataArray(Z, dims=["y", "x"], coords={"y": y, "x": x})
da.curvefit(
coords=["x", "y"],
func=multi_peak,
param_names=names,
kwargs={"maxfev": 10000},
)
.. note::
This method replicates the behavior of :py:func:`scipy.optimize.curve_fit`.


.. _compute.broadcasting:

Broadcasting by dimension name
Expand Down
11 changes: 10 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ New Features
- Implement :py:meth:`Dataset.drop_duplicate_coords` and :py:meth:`DataArray.drop_duplicate_coords`
to remove duplicate coordinate values (:pull:`5089`).
By `Andrew Huang <https://github.com/ahuang11>`_.
- Add a ``combine_attrs`` parameter to :py:func:`open_mfdataset` (:pull:`4971`).
By `Justus Magin <https://github.com/keewis>`_.
- Disable the `cfgrib` backend if the `eccodes` library is not installed (:pull:`5083`). By `Baudouin Raoult <https://github.com/b8raoult>`_.
- Added :py:meth:`DataArray.curvefit` and :py:meth:`Dataset.curvefit` for general curve fitting applications. (:issue:`4300`, :pull:`4849`)
By `Sam Levang <https://github.com/slevang>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -102,6 +106,9 @@ Documentation

Internal Changes
~~~~~~~~~~~~~~~~
- Enable displaying mypy error codes and ignore only specific error codes using
``# type: ignore[error-code]`` (:pull:`5096`). By `Mathias Hauser <https://github.com/mathause>`_.


.. _whats-new.0.17.0:

Expand Down Expand Up @@ -1363,7 +1370,9 @@ Enhancements

- Added a repr (:pull:`3344`). Example::

>>> da.groupby("time.season")
da.groupby("time.season")
DataArrayGroupBy, grouped over 'season'
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'

Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,15 @@ default_section = THIRDPARTY
known_first_party = xarray

[mypy]
show_error_codes = True

# Most of the numerical computing stack doesn't have type annotations yet.
[mypy-affine.*]
ignore_missing_imports = True
[mypy-bottleneck.*]
ignore_missing_imports = True
[mypy-cartopy.*]
ignore_missing_imports = True
[mypy-cdms2.*]
ignore_missing_imports = True
[mypy-cf_units.*]
Expand Down
7 changes: 3 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,7 @@ def open_mfdataset(
parallel=False,
join="outer",
attrs_file=None,
combine_attrs="override",
**kwargs,
):
"""Open multiple files as a single dataset.
Expand Down Expand Up @@ -931,7 +932,7 @@ def open_mfdataset(
coords=coords,
ids=ids,
join=join,
combine_attrs="drop",
combine_attrs=combine_attrs,
)
elif combine == "by_coords":
# Redo ordering from coordinates, ignoring how they were ordered
Expand All @@ -942,7 +943,7 @@ def open_mfdataset(
data_vars=data_vars,
coords=coords,
join=join,
combine_attrs="drop",
combine_attrs=combine_attrs,
)
else:
raise ValueError(
Expand All @@ -965,8 +966,6 @@ def multi_file_closer():
if isinstance(attrs_file, Path):
attrs_file = str(attrs_file)
combined.attrs = datasets[paths.index(attrs_file)].attrs
else:
combined.attrs = datasets[0].attrs

return combined

Expand Down
18 changes: 17 additions & 1 deletion xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def _infer_tile_ids_from_nested_list(entry, current_pos):
yield current_pos, entry


def _ensure_same_types(series, dim):

if series.dtype == object:
types = set(series.map(type))
if len(types) > 1:
types = ", ".join(t.__name__ for t in types)
raise TypeError(
f"Cannot combine along dimension '{dim}' with mixed types."
f" Found: {types}."
)


def _infer_concat_order_from_coords(datasets):

concat_dims = []
Expand Down Expand Up @@ -88,11 +100,15 @@ def _infer_concat_order_from_coords(datasets):
raise ValueError("Cannot handle size zero dimensions")
first_items = pd.Index([index[0] for index in indexes])

series = first_items.to_series()

# ensure series does not contain mixed types, e.g. cftime calendars
_ensure_same_types(series, dim)

# Sort datasets along dim
# We want rank but with identical elements given identical
# position indices - they should be concatenated along another
# dimension, not along this one
series = first_items.to_series()
rank = series.rank(
method="dense", ascending=ascending, numeric_only=False
)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):

else:

def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore
def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore[misc]
return self.reduce(func, dim, axis, **kwargs)

return wrapped_func
Expand Down Expand Up @@ -97,7 +97,7 @@ def wrapped_func(self, dim=None, skipna=None, **kwargs):

else:

def wrapped_func(self, dim=None, **kwargs): # type: ignore
def wrapped_func(self, dim=None, **kwargs): # type: ignore[misc]
return self.reduce(func, dim, numeric_only=numeric_only, **kwargs)

return wrapped_func
Expand Down
8 changes: 5 additions & 3 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def dims(self) -> Union[Mapping[Hashable, int], Tuple[Hashable, ...]]:

@property
def indexes(self) -> Indexes:
return self._data.indexes # type: ignore
return self._data.indexes # type: ignore[attr-defined]

@property
def variables(self):
Expand Down Expand Up @@ -105,9 +105,11 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
raise ValueError("no valid index for a 0-dimensional object")
elif len(ordered_dims) == 1:
(dim,) = ordered_dims
return self._data.get_index(dim) # type: ignore
return self._data.get_index(dim) # type: ignore[attr-defined]
else:
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
indexes = [
self._data.get_index(k) for k in ordered_dims # type: ignore[attr-defined]
]

# compute the sizes of the repeat and tile for the cartesian product
# (taken from pandas.core.reshape.util)
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def ensure_minimum_chunksize(size, chunks):

def sliding_window_view(x, window_shape, axis=None):
from dask.array.overlap import map_overlap
from numpy.core.numeric import normalize_axis_tuple # type: ignore
from numpy.core.numeric import normalize_axis_tuple

from .npcompat import sliding_window_view as _np_sliding_window_view

Expand Down
Loading

0 comments on commit cc94bbe

Please sign in to comment.