diff --git a/flox/core.py b/flox/core.py index b4e68c23f..6de0db58f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1310,11 +1310,12 @@ def _lazy_factorize_wrapper(*by, **kwargs): return group_idx -def _factorize_multiple(by, expected_groups, by_is_dask): +def _factorize_multiple(by, expected_groups, by_is_dask, reindex): kwargs = dict( expected_groups=expected_groups, axis=None, # always None, we offset later if necessary. fastpath=True, + reindex=reindex, ) if by_is_dask: import dask.array @@ -1325,7 +1326,9 @@ def _factorize_multiple(by, expected_groups, by_is_dask): meta=np.array((), dtype=np.int64), **kwargs, ) - found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by) + found_groups = tuple( + None if is_duck_dask_array(b) else pd.unique(np.array(b).reshape(-1)) for b in by + ) grp_shape = tuple(len(e) for e in expected_groups) else: group_idx, found_groups, grp_shape = factorize_(by, **kwargs) @@ -1489,7 +1492,7 @@ def groupby_reduce( ) if factorize_early: by, final_groups, grp_shape = _factorize_multiple( - by, expected_groups, by_is_dask=by_is_dask + by, expected_groups, by_is_dask=by_is_dask, reindex=reindex ) expected_groups = (pd.RangeIndex(np.prod(grp_shape)),) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 5f779d0ad..fa54e4e54 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -451,3 +451,37 @@ def test_groupby_bins_indexed_coordinate(): method="split-reduce", ) xr.testing.assert_allclose(expected, actual) + + +@pytest.mark.parametrize("chunk", (True, False)) +def test_mixed_grouping(chunk): + if not has_dask and chunk: + pytest.skip() + # regression test for https://github.com/dcherian/flox/pull/111 + sa = 10 + sb = 13 + sc = 3 + + x = xr.Dataset( + { + "v0": xr.DataArray( + ((np.arange(sa * sb * sc) / sa) % 1).reshape((sa, sb, sc)), + dims=("a", "b", "c"), + ), + "v1": xr.DataArray((np.arange(sa * sb) % 3).reshape(sa, sb), dims=("a", "b")), + } + ) + if chunk: + x["v0"] = x["v0"].chunk({"a": 5}) + + r = xarray_reduce( + x["v0"], + x["v1"], + x["v0"], + expected_groups=(np.arange(6), np.linspace(0, 1, num=5)), + isbin=[False, True], + func="count", + dim="b", + fill_value=0, + ) + assert (r.sel(v1=[3, 4, 5]) == 0).all().data