Skip to content
forked from pydata/xarray

Commit

Permalink
Better binning API
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 28, 2024
1 parent 01fbf50 commit e250895
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 22 deletions.
13 changes: 7 additions & 6 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6805,6 +6805,7 @@ def groupby_bins(
include_lowest: bool = False,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
duplicates: Literal["raise", "drop"] = "raise",
) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -6841,6 +6842,8 @@ def groupby_bins(
restore_coord_dims : bool, default: False
If True, also restore the dimension order of multi-dimensional
coordinates.
duplicates : {default 'raise', 'drop'}, optional
If bin edges are not unique, raise ValueError or drop non-uniques.
Returns
-------
Expand Down Expand Up @@ -6873,12 +6876,10 @@ def groupby_bins(
_validate_groupby_squeeze(squeeze)
grouper = BinGrouper(
bins=bins,
cut_kwargs={
"right": right,
"labels": labels,
"precision": precision,
"include_lowest": include_lowest,
},
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
)
rgrouper = ResolvedGrouper(grouper, group, self)

Expand Down
13 changes: 7 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10342,6 +10342,7 @@ def groupby_bins(
include_lowest: bool = False,
squeeze: bool | None = None,
restore_coord_dims: bool = False,
duplicates: Literal["raise", "drop"] = "raise",
) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Expand Down Expand Up @@ -10378,6 +10379,8 @@ def groupby_bins(
restore_coord_dims : bool, default: False
If True, also restore the dimension order of multi-dimensional
coordinates.
duplicates : {default 'raise', 'drop'}, optional
If bin edges are not unique, raise ValueError or drop non-uniques.
Returns
-------
Expand Down Expand Up @@ -10410,12 +10413,10 @@ def groupby_bins(
_validate_groupby_squeeze(squeeze)
grouper = BinGrouper(
bins=bins,
cut_kwargs={
"right": right,
"labels": labels,
"precision": precision,
"include_lowest": include_lowest,
},
right=right,
labels=labels,
precision=precision,
include_lowest=include_lowest,
)
rgrouper = ResolvedGrouper(grouper, group, self)

Expand Down
56 changes: 49 additions & 7 deletions xarray/core/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import datetime
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -195,14 +196,46 @@ class BinGrouper(Grouper):
Attributes
----------
bins: int, sequence of scalars, or IntervalIndex
Speciication for bins either as integer, or as bin edges.
cut_kwargs: dict
Keyword arguments forwarded to :py:func:`pandas.cut`.
bins : int, sequence of scalars, or IntervalIndex
The criteria to bin by.
* int : Defines the number of equal-width bins in the range of `x`. The
range of `x` is extended by .1% on each side to include the minimum
and maximum values of `x`.
* sequence of scalars : Defines the bin edges allowing for non-uniform
width. No extension of the range of `x` is done.
* IntervalIndex : Defines the exact bins to be used. Note that
IntervalIndex for `bins` must be non-overlapping.
right : bool, default True
Indicates whether `bins` includes the rightmost edge or not. If
``right == True`` (the default), then the `bins` ``[1, 2, 3, 4]``
indicate (1,2], (2,3], (3,4]. This argument is ignored when
`bins` is an IntervalIndex.
labels : array or False, default None
Specifies the labels for the returned bins. Must be the same length as
the resulting bins. If False, returns only integer indicators of the
bins. This affects the type of the output container (see below).
This argument is ignored when `bins` is an IntervalIndex. If True,
raises an error. When `ordered=False`, labels must be provided.
retbins : bool, default False
Whether to return the bins or not. Useful when bins is provided
as a scalar.
precision : int, default 3
The precision at which to store and display the bins labels.
include_lowest : bool, default False
Whether the first interval should be left-inclusive or not.
duplicates : {default 'raise', 'drop'}, optional
If bin edges are not unique, raise ValueError or drop non-uniques.
"""

bins: int | Sequence | pd.IntervalIndex
cut_kwargs: Mapping = field(default_factory=dict)
# The rest are copied from pandas
right: bool = True
labels: Any = None
precision: int = 3
include_lowest: bool = False
duplicates: Literal["raise", "drop"] = "raise"

def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
Expand All @@ -213,7 +246,16 @@ def factorize(self, group: T_Group) -> EncodedGroups:

data = group.data

binned, self.bins = pd.cut(data, self.bins, **self.cut_kwargs, retbins=True)
binned, self.bins = pd.cut(
data,
bins=self.bins,
right=self.right,
labels=self.labels,
precision=self.precision,
include_lowest=self.include_lowest,
duplicates=self.duplicates,
retbins=True,
)

binned_codes = binned.codes
if (binned_codes == -1).all():
Expand Down
4 changes: 1 addition & 3 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,9 +1034,7 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:

with xr.set_options(use_flox=use_flox):
actual = da.groupby(
x=BinGrouper(
bins=x_bins, cut_kwargs=dict(include_lowest=True, right=False)
),
x=BinGrouper(bins=x_bins, include_lowest=True, right=False),
).mean()
assert_identical(expected, actual)

Expand Down

0 comments on commit e250895

Please sign in to comment.