Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor groupby binary ops code. #6789

Merged
merged 1 commit into from
Jul 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 33 additions & 38 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from . import dtypes, duck_array_ops, nputils, ops
from ._reductions import DataArrayGroupByReductions, DatasetGroupByReductions
from .alignment import align
from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from .concat import concat
from .formatting import format_array_flat
Expand Down Expand Up @@ -309,7 +310,7 @@ class GroupBy(Generic[T_Xarray]):
"_squeeze",
# Save unstacked object for flox
"_original_obj",
"_unstacked_group",
"_original_group",
"_bins",
)
_obj: T_Xarray
Expand Down Expand Up @@ -374,7 +375,7 @@ def __init__(
group.name = "group"

self._original_obj: T_Xarray = obj
self._unstacked_group = group
self._original_group = group
self._bins = bins

group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj)
Expand Down Expand Up @@ -571,11 +572,22 @@ def _binary_op(self, other, f, reflexive=False):

g = f if not reflexive else lambda x, y: f(y, x)

obj = self._obj
group = self._group
dim = self._group_dim
if self._bins is None:
obj = self._original_obj
group = self._original_group
dims = group.dims
else:
obj = self._maybe_unstack(self._obj)
group = self._maybe_unstack(self._group)
dims = (self._group_dim,)

if isinstance(group, _DummyGroup):
group = obj[dim]
group = obj[group.name]
coord = group
else:
coord = self._unique_coord
if not isinstance(coord, DataArray):
coord = DataArray(self._unique_coord)
name = group.name

if not isinstance(other, (Dataset, DataArray)):
Expand All @@ -592,37 +604,19 @@ def _binary_op(self, other, f, reflexive=False):
"is not a dimension on the other argument"
)

try:
expanded = other.sel({name: group})
except KeyError:
# some labels are absent i.e. other is not aligned
# so we align by reindexing and then rename dimensions.

# Broadcast out scalars for backwards compatibility
# TODO: get rid of this when fixing GH2145
for var in other.coords:
if other[var].ndim == 0:
other[var] = (
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)
expanded = (
other.reindex({name: group.data})
.rename({name: dim})
.assign_coords({dim: obj[dim]})
)
# Broadcast out scalars for backwards compatibility
# TODO: get rid of this when fixing GH2145
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
for var in other.coords:
if other[var].ndim == 0:
other[var] = (
other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
)

if self._bins is not None and name == dim and dim not in obj.xindexes:
# When binning by unindexed coordinate we need to reindex obj.
# _full_index is IntervalIndex, so idx will be -1 where
# a value does not belong to any bin. Using IntervalIndex
# accounts for any non-default cut_kwargs passed to the constructor
idx = pd.cut(group, bins=self._full_index).codes
obj = obj.isel({dim: np.arange(group.size)[idx != -1]})
other, _ = align(other, coord, join="outer")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the real cleanup. Using align simplifies the logic a lot.

expanded = other.sel({name: group})

result = g(obj, expanded)

result = self._maybe_unstack(result)
group = self._maybe_unstack(group)
if group.ndim > 1:
# backcompat:
# TODO: get rid of this when fixing GH2145
Expand All @@ -632,8 +626,9 @@ def _binary_op(self, other, f, reflexive=False):

if isinstance(result, Dataset) and isinstance(obj, Dataset):
for var in set(result):
if dim not in obj[var].dims:
result[var] = result[var].transpose(dim, ...)
for d in dims:
if d not in obj[var].dims:
result[var] = result[var].transpose(d, ...)
return result

def _maybe_restore_empty_groups(self, combined):
Expand Down Expand Up @@ -695,10 +690,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs):
# group is only passed by resample
group = kwargs.pop("group", None)
if group is None:
if isinstance(self._unstacked_group, _DummyGroup):
group = self._unstacked_group.name
if isinstance(self._original_group, _DummyGroup):
group = self._original_group.name
else:
group = self._unstacked_group
group = self._original_group

unindexed_dims = tuple()
if isinstance(group, str):
Expand Down