From 7043bdb24d6e74b34598aa6e0f4f9ca4fb9e7e44 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 13 May 2022 11:22:13 -0600 Subject: [PATCH] Enable `flox` in `GroupBy` and `resample` (#5734) Closes #5734 Closes #4473 Closes #4498 Closes #659 Closes #2237 xref https://github.com/pangeo-data/pangeo/issues/271 Co-authored-by: Anderson Banihirwe Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Mathias Hauser Co-authored-by: Stephan Hoyer --- asv_bench/asv.conf.json | 2 + asv_bench/benchmarks/groupby.py | 10 +- ci/install-upstream-wheels.sh | 2 + ci/requirements/all-but-dask.yml | 1 + ci/requirements/environment-windows.yml | 1 + ci/requirements/environment.yml | 1 + ci/requirements/min-all-deps.yml | 1 + doc/whats-new.rst | 5 + setup.cfg | 1 + xarray/core/_reductions.py | 1075 ++++++++++++++++------- xarray/core/groupby.py | 126 ++- xarray/core/options.py | 6 + xarray/core/resample.py | 33 +- xarray/core/utils.py | 18 + xarray/tests/__init__.py | 2 + xarray/tests/test_groupby.py | 116 ++- xarray/tests/test_units.py | 8 +- xarray/util/generate_reductions.py | 130 ++- xarray/util/print_versions.py | 2 + 19 files changed, 1185 insertions(+), 355 deletions(-) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 3e4137cf807..10b8aead374 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -65,6 +65,8 @@ "bottleneck": [""], "dask": [""], "distributed": [""], + "flox": [""], + "numpy_groupies": [""], "sparse": [""] }, diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index fa93ce9e8b5..490c2ccbd4c 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -13,6 +13,7 @@ def setup(self, *args, **kwargs): { "a": xr.DataArray(np.r_[np.repeat(1, self.n), np.repeat(2, self.n)]), "b": xr.DataArray(np.arange(2 * self.n)), + "c": xr.DataArray(np.arange(2 * self.n)), } ) self.ds2d = self.ds1d.expand_dims(z=10) @@ -50,10 +51,11 @@ class GroupByDask(GroupBy): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)).chunk({"dim_0": 50}) - self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)).chunk( - {"dim_0": 50, "z": 5} - ) + + self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)) + self.ds1d["c"] = self.ds1d["c"].chunk({"dim_0": 50}) + self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)) + self.ds2d["c"] = self.ds2d["c"].chunk({"dim_0": 50, "z": 5}) self.ds1d_mean = self.ds1d.groupby("b").mean() self.ds2d_mean = self.ds2d.groupby("b").mean() diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 96a39ccd20b..ff5615c17c6 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -15,6 +15,7 @@ conda uninstall -y --force \ pint \ bottleneck \ sparse \ + flox \ h5netcdf \ xarray # to limit the runtime of Upstream CI @@ -47,4 +48,5 @@ python -m pip install \ git+https://github.com/pydata/sparse \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ + git+https://github.com/dcherian/flox \ git+https://github.com/h5netcdf/h5netcdf diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index cb9ec8d3bc5..e20ec2016ed 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -13,6 +13,7 @@ dependencies: - cfgrib - cftime - coveralls + - flox - h5netcdf - h5py - hdf5 diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 6c389c22ce6..634140fe84b 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -10,6 +10,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 516c964afc7..d37bb7dc44a 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -12,6 +12,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index ecabde06622..34879af730b 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -17,6 +17,7 @@ dependencies: - coveralls - dask-core=2021.04 - distributed=2021.04 + - flox=0.5 - h5netcdf=0.11 - h5py=3.1 # hdf5 1.12 conflicts with h5py=3.1 diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6cba8563ecd..680c8219a38 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -141,6 +141,11 @@ Performance - GroupBy binary operations are now vectorized. Previously this involved looping over all groups. (:issue:`5804`,:pull:`6160`) By `Deepak Cherian `_. +- Substantially improved GroupBy operations using `flox `_. + This is auto-enabled when ``flox`` is installed. Use ``xr.set_options(use_flox=False)`` to use + the old algorithm. (:issue:`4473`, :issue:`4498`, :issue:`659`, :issue:`2237`, :pull:`271`). + By `Deepak Cherian `_,`Anderson Banihirwe `_, + `Jimmy Westling `_. Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/setup.cfg b/setup.cfg index 6a0a06d2367..f5dd4dde810 100644 --- a/setup.cfg +++ b/setup.cfg @@ -98,6 +98,7 @@ accel = scipy bottleneck numbagg + flox parallel = dask[complete] diff --git a/xarray/core/_reductions.py b/xarray/core/_reductions.py index 31365f39e65..d782363760a 100644 --- a/xarray/core/_reductions.py +++ b/xarray/core/_reductions.py @@ -4,11 +4,18 @@ from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops +from .options import OPTIONS +from .utils import contains_only_dask_or_numpy if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset +try: + import flox +except ImportError: + flox = None # type: ignore + class DatasetReductions: __slots__ = () @@ -1941,7 +1948,7 @@ def median( class DatasetGroupByReductions: - __slots__ = () + _obj: "Dataset" def reduce( self, @@ -1955,6 +1962,13 @@ def reduce( ) -> "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -2021,13 +2035,23 @@ def count( Data variables: da (labels) int64 1 2 2 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -2095,13 +2119,23 @@ def all( Data variables: da (labels) bool False True True """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -2169,13 +2203,23 @@ def any( Data variables: da (labels) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -2259,14 +2303,25 @@ def max( Data variables: da (labels) float64 nan 2.0 3.0 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -2350,14 +2405,25 @@ def min( Data variables: da (labels) float64 nan 2.0 1.0 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -2445,14 +2511,25 @@ def mean( Data variables: da (labels) float64 nan 2.0 2.0 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -2557,15 +2634,27 @@ def prod( Data variables: da (labels) float64 nan 4.0 3.0 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -2670,15 +2759,27 @@ def sum( Data variables: da (labels) float64 nan 4.0 4.0 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -2780,15 +2881,27 @@ def std( Data variables: da (labels) float64 nan 0.0 1.414 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -2890,15 +3003,27 @@ def var( Data variables: da (labels) float64 nan 0.0 2.0 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -2997,7 +3122,7 @@ def median( class DatasetResampleReductions: - __slots__ = () + _obj: "Dataset" def reduce( self, @@ -3011,6 +3136,13 @@ def reduce( ) -> "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -3077,13 +3209,23 @@ def count( Data variables: da (time) int64 1 3 1 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -3151,13 +3293,23 @@ def all( Data variables: da (time) bool True True False """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -3225,13 +3377,23 @@ def any( Data variables: da (time) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -3315,14 +3477,25 @@ def max( Data variables: da (time) float64 1.0 3.0 nan """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -3406,14 +3579,25 @@ def min( Data variables: da (time) float64 1.0 1.0 nan """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -3501,14 +3685,25 @@ def mean( Data variables: da (time) float64 1.0 2.0 nan """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -3613,15 +3808,27 @@ def prod( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -3726,15 +3933,27 @@ def sum( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -3836,15 +4055,27 @@ def std( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -3946,15 +4177,27 @@ def var( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -4053,7 +4296,7 @@ def median( class DataArrayGroupByReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -4067,6 +4310,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -4128,12 +4378,21 @@ def count( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -4196,12 +4455,21 @@ def all( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -4264,12 +4532,21 @@ def any( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -4346,13 +4623,23 @@ def max( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -4429,13 +4716,23 @@ def min( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -4516,13 +4813,23 @@ def mean( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -4618,14 +4925,25 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -4721,14 +5039,25 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -4821,14 +5150,25 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -4921,14 +5261,25 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -5019,7 +5370,7 @@ def median( class DataArrayResampleReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -5033,6 +5384,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -5094,12 +5452,21 @@ def count( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -5162,12 +5529,21 @@ def all( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -5230,12 +5606,21 @@ def any( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -5312,13 +5697,23 @@ def max( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -5395,13 +5790,23 @@ def min( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -5482,13 +5887,23 @@ def mean( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -5584,14 +5999,25 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -5687,14 +6113,25 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -5787,14 +6224,25 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -5887,14 +6335,25 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 151ef844f44..fec8954c9e2 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -264,6 +264,10 @@ class GroupBy: "_stacked_dim", "_unique_coord", "_dims", + "_squeeze", + # Save unstacked object for flox + "_original_obj", + "_unstacked_group", "_bins", ) @@ -326,6 +330,10 @@ def __init__( if getattr(group, "name", None) is None: group.name = "group" + self._original_obj = obj + self._unstacked_group = group + self._bins = bins + group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) (group_dim,) = group.dims @@ -342,7 +350,7 @@ def __init__( if bins is not None: if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") - binned = pd.cut(group.values, bins, **cut_kwargs) + binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) new_dim_name = group.name + "_bins" group = DataArray(binned, group.coords, name=new_dim_name) full_index = binned.categories @@ -403,6 +411,7 @@ def __init__( self._full_index = full_index self._restore_coord_dims = restore_coord_dims self._bins = bins + self._squeeze = squeeze # cached attributes self._groups = None @@ -570,6 +579,121 @@ def _maybe_unstack(self, obj): obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj + def _flox_reduce(self, dim, **kwargs): + """Adaptor function that translates our groupby API to that of flox.""" + from flox.xarray import xarray_reduce + + from .dataset import Dataset + + obj = self._original_obj + + # preserve current strategy (approximately) for dask groupby. + # We want to control the default anyway to prevent surprises + # if flox decides to change its default + kwargs.setdefault("method", "split-reduce") + + numeric_only = kwargs.pop("numeric_only", None) + if numeric_only: + non_numeric = { + name: var + for name, var in obj.data_vars.items() + if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + } + else: + non_numeric = {} + + # weird backcompat + # reducing along a unique indexed dimension with squeeze=True + # should raise an error + if ( + dim is None or dim == self._group.name + ) and self._group.name in obj.xindexes: + index = obj.indexes[self._group.name] + if index.is_unique and self._squeeze: + raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + + # group is only passed by resample + group = kwargs.pop("group", None) + if group is None: + if isinstance(self._unstacked_group, _DummyGroup): + group = self._unstacked_group.name + else: + group = self._unstacked_group + + unindexed_dims = tuple() + if isinstance(group, str): + if group in obj.dims and group not in obj._indexes and self._bins is None: + unindexed_dims = (group,) + group = self._original_obj[group] + + if isinstance(dim, str): + dim = (dim,) + elif dim is None: + dim = group.dims + elif dim is Ellipsis: + dim = tuple(self._original_obj.dims) + + # Do this so we raise the same error message whether flox is present or not. + # Better to control it here than in flox. + if any(d not in group.dims and d not in self._original_obj.dims for d in dim): + raise ValueError(f"cannot reduce over dimensions {dim}.") + + if self._bins is not None: + # TODO: fix this; When binning by time, self._bins is a DatetimeIndex + expected_groups = (np.array(self._bins),) + isbin = (True,) + # This is an annoying hack. Xarray returns np.nan + # when there are no observations in a bin, instead of 0. + # We can fake that here by forcing min_count=1. + if kwargs["func"] == "count": + if "fill_value" not in kwargs or kwargs["fill_value"] is None: + kwargs["fill_value"] = np.nan + # note min_count makes no sense in the xarray world + # as a kwarg for count, so this should be OK + kwargs["min_count"] = 1 + # empty bins have np.nan regardless of dtype + # flox's default would not set np.nan for integer dtypes + kwargs.setdefault("fill_value", np.nan) + else: + expected_groups = (self._unique_coord.values,) + isbin = False + + result = xarray_reduce( + self._original_obj.drop_vars(non_numeric), + group, + dim=dim, + expected_groups=expected_groups, + isbin=isbin, + **kwargs, + ) + + # Ignore error when the groupby reduction is effectively + # a reduction of the underlying dataset + result = result.drop_vars(unindexed_dims, errors="ignore") + + # broadcast and restore non-numeric data variables (backcompat) + for name, var in non_numeric.items(): + if all(d not in var.dims for d in dim): + result[name] = var.variable.set_dims( + (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + ) + + if self._bins is not None: + # bins provided to flox are at full precision + # the bin edge labels have a default precision of 3 + # reassign to fix that. + new_coord = [ + pd.Interval(inter.left, inter.right) for inter in self._full_index + ] + result[self._group.name] = new_coord + # Fix dimension order when binning a dimension coordinate + # Needed as long as we do a separate code path for pint; + # For some reason Datasets and DataArrays behave differently! + if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: + result = result.transpose(self._group.name, ...) + + return result + def fillna(self, value): """Fill missing values in this object by group. diff --git a/xarray/core/options.py b/xarray/core/options.py index 399afe90b66..d31f2577601 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -27,6 +27,7 @@ class T_Options(TypedDict): keep_attrs: Literal["default", True, False] warn_for_unclosed_files: bool use_bottleneck: bool + use_flox: bool OPTIONS: T_Options = { @@ -45,6 +46,7 @@ class T_Options(TypedDict): "file_cache_maxsize": 128, "keep_attrs": "default", "use_bottleneck": True, + "use_flox": True, "warn_for_unclosed_files": False, } @@ -70,6 +72,7 @@ def _positive_integer(value): "file_cache_maxsize": _positive_integer, "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), + "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -180,6 +183,9 @@ class set_options: use_bottleneck : bool, default: True Whether to use ``bottleneck`` to accelerate 1D reductions and 1D rolling reduction operations. + use_flox : bool, default: True + Whether to use ``numpy_groupies`` and `flox`` to + accelerate groupby and resampling reductions. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ed665ad4048..bcc4bfb90cd 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,13 +1,15 @@ import warnings from typing import Any, Callable, Hashable, Sequence, Union +import numpy as np + from ._reductions import DataArrayResampleReductions, DatasetResampleReductions -from .groupby import DataArrayGroupByBase, DatasetGroupByBase +from .groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy RESAMPLE_DIM = "__resample_dim__" -class Resample: +class Resample(GroupBy): """An object that extends the `GroupBy` object with additional logic for handling specialized re-sampling operations. @@ -21,6 +23,29 @@ class Resample: """ + def _flox_reduce(self, dim, **kwargs): + + from .dataarray import DataArray + + kwargs.setdefault("method", "cohorts") + + # now create a label DataArray since resample doesn't do that somehow + repeats = [] + for slicer in self._group_indices: + stop = ( + slicer.stop + if slicer.stop is not None + else self._obj.sizes[self._group_dim] + ) + repeats.append(stop - slicer.start) + labels = np.repeat(self._unique_coord.data, repeats) + group = DataArray(labels, dims=(self._group_dim,), name=self._unique_coord.name) + + result = super()._flox_reduce(dim=dim, group=group, **kwargs) + result = self._maybe_restore_empty_groups(result) + result = result.rename({RESAMPLE_DIM: self._group_dim}) + return result + def _upsample(self, method, *args, **kwargs): """Dispatch function to call appropriate up-sampling methods on data. @@ -158,7 +183,7 @@ def _interpolate(self, kind="linear"): ) -class DataArrayResample(DataArrayGroupByBase, DataArrayResampleReductions, Resample): +class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ @@ -249,7 +274,7 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -class DatasetResample(DatasetGroupByBase, DatasetResampleReductions, Resample): +class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): diff --git a/xarray/core/utils.py b/xarray/core/utils.py index aaa087a3532..eda08becc20 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -928,3 +928,21 @@ def iterate_nested(nested_list): yield from iterate_nested(item) else: yield item + + +def contains_only_dask_or_numpy(obj) -> bool: + """Returns True if xarray object contains only numpy or dask arrays. + + Expects obj to be Dataset or DataArray""" + from .dataarray import DataArray + from .pycompat import is_duck_dask_array + + if isinstance(obj, DataArray): + obj = obj._to_temp_dataset() + + return all( + [ + isinstance(var.data, np.ndarray) or is_duck_dask_array(var.data) + for var in obj.variables.values() + ] + ) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7872fec2e62..65f0bc08261 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -78,6 +78,8 @@ def _importorskip(modname, minversion=None): has_cartopy, requires_cartopy = _importorskip("cartopy") has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") +has_flox, requires_flox = _importorskip("flox") + # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b4b93d1dba3..8c745dc640d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -17,6 +17,7 @@ assert_identical, create_test_data, requires_dask, + requires_flox, requires_scipy, ) @@ -24,7 +25,10 @@ @pytest.fixture def dataset(): ds = xr.Dataset( - {"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))}, + { + "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), + "baz": ("x", ["e", "f", "g"]), + }, {"x": ["a", "b", "c"], "y": [1, 2, 3, 4], "z": [1, 2]}, ) ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2) @@ -71,6 +75,15 @@ def test_multi_index_groupby_map(dataset) -> None: assert_equal(expected, actual) +def test_reduce_numeric_only(dataset) -> None: + gb = dataset.groupby("x", squeeze=False) + with xr.set_options(use_flox=False): + expected = gb.sum() + with xr.set_options(use_flox=True): + actual = gb.sum() + assert_identical(expected, actual) + + def test_multi_index_groupby_sum() -> None: # regression test for GH873 ds = xr.Dataset( @@ -961,6 +974,17 @@ def test_groupby_dataarray_map_dataset_func(): assert_identical(actual, expected) +@requires_flox +@pytest.mark.parametrize("kwargs", [{"method": "map-reduce"}, {"engine": "numpy"}]) +def test_groupby_flox_kwargs(kwargs): + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + with xr.set_options(use_flox=False): + expected = ds.groupby("c").mean() + with xr.set_options(use_flox=True): + actual = ds.groupby("c").mean(**kwargs) + assert_identical(expected, actual) + + class TestDataArrayGroupBy: @pytest.fixture(autouse=True) def setup(self): @@ -1016,19 +1040,22 @@ def test_groupby_properties(self): assert_array_equal(expected_groups[key], grouped.groups[key]) assert 3 == len(grouped) - def test_groupby_map_identity(self): + @pytest.mark.parametrize( + "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] + ) + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("squeeze", [True, False]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None: expected = self.da - idx = expected.coords["y"] + if use_da: + by = expected.coords[by] def identity(x): return x - for g in ["x", "y", "abc", idx]: - for shortcut in [False, True]: - for squeeze in [False, True]: - grouped = expected.groupby(g, squeeze=squeeze) - actual = grouped.map(identity, shortcut=shortcut) - assert_identical(expected, actual) + grouped = expected.groupby(by, squeeze=squeeze) + actual = grouped.map(identity, shortcut=shortcut) + assert_identical(expected, actual) def test_groupby_sum(self): array = self.da @@ -1083,19 +1110,21 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) - def test_groupby_sum_default(self): + @pytest.mark.parametrize("method", ["sum", "mean", "median"]) + def test_groupby_reductions(self, method): array = self.da grouped = array.groupby("abc") - expected_sum_all = Dataset( + reduction = getattr(np, method) + expected = Dataset( { "foo": Variable( ["x", "abc"], np.array( [ - self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1), + reduction(self.x[:, :9], axis=-1), + reduction(self.x[:, 10:], axis=-1), + reduction(self.x[:, 9:10], axis=-1), ] ).T, ), @@ -1103,7 +1132,14 @@ def test_groupby_sum_default(self): } )["foo"] - assert_allclose(expected_sum_all, grouped.sum(dim="y")) + with xr.set_options(use_flox=False): + actual_legacy = getattr(grouped, method)(dim="y") + + with xr.set_options(use_flox=True): + actual_npg = getattr(grouped, method)(dim="y") + + assert_allclose(expected, actual_legacy) + assert_allclose(expected, actual_npg) def test_groupby_count(self): array = DataArray( @@ -1318,13 +1354,23 @@ def test_groupby_bins(self): expected = DataArray( [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} ) - # the problem with this is that it overwrites the dimensions of array! - # actual = array.groupby('dim_0', bins=bins).sum() - actual = array.groupby_bins("dim_0", bins).map(lambda x: x.sum()) + actual = array.groupby_bins("dim_0", bins=bins).sum() + assert_identical(expected, actual) + + actual = array.groupby_bins("dim_0", bins=bins).map(lambda x: x.sum()) assert_identical(expected, actual) + # make sure original array dims are unchanged assert len(array.dim_0) == 4 + da = xr.DataArray(np.ones((2, 3, 4))) + bins = [-1, 0, 1, 2] + with xr.set_options(use_flox=False): + actual = da.groupby_bins("dim_0", bins).mean(...) + with xr.set_options(use_flox=True): + expected = da.groupby_bins("dim_0", bins).mean(...) + assert_allclose(actual, expected) + def test_groupby_bins_empty(self): array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty @@ -1350,6 +1396,27 @@ def test_groupby_bins_multidim(self): actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) + bins = [-2, -1, 0, 1, 2] + field = DataArray(np.ones((5, 3)), dims=("x", "y")) + by = DataArray( + np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3), dims=("x", "y") + ) + actual = field.groupby_bins(by, bins=bins).count() + + bincoord = np.array( + [ + pd.Interval(left, right, closed="right") + for left, right in zip(bins[:-1], bins[1:]) + ], + dtype=object, + ) + expected = DataArray( + np.array([6, np.nan, 3, 6]), + dims="group_bins", + coords={"group_bins": bincoord}, + ) + assert_identical(actual, expected) + def test_groupby_bins_sort(self): data = xr.DataArray( np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} @@ -1357,6 +1424,12 @@ def test_groupby_bins_sort(self): binned_mean = data.groupby_bins("x", bins=11).mean() assert binned_mean.to_index().is_monotonic_increasing + with xr.set_options(use_flox=True): + actual = data.groupby_bins("x", bins=11).count() + with xr.set_options(use_flox=False): + expected = data.groupby_bins("x", bins=11).count() + assert_identical(actual, expected) + def test_groupby_assign_coords(self): array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") @@ -1769,7 +1842,7 @@ def test_resample_min_count(self): ], dim=actual["time"], ) - assert_equal(expected, actual) + assert_allclose(expected, actual) def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) @@ -1903,7 +1976,7 @@ def test_resample_ds_da_are_the_same(self): "x": np.arange(5), } ) - assert_identical( + assert_allclose( ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() ) @@ -1916,6 +1989,3 @@ def func(arg1, arg2, arg3=0.0): expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) assert_identical(expected, actual) - - -# TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c18b7d18c04..679733e1ecf 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5344,8 +5344,12 @@ def test_computation_objects(self, func, variant, dtype): units = extract_units(ds) args = [] if func.name != "groupby" else ["y"] - expected = attach_units(func(strip_units(ds)).mean(*args), units) - actual = func(ds).mean(*args) + # Doesn't work with flox because pint doesn't implement + # ufunc.reduceat or np.bincount + # kwargs = {"engine": "numpy"} if "groupby" in func.name else {} + kwargs = {} + expected = attach_units(func(strip_units(ds)).mean(*args, **kwargs), units) + actual = func(ds).mean(*args, **kwargs) assert_units_equal(expected, actual) assert_allclose(expected, actual) diff --git a/xarray/util/generate_reductions.py b/xarray/util/generate_reductions.py index e79c94e8907..96b91c16906 100644 --- a/xarray/util/generate_reductions.py +++ b/xarray/util/generate_reductions.py @@ -23,13 +23,19 @@ from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops +from .options import OPTIONS +from .utils import contains_only_dask_or_numpy if TYPE_CHECKING: from .dataarray import DataArray - from .dataset import Dataset''' + from .dataset import Dataset +try: + import flox +except ImportError: + flox = None # type: ignore''' -CLASS_PREAMBLE = """ +DEFAULT_PREAMBLE = """ class {obj}{cls}Reductions: __slots__ = () @@ -46,6 +52,54 @@ def reduce( ) -> "{obj}": raise NotImplementedError()""" +GROUPBY_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + +RESAMPLE_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + TEMPLATE_REDUCTION_SIGNATURE = ''' def {method}( self, @@ -113,11 +167,7 @@ def {method}( These could include dask-specific kwargs like ``split_every``.""" NAN_CUM_METHODS = ["cumsum", "cumprod"] - -NUMERIC_ONLY_METHODS = [ - "cumsum", - "cumprod", -] +NUMERIC_ONLY_METHODS = ["cumsum", "cumprod"] _NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") @@ -182,6 +232,7 @@ def __init__( docref, docref_description, example_call_preamble, + definition_preamble, see_also_obj=None, ): self.datastructure = datastructure @@ -190,7 +241,7 @@ def __init__( self.docref = docref self.docref_description = docref_description self.example_call_preamble = example_call_preamble - self.preamble = CLASS_PREAMBLE.format(obj=datastructure.name, cls=cls) + self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) if not see_also_obj: self.see_also_obj = self.datastructure.name else: @@ -268,6 +319,53 @@ def generate_example(self, method): >>> {calculation}(){extra_examples}""" +class GroupByReductionGenerator(ReductionGenerator): + def generate_code(self, method): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] + + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") + + # numpy_groupies & flox do not support median + # https://github.com/ml31415/numpy-groupies/issues/43 + if method.name == "median": + indent = 12 + else: + indent = 16 + + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), indent * " ") + else: + extra_kwargs = "" + + if method.name == "median": + return f"""\ + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + else: + return f"""\ + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="{method.name}", + dim=dim,{extra_kwargs} + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + class GenericReductionGenerator(ReductionGenerator): def generate_code(self, method): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] @@ -335,6 +433,7 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="DataArray", + definition_preamble=DEFAULT_PREAMBLE, ) DATAARRAY_GENERATOR = GenericReductionGenerator( cls="", @@ -344,39 +443,43 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="Dataset", + definition_preamble=DEFAULT_PREAMBLE, ) - -DATAARRAY_GROUPBY_GENERATOR = GenericReductionGenerator( +DATAARRAY_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATAARRAY_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATAARRAY_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) -DATASET_GROUPBY_GENERATOR = GenericReductionGenerator( +DATASET_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATASET_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATASET_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) @@ -386,6 +489,7 @@ class DataStructure: p = Path(os.getcwd()) filepath = p.parent / "xarray" / "xarray" / "core" / "_reductions.py" + # filepath = p.parent / "core" / "_reductions.py" # Run from script location with open(filepath, mode="w", encoding="utf-8") as f: f.write(MODULE_PREAMBLE + "\n") for gen in [ diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 561126ea05f..b8689e3a18f 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -122,6 +122,8 @@ def show_versions(file=sys.stdout): ("cupy", lambda mod: mod.__version__), ("pint", lambda mod: mod.__version__), ("sparse", lambda mod: mod.__version__), + ("flox", lambda mod: mod.__version__), + ("numpy_groupies", lambda mod: mod.__version__), # xarray setup/test ("setuptools", lambda mod: mod.__version__), ("pip", lambda mod: mod.__version__),