Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Feb 22, 2022
1 parent 0db44f7 commit b2be395
Showing 1 changed file with 80 additions and 17 deletions.
97 changes: 80 additions & 17 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None:
expected = pd.unique(flatby[~isnull(flatby)])
if sort:
expected = np.sort(expected)
return _convert_expected_groups_to_index(expected, isbin=False)
return _convert_expected_groups_to_index((expected,), isbin=(False,))[0]


def _get_chunk_reduction(reduction_type: str) -> Callable:
Expand Down Expand Up @@ -388,7 +388,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]:


def factorize_(
by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, sort=True
by: tuple,
axis,
expected_groups: tuple[pd.Index, ...] = None,
reindex=False,
sort=True,
fastpath=False,
):
"""
Returns an array of integer codes for groups (and associated data)
Expand Down Expand Up @@ -440,10 +445,13 @@ def factorize_(
grp_shape = tuple(len(grp) for grp in found_groups)
ngroups = np.prod(grp_shape)
if len(by) > 1:
group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape)
group_idx = np.ravel_multi_index(factorized, grp_shape)
else:
group_idx = factorized[0]

if fastpath:
return group_idx, found_groups, grp_shape

if np.isscalar(axis) and groupvar.ndim > 1:
# Not reducing along all dimensions of by
# this is OK because for 3D by and axis=(1,2),
Expand Down Expand Up @@ -1272,23 +1280,60 @@ def _assert_by_is_aligned(shape, by):
)


def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
if isinstance(expected_groups, pd.IntervalIndex) or (
isinstance(expected_groups, pd.Index) and not isbin
):
return expected_groups
if isbin:
return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:])
elif expected_groups is not None:
return pd.Index(expected_groups)
return expected_groups
def _convert_expected_groups_to_index(expected_groups: tuple, isbin: bool) -> pd.Index | None:
out = []
for ex, isbin_ in zip(expected_groups, isbin):
if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin):
out.append(expected_groups)
elif ex is not None:
if isbin_:
out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:]))
else:
out.append(pd.Index(ex))
else:
assert ex is None
out.append(None)
return tuple(out)


def _lazy_factorize_wrapper(*by, **kwargs):
group_idx, _ = factorize_(by, **kwargs)
return group_idx


def _factorize_multiple(by, expected_groups, by_is_dask):
kwargs = dict(
expected_groups=expected_groups,
axis=None, # always None, we offset later if necessary.
fastpath=True,
)
if by_is_dask:
import dask.array

group_idx = dask.array.map_blocks(
_lazy_factorize_wrapper,
*np.broadcast_arrays(*by),
meta=np.array((), dtype=np.int64),
**kwargs,
)
found_groups = tuple(None if is_duck_dask_array(b) else np.unique(b) for b in by)
else:
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)

final_groups = tuple(
pd.Index(found) if expect is None else expect
for found, expect in zip(found_groups, expected_groups)
)

if any(grp is None for grp in final_groups):
raise
return (group_idx,), final_groups, grp_shape


def groupby_reduce(
array: np.ndarray | DaskArray,
by: np.ndarray | DaskArray,
*by: np.ndarray | DaskArray,
func: str | Aggregation,
*,
expected_groups: Sequence | np.ndarray | None = None,
sort: bool = True,
isbin: bool = False,
Expand Down Expand Up @@ -1402,6 +1447,7 @@ def groupby_reduce(
reindex = _validate_reindex(reindex, func, method, expected_groups)

by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
nby = len(by)
by_is_dask = any(is_duck_dask_array(b) for b in by)
if not is_duck_array(array):
array = np.asarray(array)
Expand All @@ -1423,6 +1469,20 @@ def groupby_reduce(
if expected_groups is not None and sort:
expected_groups = expected_groups.sort_values()

# when grouping by multiple variables, we factorize early.
# TODO: could restrict this to dask-only
if nby > 1:
by, final_groups, grp_shape = _factorize_multiple(
by, expected_groups, by_is_dask=by_is_dask
)
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)
else:
final_groups = expected_groups

assert len(by) == 1
by = by[0]
expected_groups = expected_groups[0]

if axis is None:
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
else:
Expand All @@ -1434,7 +1494,7 @@ def groupby_reduce(
)

# TODO: make sure expected_groups is unique
if len(axis) == 1 and by_ndim > 1 and expected_groups[0] is None:
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
# TODO: hack
if not by_is_dask:
expected_groups = _get_expected_groups(by, sort)
Expand Down Expand Up @@ -1540,7 +1600,7 @@ def groupby_reduce(
result, *groups = partial_agg(
array,
by,
expected_groups=expected_groups,
expected_groups=None if method == "blockwise" else expected_groups,
reindex=reindex,
method=method,
sort=sort,
Expand All @@ -1552,4 +1612,7 @@ def groupby_reduce(
result = result[..., sorted_idx]
groups = (groups[0][sorted_idx],)

if nby > 1:
groups = final_groups
result = result.reshape(result.shape[:-1] + grp_shape)
return (result, *groups)

0 comments on commit b2be395

Please sign in to comment.