diff --git a/flox/core.py b/flox/core.py index 022f29582..ca0cb922d 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1559,7 +1559,10 @@ def _validate_expected_groups(nby: int, expected_groups: T_ExpectedGroupsOpt) -> return (None,) * nby if nby == 1 and not isinstance(expected_groups, tuple): - return (np.asarray(expected_groups),) + if isinstance(expected_groups, pd.Index): + return (expected_groups,) + else: + return (np.asarray(expected_groups),) if nby > 1 and not isinstance(expected_groups, tuple): # TODO: test for list raise ValueError( @@ -1734,9 +1737,11 @@ def groupby_reduce( # (pd.IntervalIndex or not) expected_groups = _convert_expected_groups_to_index(expected_groups, isbins, sort) + is_binning = any([isinstance(e, pd.IntervalIndex) for e in expected_groups]) + # TODO: could restrict this to dask-only factorize_early = (nby > 1) or ( - any(isbins) and method == "cohorts" and is_duck_dask_array(array) + is_binning and method == "cohorts" and is_duck_dask_array(array) ) if factorize_early: bys, final_groups, grp_shape = _factorize_multiple( diff --git a/flox/xarray.py b/flox/xarray.py index 1bd384875..e02065dab 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -313,7 +313,9 @@ def xarray_reduce( group_names: tuple[Any, ...] = () group_sizes: dict[Any, int] = {} for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups, isbins)): - group_name = b_.name if not isbin_ else f"{b_.name}_bins" + group_name = ( + f"{b_.name}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else b_.name + ) group_names += (group_name,) if isbin_ and isinstance(expect, int): diff --git a/tests/test_core.py b/tests/test_core.py index 0841e531a..1c208c211 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -644,10 +644,17 @@ def test_npg_nanarg_bug(func): assert_equal(actual, expected) +@pytest.mark.parametrize( + "kwargs", + ( + dict(expected_groups=np.array([1, 2, 4, 5]), isbin=True), + dict(expected_groups=pd.IntervalIndex.from_breaks([1, 2, 4, 5])), + ), +) @pytest.mark.parametrize("method", ["cohorts", "map-reduce"]) @pytest.mark.parametrize("chunk_labels", [False, True]) @pytest.mark.parametrize("chunks", ((), (1,), (2,))) -def test_groupby_bins(chunk_labels, chunks, engine, method) -> None: +def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None: array = [1, 1, 1, 1, 1, 1] labels = [0.2, 1.5, 1.9, 2, 3, 20] @@ -663,14 +670,7 @@ def test_groupby_bins(chunk_labels, chunks, engine, method) -> None: with raise_if_dask_computes(): actual, groups = groupby_reduce( - array, - labels, - func="count", - expected_groups=np.array([1, 2, 4, 5]), - isbin=True, - fill_value=0, - engine=engine, - method=method, + array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs ) expected = np.array([3, 1, 0], dtype=np.intp) for left, right in zip(groups, pd.IntervalIndex.from_arrays([1, 2, 4], [2, 4, 5]).to_numpy()): diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 50864a247..e17c7b98e 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -457,7 +457,8 @@ def test_datetime_array_reduce(use_cftime, func, engine): @requires_dask -def test_groupby_bins_indexed_coordinate(): +@pytest.mark.parametrize("method", ["cohorts", "map-reduce"]) +def test_groupby_bins_indexed_coordinate(method): ds = ( xr.tutorial.open_dataset("air_temperature") .isel(time=slice(100)) @@ -472,7 +473,17 @@ def test_groupby_bins_indexed_coordinate(): expected_groups=([40, 50, 60, 70],), isbin=(True,), func="mean", - method="split-reduce", + method=method, + ) + xr.testing.assert_allclose(expected, actual) + + actual = xarray_reduce( + ds, + ds.lat, + dim=ds.air.dims, + expected_groups=pd.IntervalIndex.from_breaks([40, 50, 60, 70]), + func="mean", + method=method, ) xr.testing.assert_allclose(expected, actual)