Skip to content

Commit

Permalink
groupby_reduce updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Feb 22, 2022
1 parent 97972fb commit 0db44f7
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
35 changes: 23 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,14 +1261,15 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool:


def _assert_by_is_aligned(shape, by):
if shape[-by.ndim :] != by.shape:
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
"for every array in `by`."
f"Received array of shape {shape} but "
f"`by` has shape {by.shape}."
)
for idx, b in enumerate(by):
if shape[-b.ndim :] != b.shape:
raise ValueError(
"`array` and `by` arrays must be aligned "
"i.e. array.shape[-by.ndim :] == by.shape. "
"for every array in `by`."
f"Received array of shape {shape} but "
f"array {idx} in `by` has shape {b.shape}."
)


def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None:
Expand Down Expand Up @@ -1400,13 +1401,22 @@ def groupby_reduce(
)
reindex = _validate_reindex(reindex, func, method, expected_groups)

if not is_duck_array(by):
by = np.asarray(by)
by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
by_is_dask = any(is_duck_dask_array(b) for b in by)
if not is_duck_array(array):
array = np.asarray(array)
if isinstance(isbin, bool):
isbin = (isbin,) * len(by)
if expected_groups is None:
expected_groups = (None,) * len(by)

_assert_by_is_aligned(array.shape, by)

if len(by) == 1 and not isinstance(expected_groups, tuple):
expected_groups = (np.asarray(expected_groups),)
elif len(expected_groups) != len(by):
raise ValueError("len(expected_groups) != len(by)")

# We convert to pd.Index since that lets us know if we are binning or not
# (pd.IntervalIndex or not)
expected_groups = _convert_expected_groups_to_index(expected_groups, isbin)
Expand All @@ -1424,8 +1434,9 @@ def groupby_reduce(
)

# TODO: make sure expected_groups is unique
if len(axis) == 1 and by.ndim > 1 and expected_groups is None:
if not is_duck_dask_array(by):
if len(axis) == 1 and by_ndim > 1 and expected_groups[0] is None:
# TODO: hack
if not by_is_dask:
expected_groups = _get_expected_groups(by, sort)
else:
# When we reduce along all axes, we are guaranteed to see all
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_groupby_reduce(
elif func == "count":
expected = np.array(expected, dtype=int)

result, groups = groupby_reduce(
result, groups, = groupby_reduce(
array,
by,
func=func,
Expand Down Expand Up @@ -780,7 +780,7 @@ def test_datetime_binning():
time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24H")
by = pd.date_range("2010-08-01", "2010-08-15", freq="15min")

actual = _convert_expected_groups_to_index(time_bins, isbin=True)
(actual,) = _convert_expected_groups_to_index((time_bins,), isbin=(True,))
expected = pd.IntervalIndex.from_arrays(time_bins[:-1], time_bins[1:])
assert_equal(actual, expected)

Expand Down

0 comments on commit 0db44f7

Please sign in to comment.