Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/weighted #2922

Merged
merged 60 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
0f2da8e
weighted for DataArray
mathause Apr 26, 2019
5f64492
remove some commented code
mathause Apr 26, 2019
685e5c4
pep8 and faulty import tests
mathause Apr 26, 2019
c9d612d
add weighted sum, replace 0s in sum_of_wgt
mathause Apr 30, 2019
a20a4cf
weighted: overhaul tests
mathause Apr 30, 2019
26c24b6
weighted: pep8
mathause Apr 30, 2019
f3c6758
weighted: pep8 lines
mathause Apr 30, 2019
25c3c29
weighted update docs
mathause May 2, 2019
5d37d11
weighted: fix typo
mathause May 2, 2019
b1c572b
weighted: pep8
mathause May 8, 2019
d1d1f2c
undo changes to avoid merge conflict
mathause Oct 17, 2019
6be1414
Merge branch 'master' into feature/weighted
mathause Oct 17, 2019
059263c
add weighted to dataarray again
mathause Oct 17, 2019
8b1904b
remove super
mathause Oct 17, 2019
8cad145
overhaul core/weighted.py
mathause Oct 17, 2019
49d4e43
add DatasetWeighted class
mathause Oct 17, 2019
527256e
_maybe_get_all_dims return sorted tuple
mathause Oct 17, 2019
739568f
work on: test_weighted
mathause Oct 17, 2019
f01305d
black and flake8
mathause Oct 17, 2019
2e3880d
Apply suggestions from code review (docs)
mathause Oct 17, 2019
ae8d048
restructure interim
mathause Oct 18, 2019
dc7f605
restructure classes
mathause Oct 18, 2019
c646568
Merge branch 'master' into feature/weighted
mathause Dec 4, 2019
e2ad69e
update weighted.py
mathause Dec 4, 2019
bd4f048
black
mathause Dec 4, 2019
3c7695a
use map; add keep_attrs
mathause Dec 4, 2019
ef07edd
implement expected_weighted; update tests
mathause Dec 4, 2019
064b5a9
add whats new
mathause Dec 4, 2019
fec1a35
Merge branch 'master' into feature/weighted
mathause Dec 4, 2019
72c7942
undo changes to whats-new
mathause Dec 4, 2019
0e91411
F811: noqa where?
mathause Dec 4, 2019
1eb2913
api.rst
mathause Dec 5, 2019
118dfed
add to computation
mathause Dec 5, 2019
e08c921
small updates
mathause Dec 5, 2019
0fafe0b
add example to gallery
mathause Dec 5, 2019
a8d330d
typo
mathause Dec 5, 2019
ae0012f
another typo
mathause Dec 5, 2019
111259b
correct docstring in core/common.py
mathause Dec 5, 2019
5afc6f3
Merge branch 'master' into feature/weighted
mathause Jan 14, 2020
668b54b
typos
mathause Jan 14, 2020
d877022
adjust review
mathause Jan 14, 2020
ead681e
clean tests
mathause Jan 14, 2020
c4598ba
add test nonequal coords
mathause Jan 14, 2020
866fba5
comment on use of dot
mathause Jan 14, 2020
3cc00c1
fix erroneous merge
mathause Jan 14, 2020
8f34167
Merge branch 'master' into feature/weighted
mathause Jan 21, 2020
9f0a8cd
update tests
mathause Jan 21, 2020
98929f1
Merge branch 'master' into feature/weighted
mathause Mar 5, 2020
62c43e6
move example to notebook
mathause Mar 5, 2020
2e8aba2
move whats-new entry to 15.1
mathause Mar 5, 2020
d14f668
some doc updates
mathause Mar 5, 2020
7fa78ae
dot to own function
mathause Mar 5, 2020
3ebb9d4
simplify some tests
mathause Mar 5, 2020
f01d47a
Doc updates
dcherian Mar 17, 2020
4b184f6
very minor changes.
dcherian Mar 17, 2020
1e06adc
fix & add references
dcherian Mar 17, 2020
706579a
doc: return 0/NaN on 0 weights
mathause Mar 17, 2020
b2718db
Merge branch 'feature/weighted' of https://github.com/mathause/xarray…
mathause Mar 17, 2020
4c17108
Merge branch 'master' into feature/weighted
mathause Mar 17, 2020
8acc78e
Update xarray/core/common.py
dcherian Mar 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,22 @@ def groupby_bins(
},
)

def weighted(self, weights):
"""
Weighted operations.
Parameters
mathause marked this conversation as resolved.
Show resolved Hide resolved
----------
weights : DataArray
An array of weights associated with the values in this Dataset.
Each value in the data contributes to the reduction operation
according to its associated weight.
mathause marked this conversation as resolved.
Show resolved Hide resolved
Note
----
Missing values in the weights are treated as 0 (i.e. no weight).
"""

return self._weighted_cls(self, weights)

def rolling(
self,
dim: Mapping[Hashable, int] = None,
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
resample,
rolling,
utils,
weighted,
)
from .accessor_dt import DatetimeAccessor
from .accessor_str import StringAccessor
Expand Down Expand Up @@ -269,6 +270,7 @@ class DataArray(AbstractArray, DataWithCoords):
_rolling_cls = rolling.DataArrayRolling
_coarsen_cls = rolling.DataArrayCoarsen
_resample_cls = resample.DataArrayResample
_weighted_cls = weighted.DataArrayWeighted

__default = ReprObject("<default>")

Expand Down
2 changes: 2 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
resample,
rolling,
utils,
weighted,
)
from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
from .common import (
Expand Down Expand Up @@ -432,6 +433,7 @@ class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords):
_rolling_cls = rolling.DatasetRolling
_coarsen_cls = rolling.DatasetCoarsen
_resample_cls = resample.DatasetResample
_weighted_cls = weighted.DatasetWeighted

def __init__(
self,
Expand Down
307 changes: 307 additions & 0 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
from .computation import where, dot
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Tuple, Union, overload

if TYPE_CHECKING:
from .dataarray import DataArray, Dataset

_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted `{fcn}` along some dimension(s).
dcherian marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to apply the weighted `{fcn}`.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
axis : int or sequence of int, optional
Axis(es) over which to apply the weighted `{fcn}`. Only one of the
'dim' and 'axis' arguments can be supplied. If neither are supplied,
then the weighted `{fcn}` is calculated over all axes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove as xr.dot does not provide this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree re removing (or am I missing something—why was this here?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency it would be nice to offer axis here. Originally sum was implemented as (weights * da).sum(...) and we got axis for free. With dot it is not as straightforward any more. Honestly, I never use axis with xarray, so I suppose it is fine to only implement it if anyone would ever request it...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Unless I'm missing something, let's remove axis

skipna : bool, optional
mathause marked this conversation as resolved.
Show resolved Hide resolved
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
Note: Missing values in the weights are replaced with 0 (i.e. no
weight).
keep_attrs : bool, optional
If True, the attributes (`attrs`) will be copied from the original
mathause marked this conversation as resolved.
Show resolved Hide resolved
object to the new one. If False (default), the new object will be
returned without attributes.
**kwargs : dict
Additional keyword arguments passed on to the appropriate array
function for calculating `{fcn}` on this object's data.

Returns
-------
reduced : {cls}
New {cls} object with weighted `{fcn}` applied to its data and
the indicated dimension(s) removed.
"""

_SUM_OF_WEIGHTS_DOCSTRING = """
Calculate the sum of weights, accounting for missing values

Parameters
----------
dim : str or sequence of str, optional
Dimension(s) over which to sum the weights.
axis : int or sequence of int, optional
Axis(es) over which to sum the weights. Only one of the 'dim' and
'axis' arguments can be supplied. If neither are supplied, then
the weights are summed over all axes.

Returns
-------
reduced : {cls}
New {cls} object with the sum of the weights over the given dimension.


"""


# functions for weighted operations for one DataArray
# NOTE: weights must not contain missing values (this is taken care of in the
# DataArrayWeighted and DatasetWeighted cls)


def _maybe_get_all_dims(
dims: Optional[Union[Hashable, Iterable[Hashable]]],
dims1: Tuple[Hashable, ...],
dims2: Tuple[Hashable, ...],
):
""" the union of dims1 and dims2 if dims is None

`dims=None` behaves differently in `dot` and `sum`, so we have to apply
`dot` over the union of the dimensions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure you're right, but can you help me understand the issue here? Doesn't each of dot and sum collapse all dimensions if dims=None? Or there's an issue where only a subset of the dimensions are shared?


"""

if dims is None:
dims = tuple(sorted(set(dims1) | set(dims2)))

return dims


def _sum_of_weights(
da: "DataArray",
weights: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
mathause marked this conversation as resolved.
Show resolved Hide resolved
) -> "DataArray":
""" Calcualte the sum of weights, accounting for missing values """

# we need to mask DATA values that are nan; else the weights are wrong
mask = where(da.notnull(), 1, 0) # binary mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask = where(da.notnull(), 1, 0) # binary mask
mask = da.isnull()

bool works with dot from my limited checks


# need to infer dims as we use `dot`
dims = _maybe_get_all_dims(dim, da.dims, weights.dims)

# use `dot` to avoid creating large DataArrays (if da and weights do not share all dims)
sum_of_weights = dot(mask, weights, dims=dims)

# find all weights that are valid (not 0)
valid_weights = sum_of_weights != 0.0

# set invalid weights to nan
return sum_of_weights.where(valid_weights)


def _weighted_sum(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just put these on the Weighted object? (not a huge deal, trying to reduce indirection if possible though)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean like so?

class Weighted
    ...

    def _weighted_sum(self, da, dims, skipna, **kwargs)
        pass

Yes good idea, this might actually even work better as this avoids the hassle with the non-sanitized weights

da: "DataArray",
weights: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "DataArray":
"""Reduce a DataArray by a by a weighted `sum` along some dimension(s)."""

# need to infer dims as we use `dot`
dims = _maybe_get_all_dims(dim, da.dims, weights.dims)

# use `dot` to avoid creating large da's

# need to mask invalid DATA as dot does not implement skipna
if skipna or skipna is None:
return where(da.isnull(), 0.0, da).dot(weights, dims=dims)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return where(da.isnull(), 0.0, da).dot(weights, dims=dims)
return da.fillna(0.0).dot(weights, dims=dims)

?


return dot(da, weights, dims=dims)


def _weighted_mean(
da: "DataArray",
weights: "DataArray",
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "DataArray":
"""Reduce a DataArray by a weighted `mean` along some dimension(s)."""

# get weighted sum
weighted_sum = _weighted_sum(
da, weights, dim=dim, axis=axis, skipna=skipna, **kwargs
)

# get the sum of weights
sum_of_weights = _sum_of_weights(da, weights, dim=dim, axis=axis)

# calculate weighted mean
return weighted_sum / sum_of_weights


class Weighted:
"""A object that implements weighted operations.

You should create a Weighted object by using the `DataArray.weighted` or
`Dataset.weighted` methods.

See Also
--------
Dataset.weighted
DataArray.weighted
"""

__slots__ = ("obj", "weights")

@overload
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
...

@overload
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
...
mathause marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, obj, weights) -> None:
"""
Create a Weighted object

Parameters
----------
obj : DataArray or Dataset
Object over which the weighted reduction operation is applied.
weights : DataArray
An array of weights associated with the values in this obj.
Each value in the obj contributes to the reduction operation
according to its associated weight.

Note
----
Missing values in the weights are replaced with 0 (i.e. no weight).

"""

from .dataarray import DataArray

msg = "'weights' must be a DataArray"
assert isinstance(weights, DataArray), msg

self.obj = obj
self.weights = weights.fillna(0)

def __repr__(self):
"""provide a nice str repr of our Weighted object"""

msg = "{klass} with weights along dimensions: {weight_dims}"
return msg.format(
klass=self.__class__.__name__, weight_dims=", ".join(self.weights.dims)
)


class DataArrayWeighted(Weighted):
def sum_of_weights(
self, dim: Optional[Union[Hashable, Iterable[Hashable]]] = None, axis=None
) -> "DataArray":

return _sum_of_weights(self.obj, self.weights, dim=dim, axis=axis)

def sum(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "DataArray":

return _weighted_sum(
self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs
)

def mean(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "DataArray":

return _weighted_mean(
self.obj, self.weights, dim=dim, axis=axis, skipna=skipna, **kwargs
)


# add docstrings
DataArrayWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(
cls="DataArray"
)
DataArrayWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls="DataArray", fcn="mean"
)
DataArrayWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls="DataArray", fcn="sum"
)


class DatasetWeighted(Weighted):
def _dataset_implementation(self, func, **kwargs) -> "Dataset":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but we do this lots of places; we should find a common approach at some point!


from .dataset import Dataset

weighted = {}
for key, da in self.obj.data_vars.items():

weighted[key] = func(da, self.weights, **kwargs)

return Dataset(weighted, coords=self.obj.coords)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work if coords are on dimensions that are reduced / removed by the function?


def sum_of_weights(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
) -> "Dataset":

return self._dataset_implementation(_sum_of_weights, dim=dim, axis=axis)

def sum(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "Dataset":

return self._dataset_implementation(
_weighted_sum, dim=dim, axis=axis, skipna=skipna, **kwargs
)

def mean(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
axis=None,
skipna: Optional[bool] = None,
**kwargs
) -> "Dataset":

return self._dataset_implementation(
_weighted_mean, dim=dim, axis=axis, skipna=skipna, **kwargs
)


# add docstring
DatasetWeighted.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls="Dataset")
DatasetWeighted.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls="Dataset", fcn="mean"
)
DatasetWeighted.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
cls="Dataset", fcn="sum"
)
Loading