diff --git a/flox/core.py b/flox/core.py index df1b9292a..e6e30d430 100644 --- a/flox/core.py +++ b/flox/core.py @@ -340,9 +340,10 @@ def invert(x) -> tuple[np.ndarray, ...]: # TODO: we can optimize this to loop over chunk_cohorts instead # by zeroing out rows that are already in a cohort for rowidx in order: - cohort_ = containment.indices[ + cohidx = containment.indices[ slice(containment.indptr[rowidx], containment.indptr[rowidx + 1]) ] + cohort_ = present_labels[cohidx] cohort = [elem for elem in cohort_ if elem not in merged_keys] if not cohort: continue diff --git a/tests/test_core.py b/tests/test_core.py index b32cd4d41..385378910 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -857,6 +857,16 @@ def test_find_group_cohorts(expected, labels, chunks: tuple[int]) -> None: assert actual == expected, (actual, expected) +@requires_dask +def test_find_cohorts_missing_groups(): + by = np.array([np.nan, np.nan, np.nan, 2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, np.nan, np.nan]) + kwargs = {"func": "sum", "expected_groups": [0, 1, 2], "fill_value": 123} + array = dask.array.ones_like(by, chunks=(3,)) + actual, _ = groupby_reduce(array, by, method="cohorts", **kwargs) + expected, _ = groupby_reduce(array.compute(), by, **kwargs) + assert_equal(expected, actual) + + @pytest.mark.parametrize("chunksize", [12, 13, 14, 24, 36, 48, 72, 71]) def test_verify_complex_cohorts(chunksize: int) -> None: time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))