Skip to content

Commit

Permalink
Return a dataclass from factorize.
Browse files Browse the repository at this point in the history
Also auto-compute group_indices if not set.
  • Loading branch information
dcherian committed Jan 4, 2024
1 parent 5d72f2f commit 78e3880
Showing 1 changed file with 72 additions and 41 deletions.
113 changes: 72 additions & 41 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@
GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
T_GroupIndices = list[GroupIndex]
T_FactorizeOut = tuple[
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index, DataArray
]


def check_reduce_dims(reduce_dims, dimensions):
Expand Down Expand Up @@ -96,7 +93,7 @@ def _maybe_squeeze_indices(

def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
"""Group an array by its unique values.
Parameters
Expand All @@ -117,11 +114,11 @@ def unique_value_groups(
inverse, values = pd.factorize(ar, sort=sort)
if isinstance(values, pd.MultiIndex):
values.names = ar.names
groups = _codes_to_groups(inverse, len(values))
return values, groups, inverse
return values, inverse


def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
assert inverse.ndim == 1
groups: T_GroupIndices = [[] for _ in range(N)]
for n, g in enumerate(inverse):
if g >= 0:
Expand Down Expand Up @@ -342,16 +339,35 @@ def _apply_loffset(


@dataclass
class ResolvedGrouper:
class EncodedGroups:
"""
Parameters
----------
codes:
full_index:
group_indices: optional,
Inferred if not provided.
unique_coord:
Inferred if not provided
"""

codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)


@dataclass
class ResolvedGrouper(Generic[T_Xarray]):
grouper: Grouper
group: T_Group
obj: T_Xarray

# Defined by factorize:
# returned by factorize:
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
full_index: pd.Index = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand Down Expand Up @@ -395,20 +411,29 @@ def dims(self):
return self.group1d.dims

def factorize(self) -> None:
# This design makes it clear to mypy that
# codes, group_indices, unique_coord, and full_index
# are set by the factorize method on the derived class.
(
self.codes,
self.group_indices,
self.unique_coord,
self.full_index,
) = self.grouper.factorize(self.group1d)
encoded = self.grouper.factorize(self.group1d)

self.codes = encoded.codes
self.full_index = encoded.full_index

if encoded.group_indices is not None:
self.group_indices = encoded.group_indices
else:
self.group_indices = [
g
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
if g
]
if encoded.unique_coord is None:
# TODO
raise NotImplementedError
else:
self.unique_coord = encoded.unique_coord


class Grouper(ABC):
@abstractmethod
def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group: T_Group) -> EncodedGroups:
pass


Expand All @@ -418,7 +443,7 @@ class Resampler(Grouper):

@dataclass
class UniqueGrouper(Grouper):
group_as_index: pd.Index | None = field(default=None, init=False)
group_as_index: pd.Index = field(init=False)

@property
def is_unique_and_monotonic(self) -> bool:
Expand All @@ -432,7 +457,7 @@ def can_squeeze(self):
is_dimension = self.group.dims == (self.group.name,)
return is_dimension and self.is_unique_and_monotonic

def factorize(self, group1d) -> T_FactorizeOut:
def factorize(self, group1d) -> EncodedGroups:
self.group = group1d
self.group_as_index = group1d.to_index()

Expand All @@ -441,26 +466,25 @@ def factorize(self, group1d) -> T_FactorizeOut:
else:
return self._factorize_unique()

def _factorize_unique(self) -> T_FactorizeOut:
def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
unique_values, group_indices, codes_ = unique_value_groups(
self.group_as_index, sort=sort
)
if len(group_indices) == 0:
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
if (codes_ == -1).all():
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
group_indices = group_indices
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
)
full_index = unique_coord

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)

def _factorize_dummy(self) -> T_FactorizeOut:
def _factorize_dummy(self) -> EncodedGroups:
size = self.group.size
# no need to factorize
# use slices to do views instead of fancy indexing
Expand All @@ -475,8 +499,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


@dataclass
Expand All @@ -490,7 +518,7 @@ def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
raise ValueError("All bin edges are NaN.")

def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
from xarray.core.dataarray import DataArray

data = group.data
Expand All @@ -508,11 +536,7 @@ def factorize(self, group) -> T_FactorizeOut:
full_index = binned.categories
uniques = np.sort(pd.unique(binned_codes))
unique_values = full_index[uniques[uniques != -1]]
group_indices = [
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
]

if len(group_indices) == 0:
if (binned_codes == -1).all():
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)
Expand All @@ -521,7 +545,9 @@ def factorize(self, group) -> T_FactorizeOut:
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)


@dataclass
Expand Down Expand Up @@ -620,7 +646,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
_apply_loffset(self.loffset, first_items)
return first_items, codes

def _factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
self._init_properties(group)
full_index, first_items, codes_ = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
Expand All @@ -632,7 +658,12 @@ def _factorize(self, group) -> T_FactorizeOut:
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
codes = group.copy(data=codes_)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


def _validate_groupby_squeeze(squeeze: bool | None) -> None:
Expand Down

0 comments on commit 78e3880

Please sign in to comment.