diff --git a/flox/core.py b/flox/core.py index cb66ec883..8b4a21ade 100644 --- a/flox/core.py +++ b/flox/core.py @@ -4,7 +4,7 @@ import itertools import operator from collections import namedtuple -from functools import partial +from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, Mapping, Sequence, Union import numpy as np @@ -59,16 +59,14 @@ def _prepare_for_flox(group_idx, array): return group_idx, ordered_array -def _get_expected_groups(by, sort, raise_if_dask=True) -> pd.Index | None: +def _get_expected_groups(by, sort, *, raise_if_dask=True) -> pd.Index | None: if is_duck_dask_array(by): if raise_if_dask: - raise ValueError("Please provide `expected_groups`.") + raise ValueError("Please provide expected_groups if not grouping by a numpy array.") return None flatby = by.ravel() expected = pd.unique(flatby[~isnull(flatby)]) - if sort: - expected = np.sort(expected) - return _convert_expected_groups_to_index(expected, isbin=False) + return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0] def _get_chunk_reduction(reduction_type: str) -> Callable: @@ -378,6 +376,7 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: Copied from xhistogram & https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy """ + assert labels.ndim > 1 offset: np.ndarray = ( labels + np.arange(np.prod(labels.shape[:-1])).reshape((*labels.shape[:-1], -1)) * ngroups ) @@ -388,7 +387,12 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: def factorize_( - by: tuple, axis, expected_groups: tuple[pd.Index, ...] = None, reindex=False, sort=True + by: tuple, + axis, + expected_groups: tuple[pd.Index, ...] = None, + reindex=False, + sort=True, + fastpath=False, ): """ Returns an array of integer codes for groups (and associated data) @@ -413,7 +417,7 @@ def factorize_( raise ValueError("Please pass bin edges in expected_groups.") # TODO: fix for binning found_groups.append(expect) - # pd.cut with bins = IntervalIndex[datetime64] doesn't work... + # pd.cut with bins = IntervalIndex[datetime64] doesn't work... if groupvar.dtype.kind == "M": expect = np.concatenate([expect.left.to_numpy(), [expect.right[-1].to_numpy()]]) idx = pd.cut(groupvar.ravel(), bins=expect).codes.copy() @@ -440,10 +444,15 @@ def factorize_( grp_shape = tuple(len(grp) for grp in found_groups) ngroups = np.prod(grp_shape) if len(by) > 1: - group_idx = np.ravel_multi_index(factorized, grp_shape).reshape(by[0].shape) + group_idx = np.ravel_multi_index(factorized, grp_shape, mode="wrap").reshape(by[0].shape) + nan_by_mask = reduce(np.logical_or, [isnull(b) for b in by]) + group_idx[nan_by_mask] = -1 else: group_idx = factorized[0] + if fastpath: + return group_idx, found_groups, grp_shape + if np.isscalar(axis) and groupvar.ndim > 1: # Not reducing along all dimensions of by # this is OK because for 3D by and axis=(1,2), @@ -1244,33 +1253,78 @@ def _validate_reindex(reindex: bool, func, method, expected_groups) -> bool: def _assert_by_is_aligned(shape, by): - if shape[-by.ndim :] != by.shape: - raise ValueError( - "`array` and `by` arrays must be aligned " - "i.e. array.shape[-by.ndim :] == by.shape. " - "for every array in `by`." - f"Received array of shape {shape} but " - f"`by` has shape {by.shape}." + for idx, b in enumerate(by): + if shape[-b.ndim :] != b.shape: + raise ValueError( + "`array` and `by` arrays must be aligned " + "i.e. array.shape[-by.ndim :] == by.shape. " + "for every array in `by`." + f"Received array of shape {shape} but " + f"array {idx} in `by` has shape {b.shape}." + ) + + +def _convert_expected_groups_to_index( + expected_groups: tuple, isbin: bool, sort: bool +) -> pd.Index | None: + out = [] + for ex, isbin_ in zip(expected_groups, isbin): + if isinstance(ex, pd.IntervalIndex) or (isinstance(ex, pd.Index) and not isbin): + if sort: + ex = ex.sort_values() + out.append(ex) + elif ex is not None: + if isbin_: + out.append(pd.IntervalIndex.from_arrays(ex[:-1], ex[1:])) + else: + if sort: + ex = np.sort(ex) + out.append(pd.Index(ex)) + else: + assert ex is None + out.append(None) + return tuple(out) + + +def _lazy_factorize_wrapper(*by, **kwargs): + group_idx, *rest = factorize_(by, **kwargs) + return group_idx + + +def _factorize_multiple(by, expected_groups, by_is_dask): + kwargs = dict( + expected_groups=expected_groups, + axis=None, # always None, we offset later if necessary. + fastpath=True, + ) + if by_is_dask: + import dask.array + + group_idx = dask.array.map_blocks( + _lazy_factorize_wrapper, + *np.broadcast_arrays(*by), + meta=np.array((), dtype=np.int64), + **kwargs, ) + found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by) + grp_shape = tuple(len(e) for e in expected_groups) + else: + group_idx, found_groups, grp_shape = factorize_(by, **kwargs) + final_groups = tuple( + found if expect is None else expect.to_numpy() + for found, expect in zip(found_groups, expected_groups) + ) -def _convert_expected_groups_to_index(expected_groups, isbin: bool) -> pd.Index | None: - if isinstance(expected_groups, pd.IntervalIndex) or ( - isinstance(expected_groups, pd.Index) and not isbin - ): - return expected_groups - if isbin: - return pd.IntervalIndex.from_arrays(expected_groups[:-1], expected_groups[1:]) - elif expected_groups is not None: - return pd.Index(expected_groups) - return expected_groups + if any(grp is None for grp in final_groups): + raise ValueError("Please provide expected_groups when grouping by a dask array.") + return (group_idx,), final_groups, grp_shape def groupby_reduce( array: np.ndarray | DaskArray, - by: np.ndarray | DaskArray, + *by: np.ndarray | DaskArray, func: str | Aggregation, - *, expected_groups: Sequence | np.ndarray | None = None, sort: bool = True, isbin: bool = False, @@ -1383,18 +1437,38 @@ def groupby_reduce( ) reindex = _validate_reindex(reindex, func, method, expected_groups) - if not is_duck_array(by): - by = np.asarray(by) + by: tuple = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) + nby = len(by) + by_is_dask = any(is_duck_dask_array(b) for b in by) if not is_duck_array(array): array = np.asarray(array) + if isinstance(isbin, bool): + isbin = (isbin,) * len(by) + if expected_groups is None: + expected_groups = (None,) * len(by) _assert_by_is_aligned(array.shape, by) + if len(by) == 1 and not isinstance(expected_groups, tuple): + expected_groups = (np.asarray(expected_groups),) + elif len(expected_groups) != len(by): + raise ValueError("len(expected_groups) != len(by)") + # We convert to pd.Index since that lets us know if we are binning or not # (pd.IntervalIndex or not) - expected_groups = _convert_expected_groups_to_index(expected_groups, isbin) - if expected_groups is not None and sort: - expected_groups = expected_groups.sort_values() + expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort) + + # when grouping by multiple variables, we factorize early. + # TODO: could restrict this to dask-only + if nby > 1: + by, final_groups, grp_shape = _factorize_multiple( + by, expected_groups, by_is_dask=by_is_dask + ) + expected_groups = (pd.RangeIndex(np.prod(grp_shape)),) + + assert len(by) == 1 + by = by[0] + expected_groups = expected_groups[0] if axis is None: axis = tuple(array.ndim + np.arange(-by.ndim, 0)) @@ -1408,7 +1482,7 @@ def groupby_reduce( # TODO: make sure expected_groups is unique if len(axis) == 1 and by.ndim > 1 and expected_groups is None: - if not is_duck_dask_array(by): + if not by_is_dask: expected_groups = _get_expected_groups(by, sort) else: # When we reduce along all axes, we are guaranteed to see all @@ -1422,6 +1496,7 @@ def groupby_reduce( "Please provide ``expected_groups`` when not reducing along all axes." ) + assert len(axis) <= by.ndim if len(axis) < by.ndim: by = _move_reduce_dims_to_end(by, -array.ndim + np.array(axis) + by.ndim) array = _move_reduce_dims_to_end(array, axis) @@ -1514,7 +1589,7 @@ def groupby_reduce( result, *groups = partial_agg( array, by, - expected_groups=expected_groups, + expected_groups=None if method == "blockwise" else expected_groups, reindex=reindex, method=method, sort=sort, @@ -1526,4 +1601,10 @@ def groupby_reduce( result = result[..., sorted_idx] groups = (groups[0][sorted_idx],) + if nby > 1: + # nan group labels are factorized to -1, and preserved + # now we get rid of them + nanmask = groups[0] == -1 + groups = final_groups + result = result[..., ~nanmask].reshape(result.shape[:-1] + grp_shape) return (result, *groups) diff --git a/flox/xarray.py b/flox/xarray.py index 842d22f5d..bf06328e9 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -9,26 +9,25 @@ from .aggregations import Aggregation, _atleast_1d from .core import ( - factorize_, + _convert_expected_groups_to_index, + _get_expected_groups, groupby_reduce, rechunk_for_blockwise as rechunk_array_for_blockwise, rechunk_for_cohorts as rechunk_array_for_cohorts, - reindex_, ) -from .xrutils import is_duck_dask_array, isnull if TYPE_CHECKING: from xarray import DataArray, Dataset, Resample -def _get_input_core_dims(group_names, dim, ds, to_group): +def _get_input_core_dims(group_names, dim, ds, grouper_dims): input_core_dims = [[], []] for g in group_names: if g in dim: continue if g in ds.dims: input_core_dims[0].extend([g]) - if g in to_group.dims: + if g in grouper_dims: input_core_dims[1].extend([g]) input_core_dims[0].extend(dim) input_core_dims[1].extend(dim) @@ -182,6 +181,13 @@ def xarray_reduce( if isinstance(isbin, bool): isbin = (isbin,) * len(by) + if expected_groups is None: + expected_groups = (None,) * len(by) + if isinstance(expected_groups, (np.ndarray, list)): # TODO: test for list + if len(by) == 1: + expected_groups = (expected_groups,) + else: + raise ValueError("Needs better message.") if not sort: raise NotImplementedError @@ -196,10 +202,7 @@ def xarray_reduce( by: tuple[DataArray] = tuple(obj[g] if isinstance(g, str) else g for g in by) # type: ignore - if len(by) > 1 and any(is_duck_dask_array(by_.data) for by_ in by): - raise NotImplementedError("Grouping by multiple variables will compute dask variables.") - - grouper_dims = set(itertools.chain(*tuple(g.dims for g in by))) + grouper_dims = tuple(itertools.chain(*tuple(g.dims for g in by))) if isinstance(obj, xr.DataArray): ds = obj._to_temp_dataset() @@ -222,7 +225,7 @@ def xarray_reduce( # in the case where dim is Ellipsis, and by.ndim < obj.ndim # then we also broadcast `by` to all `obj.dims` # TODO: avoid this broadcasting - exclude_dims = set(ds.dims) - grouper_dims + exclude_dims = set(ds.dims) - set(grouper_dims) if dim is not None: exclude_dims -= set(dim) ds, *by = xr.broadcast(ds, *by, exclude=exclude_dims) @@ -254,42 +257,29 @@ def xarray_reduce( axis = tuple(range(-len(dim), 0)) group_names = tuple(g.name if not binned else f"{g.name}_bins" for g, binned in zip(by, isbin)) - if len(by) > 1: - group_idx, expected_groups, group_shape, _, _, _ = factorize_( - tuple(g.data for g in by), - axis, - expected_groups, - ) - to_group = xr.DataArray(group_idx, dims=dim, coords={d: by[0][d] for d in by[0].indexes}) - else: - if expected_groups is None and isinstance(by[0].data, np.ndarray): - uniques = np.unique(by[0].data) - nans = isnull(uniques) - if nans.any(): - uniques = uniques[~nans] - expected_groups = (uniques,) - if expected_groups is None: + group_shape = [None] * len(by) + expected_groups = list(expected_groups) + + # Set expected_groups and convert to index since we need coords, sizes + # for output xarray objects + for idx, (b, expect, isbin_) in enumerate(zip(by, expected_groups, isbin)): + if isbin_ and isinstance(expect, int): raise NotImplementedError( - "Please provide expected_groups if not grouping by a numpy-backed DataArray" + "flox does not support binning into an integer number of bins yet." ) - if isinstance(expected_groups, np.ndarray): - expected_groups = (expected_groups,) - if isbin[0]: - if isinstance(expected_groups[0], int): - raise NotImplementedError( - "Does not support binning into an integer number of bins yet." + if expect is None: + if isbin_: + raise ValueError( + f"Please provided bin edges for group variable {idx} " + f"named {group_names[idx]} in expected_groups." ) - # factorized, bins = pd.cut(by[0], bins=expected_groups[0], retbins=True) - group_shape = (expected_groups[0],) - else: - group_shape = (len(expected_groups[0]) - 1,) - else: - group_shape = (len(expected_groups[0]),) - to_group = by[0] + expected_groups[idx] = _get_expected_groups(b.data, sort=sort, raise_if_dask=True) + expected_groups = _convert_expected_groups_to_index(expected_groups, isbin, sort=sort) + group_shape = tuple(len(e) for e in expected_groups) group_sizes = dict(zip(group_names, group_shape)) - def wrapper(array, to_group, *, func, skipna, **kwargs): + def wrapper(array, *by, func, skipna, **kwargs): # Handle skipna here because I need to know dtype to make a good default choice. # We cannnot handle this easily for xarray Datasets in xarray_reduce if skipna and func in ["all", "any", "count"]: @@ -299,19 +289,7 @@ def wrapper(array, to_group, *, func, skipna, **kwargs): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" - result, groups = groupby_reduce(array, to_group, func=func, **kwargs) - if len(by) > 1: - # all groups need not be present. reindex here - # TODO: add test - reindexed = reindex_( - result, - from_=groups, - to=pd.Index(np.arange(np.prod(group_shape))), - fill_value=fill_value, - axis=-1, - ) - result = reindexed.reshape(result.shape[:-1] + group_shape) - + result, *groups = groupby_reduce(array, *by, func=func, **kwargs) return result # These data variables do not have any of the core dimension, @@ -327,11 +305,13 @@ def wrapper(array, to_group, *, func, skipna, **kwargs): if is_missing_dim: missing_dim[k] = v - input_core_dims = _get_input_core_dims(group_names, dim, ds, to_group) + input_core_dims = _get_input_core_dims(group_names, dim, ds, grouper_dims) + input_core_dims += [input_core_dims[-1]] * (len(by) - 1) + actual = xr.apply_ufunc( wrapper, - ds.drop_vars(tuple(missing_dim) + bad_dtypes).transpose(..., *to_group.dims), - to_group, + ds.drop_vars(tuple(missing_dim) + bad_dtypes).transpose(..., *grouper_dims), + *by, input_core_dims=input_core_dims, # for xarray's test_groupby_duplicate_coordinate_labels exclude_dims=set(dim), @@ -350,14 +330,8 @@ def wrapper(array, to_group, *, func, skipna, **kwargs): "skipna": skipna, "engine": engine, "reindex": reindex, - # The following mess exists because for multiple `by`s I factorize eagerly - # here before passing it on; this means I have to handle the - # "binning by single by variable" case explicitly where the factorization - # happens later allowing `by` to be a dask variable. - # Another annoyance is that for resampling expected_groups is "disconnected" - # from "by" so we need the isbin part of the condition - "expected_groups": expected_groups[0] if len(by) == 1 and isbin[0] else None, - "isbin": isbin[0] if len(by) == 1 else False, + "expected_groups": tuple(expected_groups), + "isbin": isbin, "finalize_kwargs": finalize_kwargs, }, ) @@ -368,9 +342,10 @@ def wrapper(array, to_group, *, func, skipna, **kwargs): if all(d not in ds[var].dims for d in dim): actual[var] = ds[var] - for name, expect, isbin_ in zip(group_names, expected_groups, isbin): - if isbin_: - expect = [pd.Interval(left, right) for left, right in zip(expect[:-1], expect[1:])] + for name, expect in zip(group_names, expected_groups): + # Can't remove this till xarray handles IntervalIndex + if isinstance(expect, pd.IntervalIndex): + expect = expect.to_numpy() if isinstance(actual, xr.Dataset) and name in actual: actual = actual.drop_vars(name) actual[name] = expect @@ -525,11 +500,11 @@ def resample_reduce( by, func=func, method="blockwise", - expected_groups=(resampler._unique_coord.data,), keep_attrs=keep_attrs, **kwargs, ) .rename({"__resample_dim__": dim}) .transpose(dim, ...) ) + result[dim] = resampler._unique_coord.data return result diff --git a/tests/test_core.py b/tests/test_core.py index 6999c8650..649b03f62 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -116,7 +116,7 @@ def test_groupby_reduce( elif func == "count": expected = np.array(expected, dtype=int) - result, groups = groupby_reduce( + result, groups, = groupby_reduce( array, by, func=func, @@ -143,7 +143,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("chunks", [None, 3, 4]) -@pytest.mark.parametrize("nby", [1]) +@pytest.mark.parametrize("nby", [1, 2, 3]) @pytest.mark.parametrize("size", ((12,), (12, 8))) @pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) @@ -163,7 +163,6 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): by[idx][2 * idx : 2 * idx + 3] = np.nan by = tuple(by) nanmask = reduce(np.logical_or, (np.isnan(b) for b in by)) - by = by[0] finalize_kwargs = [{}] if "var" in func or "std" in func: @@ -183,7 +182,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): for _ in range(nby): expected = np.expand_dims(expected, -1) - actual, *groups = groupby_reduce(array, by, **flox_kwargs) + actual, *groups = groupby_reduce(array, *by, **flox_kwargs) assert actual.ndim == (array.ndim + nby - 1) assert expected.ndim == (array.ndim + nby - 1) expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby)) @@ -198,7 +197,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): for method in ["map-reduce", "cohorts", "split-reduce"]: if "arg" in func and method != "map-reduce": continue - actual, _ = groupby_reduce(array, by, method=method, **flox_kwargs) + actual, *groups = groupby_reduce(array, *by, method=method, **flox_kwargs) + for actual_group, expect in zip(groups, expected_groups): + assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" assert_equal(actual, expected) @@ -813,7 +814,7 @@ def test_datetime_binning(): time_bins = pd.date_range(start="2010-08-01", end="2010-08-15", freq="24H") by = pd.date_range("2010-08-01", "2010-08-15", freq="15min") - actual = _convert_expected_groups_to_index(time_bins, isbin=True) + (actual,) = _convert_expected_groups_to_index((time_bins,), isbin=(True,), sort=False) expected = pd.IntervalIndex.from_arrays(time_bins[:-1], time_bins[1:]) assert_equal(actual, expected) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index e33a7203e..3f6550aa5 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -73,9 +73,17 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine, reindex): # assert_equal(expected, actual) -def test_xarray_reduce_multiple_groupers(engine): - arr = np.ones((4, 12)) +# TODO: sort +@pytest.mark.parametrize("pass_expected_groups", [True, False]) +@pytest.mark.parametrize("chunk", (True, False)) +def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine): + if not has_dask and chunk: + pytest.skip() + + if chunk and pass_expected_groups is False: + pytest.skip() + arr = np.ones((4, 12)) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) labels = np.array(labels) labels2 = np.array([1, 2, 2, 1]) @@ -84,31 +92,39 @@ def test_xarray_reduce_multiple_groupers(engine): arr, dims=("x", "y"), coords={"labels2": ("x", labels2), "labels": ("y", labels)} ).expand_dims(z=4) + if chunk: + da = da.chunk({"x": 2, "z": 1}) + expected = xr.DataArray( [[4, 4], [8, 8], [10, 10], [2, 2]], dims=("labels", "labels2"), coords={"labels": ["a", "b", "c", "f"], "labels2": [1, 2]}, ).expand_dims(z=4) - actual = xarray_reduce(da, da.labels, da.labels2, func="count", engine=engine) + kwargs = dict(func="count", engine=engine) + if pass_expected_groups: + kwargs["expected_groups"] = (expected.labels.data, expected.labels2.data) + + with raise_if_dask_computes(): + actual = xarray_reduce(da, da.labels, da.labels2, **kwargs) xr.testing.assert_identical(expected, actual) - actual = xarray_reduce(da, "labels", da.labels2, func="count", engine=engine) + with raise_if_dask_computes(): + actual = xarray_reduce(da, "labels", da.labels2, **kwargs) xr.testing.assert_identical(expected, actual) - actual = xarray_reduce(da, "labels", "labels2", func="count", engine=engine) + with raise_if_dask_computes(): + actual = xarray_reduce(da, "labels", "labels2", **kwargs) xr.testing.assert_identical(expected, actual) - if has_dask: - with raise_if_dask_computes(): - actual = xarray_reduce( - da.chunk({"x": 2, "z": 1}), da.labels, da.labels2, func="count", engine=engine - ) - xr.testing.assert_identical(expected, actual) - with pytest.raises(NotImplementedError): - xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count") - # xr.testing.assert_identical(expected, actual) +@requires_dask +def test_dask_groupers_error(): + da = xr.DataArray( + [1.0, 2.0], dims="x", coords={"labels": ("x", [1, 2]), "labels2": ("x", [1, 2])} + ) + with pytest.raises(ValueError): + xarray_reduce(da.chunk({"x": 2, "z": 1}), "labels", "labels2", func="count") @requires_dask @@ -165,7 +181,7 @@ def test_xarray_reduce_errors(): xarray_reduce(da, by, func="mean", dim="foo") if has_dask: - with pytest.raises(NotImplementedError, match="provide expected_groups"): + with pytest.raises(ValueError, match="provide expected_groups"): xarray_reduce(da, by.chunk(), func="mean")