Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple groupers v3 #76

Merged
merged 17 commits into from
Mar 15, 2022
151 changes: 116 additions & 35 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import operator
from collections import namedtuple
from functools import partial
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -59,16 +59,14 @@ def _prepare_for_flox(group_idx, array):
return group_idx, ordered_array


def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None:
def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None:
if is_duck_dask_array(by):
if raise_if_dask:
raise ValueError("Please provide `expected_groups`.")
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
return None
flatby = by.ravel()
expected = pd.unique(flatby[~isnull(flatby)])
if sort:
expected = np.sort(expected)
return _convert_expected_groups_to_index(expected, isbin=False)
return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0]


def _get_chunk_reduction(reduction_type: str) -> Callable:
Expand Down Expand Up @@ -378,6 +376,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:
Copied from xhistogram &
https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy
"""
assert labels.ndim > 1
offset: np.ndarray = (
labels + np.arange(np.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups
)
Expand All @@ -388,7 +387,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:


def factorize_(
by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, sort=True
by: tuple,
axis,
expected_groups: tuple[pd.Index, ...] = None,
reindex=False,
sort=True,
fastpath=False,
):
"""
Returns an array of integer codes for groups (and associated data)
Expand All @@ -413,7 +417,7 @@ def factorize_(
raise ValueError("Please pass bin edges in expected_groups.")
# TODO: fix for binning
found_groups.append(expect)
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
# pd.cut with bins = IntervalIndex[datetime64] doesn't work...
if groupvar.dtype.kind == "M":
expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]])
idx = pd.cut(groupvar.ravel(), bins=expect).codes.copy()
Expand All @@ -440,10 +444,15 @@ def factorize_(
grp_shape = tuple(len(grp) for grp in found_groups)
ngroups = np.prod(grp_shape)
if len(by) > 1:
group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape)
group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap").reshape(by[0].shape)
nan_by_mask = reduce(np.logical_or, [isnull(b) for b in by])
group_idx[nan_by_mask] = -1
else:
group_idx = factorized[0]

if fastpath:
return group_idx, found_groups, grp_shape

if np.isscalar(axis) and groupvar.ndim > 1:
# Not reducing along all dimensions of by
# this is OK because for 3D by and axis=(1,2),
Expand Down Expand Up @@ -1244,33 +1253,78 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:


def _assert_by_is_aligned(shape, by):
if shape[-by.ndim :] != by.shape:
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
"for every array in `by`."
f"Received array of shape {shape} but "
f"`by` has shape {by.shape}."
for idx, b in enumerate(by):
if shape[-b.ndim :] != b.shape:
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
"for every array in `by`."
f"Received array of shape {shape} but "
f"array {idx} in `by` has shape {b.shape}."
)


def _convert_expected_groups_to_index(
expected_groups: tuple, isbin: bool, sort: bool
) -> pd.Index | None:
out = []
for ex, isbin_ in zip(expected_groups, isbin):
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
if sort:
ex = ex.sort_values()
out.append(ex)
elif ex is not None:
if isbin_:
out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:]))
else:
if sort:
ex = np.sort(ex)
out.append(pd.Index(ex))
else:
assert ex is None
out.append(None)
return tuple(out)


def _lazy_factorize_wrapper(*by, **kwargs):
group_idx, *rest = factorize_(by, **kwargs)
return group_idx


def _factorize_multiple(by, expected_groups, by_is_dask):
kwargs = dict(
expected_groups=expected_groups,
axis=None, # always None, we offset later if necessary.
fastpath=True,
)
if by_is_dask:
import dask.array

group_idx = dask.array.map_blocks(
_lazy_factorize_wrapper,
*np.broadcast_arrays(*by),
meta=np.array((), dtype=np.int64),
**kwargs,
)
found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by)
grp_shape = tuple(len(e) for e in expected_groups)
else:
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)

final_groups = tuple(
found if expect is None else expect.to_numpy()
for found, expect in zip(found_groups, expected_groups)
)

def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
if isinstance(expected_groups, pd.IntervalIndex) or (
isinstance(expected_groups, pd.Index) and not isbin
):
return expected_groups
if isbin:
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
elif expected_groups is not None:
return pd.Index(expected_groups)
return expected_groups
if any(grp is None for grp in final_groups):
raise ValueError("Please provide expected_groups when grouping by a dask array.")
return (group_idx,), final_groups, grp_shape


def groupby_reduce(
array: np.ndarray | DaskArray,
by: np.ndarray | DaskArray,
*by: np.ndarray | DaskArray,
func: str | Aggregation,
*,
expected_groups: Sequence | np.ndarray | None = None,
sort: bool = True,
isbin: bool = False,
Expand Down Expand Up @@ -1383,18 +1437,38 @@ def groupby_reduce(
)
reindex = _validate_reindex(reindex, func, method, expected_groups)

if not is_duck_array(by):
by = np.asarray(by)
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(by)
by_is_dask = any(is_duck_dask_array(b) for b in by)
if not is_duck_array(array):
array = np.asarray(array)
if isinstance(isbin, bool):
isbin = (isbin,) * len(by)
if expected_groups is None:
expected_groups = (None,) * len(by)

_assert_by_is_aligned(array.shape, by)

if len(by) == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
elif len(expected_groups) != len(by):
raise ValueError("len(expected_groups) != len(by)")

# We convert to pd.Index since that lets us know if we are binning or not
# (pd.IntervalIndex or not)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin)
if expected_groups is not None and sort:
expected_groups = expected_groups.sort_values()
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort)

# when grouping by multiple variables, we factorize early.
# TODO: could restrict this to dask-only
if nby > 1:
by, final_groups, grp_shape = _factorize_multiple(
by, expected_groups, by_is_dask=by_is_dask
)
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)

assert len(by) == 1
by = by[0]
expected_groups = expected_groups[0]

if axis is None:
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
Expand All @@ -1408,7 +1482,7 @@ def groupby_reduce(

# TODO: make sure expected_groups is unique
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
if not is_duck_dask_array(by):
if not by_is_dask:
expected_groups = _get_expected_groups(by, sort)
else:
# When we reduce along all axes, we are guaranteed to see all
Expand All @@ -1422,6 +1496,7 @@ def groupby_reduce(
"Please provide ``expected_groups`` when not reducing along all axes."
)

assert len(axis) <= by.ndim
if len(axis) < by.ndim:
by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim)
array = _move_reduce_dims_to_end(array, axis)
Expand Down Expand Up @@ -1514,7 +1589,7 @@ def groupby_reduce(
result, *groups = partial_agg(
array,
by,
expected_groups=expected_groups,
expected_groups=None if method == "blockwise" else expected_groups,
reindex=reindex,
method=method,
sort=sort,
Expand All @@ -1526,4 +1601,10 @@ def groupby_reduce(
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

if nby > 1:
# nan group labels are factorized to -1, and preserved
# now we get rid of them
nanmask = groups[0] == -1
groups = final_groups
result = result[..., ~nanmask].reshape(result.shape[:-1] + grp_shape)
return (result, *groups)
Loading