Skip to content

Commit

Permalink
Add groupby.dims & Fix groupby reduce for DataArray (#3338)
Browse files Browse the repository at this point in the history
* Fix groupby reduce for DataArray

* bugfix.

* another bugfix.

* bugfix unique_and_monotonic for object indexes (uniqueness is enough)

* Add groupby.dims property.

* update reduce docstring to point to xarray.ALL_DIMS

* test for object index dims.

* test reduce dimensions error.

* Add whats-new

* fix docs build

* sq whats-new

* one more test.

* fix test.

* undo monotonic change.

* Add dimensions to repr.

* Raise error if no bins.

* Raise nice error if no groups were formed.

* Some more error raising and testing.

* Add dataset tests.

* update whats-new.

* fix tests.

* make dims a cached lazy property.

* fix whats-new.

* whitespace

* fix whats-new
  • Loading branch information
dcherian authored and Joe Hamman committed Oct 10, 2019
1 parent 3f0049f commit 291cb80
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 15 deletions.
2 changes: 1 addition & 1 deletion doc/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,4 +213,4 @@ applying your function, and then unstacking the result:
.. ipython:: python
stacked = da.stack(gridcell=['ny', 'nx'])
stacked.groupby('gridcell').sum().unstack('gridcell')
stacked.groupby('gridcell').sum(xr.ALL_DIMS).unstack('gridcell')
15 changes: 12 additions & 3 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,17 @@ New functions/methods
Enhancements
~~~~~~~~~~~~

- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
Example::
- :py:class:`~xarray.core.GroupBy` enhancements. By `Deepak Cherian <https://github.com/dcherian>`_.

- Added a repr. Example::

>>> da.groupby("time.season")
DataArrayGroupBy, grouped over 'season'
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'

(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
- Added a ``GroupBy.dims`` property that mirrors the dimensions
of each group.(:issue:`3344`)

- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
arrays (:issue:`2799`, :pull:`3375`) by
`Guido Imperiale <https://github.com/crusaderky>`_.
Expand All @@ -73,6 +76,12 @@ Bug fixes
- Line plots with the ``x`` or ``y`` argument set to a 1D non-dimensional coord
now plot the correct data for 2D DataArrays
(:issue:`3334`). By `Tom Nicholas <http://github.com/TomNicholas>`_.
- The default behaviour of reducing across all dimensions for
:py:class:`~xarray.core.groupby.DataArrayGroupBy` objects has now been properly removed
as was done for :py:class:`~xarray.core.groupby.DatasetGroupBy` in 0.13.0 (:issue:`3337`).
Use `xarray.ALL_DIMS` if you need to replicate previous behaviour.
Also raise nicer error message when no groups are created (:issue:`1764`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- Fix error in concatenating unlabeled dimensions (:pull:`3362`).
By `Deepak Cherian <https://github.com/dcherian/>`_.

Expand Down
48 changes: 43 additions & 5 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import dtypes, duck_array_ops, nputils, ops
from .arithmetic import SupportsArithmetic
from .common import ImplementsArrayReduce, ImplementsDatasetReduce
from .common import ALL_DIMS, ImplementsArrayReduce, ImplementsDatasetReduce
from .concat import concat
from .formatting import format_array_flat
from .options import _get_keep_attrs
Expand Down Expand Up @@ -248,6 +248,7 @@ class GroupBy(SupportsArithmetic):
"_restore_coord_dims",
"_stacked_dim",
"_unique_coord",
"_dims",
)

def __init__(
Expand Down Expand Up @@ -320,6 +321,8 @@ def __init__(
full_index = None

if bins is not None:
if np.isnan(bins).all():
raise ValueError("All bin edges are NaN.")
binned = pd.cut(group.values, bins, **cut_kwargs)
new_dim_name = group.name + "_bins"
group = DataArray(binned, group.coords, name=new_dim_name)
Expand Down Expand Up @@ -351,6 +354,16 @@ def __init__(
)
unique_coord = IndexVariable(group.name, unique_values)

if len(group_indices) == 0:
if bins is not None:
raise ValueError(
"None of the data falls within bins with edges %r" % bins
)
else:
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)

if (
isinstance(obj, DataArray)
and restore_coord_dims is None
Expand Down Expand Up @@ -379,6 +392,16 @@ def __init__(

# cached attributes
self._groups = None
self._dims = None

@property
def dims(self):
if self._dims is None:
self._dims = self._obj.isel(
**{self._group_dim: self._group_indices[0]}
).dims

return self._dims

@property
def groups(self):
Expand All @@ -394,7 +417,7 @@ def __iter__(self):
return zip(self._unique_coord.values, self._iter_grouped())

def __repr__(self):
return "%s, grouped over %r \n%r groups with labels %s" % (
return "%s, grouped over %r \n%r groups with labels %s." % (
self.__class__.__name__,
self._unique_coord.name,
self._unique_coord.size,
Expand Down Expand Up @@ -689,7 +712,7 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
q : float in range of [0,1] (or sequence of floats)
Quantile to compute, which must be between 0 and 1
inclusive.
dim : str or sequence of str, optional
dim : xarray.ALL_DIMS, str or sequence of str, optional
Dimension(s) over which to apply quantile.
Defaults to the grouped dimension.
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
Expand Down Expand Up @@ -746,7 +769,7 @@ def reduce(
Function which can be called in the form
`func(x, axis=axis, **kwargs)` to return the result of collapsing
an np.ndarray over an integer valued axis.
dim : str or sequence of str, optional
dim : xarray.ALL_DIMS, str or sequence of str, optional
Dimension(s) over which to apply `func`.
axis : int or sequence of int, optional
Axis(es) over which to apply `func`. Only one of the 'dimension'
Expand All @@ -765,9 +788,18 @@ def reduce(
Array with summarized data and the indicated dimension(s)
removed.
"""
if dim is None:
dim = self._group_dim

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

if dim is not ALL_DIMS and dim not in self.dims:
raise ValueError(
"cannot reduce over dimension %r. expected either xarray.ALL_DIMS to reduce over all dimensions or one or more of %r."
% (dim, self.dims)
)

def reduce_array(ar):
return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)

Expand Down Expand Up @@ -835,7 +867,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
Function which can be called in the form
`func(x, axis=axis, **kwargs)` to return the result of collapsing
an np.ndarray over an integer valued axis.
dim : str or sequence of str, optional
dim : xarray.ALL_DIMS, str or sequence of str, optional
Dimension(s) over which to apply `func`.
axis : int or sequence of int, optional
Axis(es) over which to apply `func`. Only one of the 'dimension'
Expand Down Expand Up @@ -863,6 +895,12 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
def reduce_dataset(ds):
return ds.reduce(func, dim, keep_attrs, **kwargs)

if dim is not ALL_DIMS and dim not in self.dims:
raise ValueError(
"cannot reduce over dimension %r. expected either xarray.ALL_DIMS to reduce over all dimensions or one or more of %r."
% (dim, self.dims)
)

return self.apply(reduce_dataset)

def assign(self, **kwargs):
Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,15 @@ def change_metadata(x):
expected = change_metadata(expected)
assert_equal(expected, actual)

def test_groupby_reduce_dimension_error(self):
array = self.make_groupby_example_array()
grouped = array.groupby("y")
with raises_regex(ValueError, "cannot reduce over dimension 'y'"):
grouped.mean()

grouped = array.groupby("y", squeeze=False)
assert_identical(array, grouped.mean())

def test_groupby_math(self):
array = self.make_groupby_example_array()
for squeeze in [True, False]:
Expand Down
46 changes: 40 additions & 6 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import xarray as xr
from xarray.core.groupby import _consolidate_slices

from . import assert_identical
from . import assert_identical, raises_regex


def test_consolidate_slices():
Expand All @@ -21,6 +21,19 @@ def test_consolidate_slices():
_consolidate_slices([slice(3), 4])


def test_groupby_dims_property():
ds = xr.Dataset(
{"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))},
{"x": ["a", "bcd", "c"], "y": [1, 2, 3, 4], "z": [1, 2]},
)

assert ds.groupby("x").dims == ds.isel(x=1).dims
assert ds.groupby("y").dims == ds.isel(y=1).dims

stacked = ds.stack({"xy": ("x", "y")})
assert stacked.groupby("xy").dims == stacked.isel(xy=0).dims


def test_multi_index_groupby_apply():
# regression test for GH873
ds = xr.Dataset(
Expand Down Expand Up @@ -222,13 +235,13 @@ def test_groupby_repr(obj, dim):
expected += ", grouped over %r " % dim
expected += "\n%r groups with labels " % (len(np.unique(obj[dim])))
if dim == "x":
expected += "1, 2, 3, 4, 5"
expected += "1, 2, 3, 4, 5."
elif dim == "y":
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19"
expected += "0, 1, 2, 3, 4, 5, ..., 15, 16, 17, 18, 19."
elif dim == "z":
expected += "'a', 'b', 'c'"
expected += "'a', 'b', 'c'."
elif dim == "month":
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12"
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
assert actual == expected


Expand All @@ -238,8 +251,29 @@ def test_groupby_repr_datetime(obj):
expected = "%sGroupBy" % obj.__class__.__name__
expected += ", grouped over 'month' "
expected += "\n%r groups with labels " % (len(np.unique(obj.t.dt.month)))
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12"
expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12."
assert actual == expected


def test_groupby_grouping_errors():
dataset = xr.Dataset({"foo": ("x", [1, 1, 1])}, {"x": [1, 2, 3]})
with raises_regex(ValueError, "None of the data falls within bins with edges"):
dataset.groupby_bins("x", bins=[0.1, 0.2, 0.3])

with raises_regex(ValueError, "None of the data falls within bins with edges"):
dataset.to_array().groupby_bins("x", bins=[0.1, 0.2, 0.3])

with raises_regex(ValueError, "All bin edges are NaN."):
dataset.groupby_bins("x", bins=[np.nan, np.nan, np.nan])

with raises_regex(ValueError, "All bin edges are NaN."):
dataset.to_array().groupby_bins("x", bins=[np.nan, np.nan, np.nan])

with raises_regex(ValueError, "Failed to group data."):
dataset.groupby(dataset.foo * np.nan)

with raises_regex(ValueError, "Failed to group data."):
dataset.to_array().groupby(dataset.foo * np.nan)


# TODO: move other groupby tests from test_dataset and test_dataarray over here

0 comments on commit 291cb80

Please sign in to comment.