Skip to content

Commit

Permalink
Support skipna in core.groupby_reduce
Browse files Browse the repository at this point in the history
This is the right place to do it since skipna can be optionally True
for appropriate dtypes. It's hard to choose the right one when passing
Dataset to apply_ufunc. Instead we do it here where we are always
dealing with a single array and can choose based on dtype.
  • Loading branch information
dcherian committed Nov 10, 2021
1 parent 51297ed commit 7428689
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
35 changes: 31 additions & 4 deletions dask_groupby/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def chunk_argreduce(
if not np.isnan(results["groups"]).all():
# will not work for empty groups...
# glorious
# TODO: npg bug
results["intermediates"][1] = results["intermediates"][1].astype(int)
newidx = np.broadcast_to(idx, array.shape)[
np.unravel_index(results["intermediates"][1], array.shape)
]
Expand Down Expand Up @@ -992,6 +994,7 @@ def groupby_reduce(
isbin: bool = False,
axis=None,
fill_value=None,
skipna: Optional[bool] = None,
min_count: Optional[int] = None,
split_out: int = 1,
method: str = "mapreduce",
Expand Down Expand Up @@ -1020,6 +1023,16 @@ def groupby_reduce(
Negative integers are normalized using array.ndim
fill_value: Any
Value when a label in `expected_groups` is not present
skipna : bool, default: None
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).
min_count : int, default: None
The required number of valid values to perform the operation. If
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype.
split_out: int, optional
Number of chunks along group axis in output (last axis)
method: {"mapreduce", "blockwise", "cohorts"}, optional
Expand Down Expand Up @@ -1062,10 +1075,24 @@ def groupby_reduce(
f"Received array of shape {array.shape} and by of shape {by.shape}"
)

if min_count is not None and min_count > 1 and func not in ["nansum", "nanprod"]:
raise ValueError(
"min_count can be > 1 only for nansum, nanprod. This is an Xarray limitation."
)
# Handle skipna here because I need to know dtype to make a good default choice.
# We cannnot handle this easily for xarray Datasets in xarray_reduce
if skipna and func in ["all", "any", "count"]:
raise ValueError(f"skipna cannot be truthy for {func} reductions.")

if skipna or (skipna is None and array.dtype.kind in "cfO"):
if "nan" not in func and func not in ["all", "any", "count"]:
func = f"nan{func}"

if min_count is not None and min_count > 1:
if func not in ["nansum", "nanprod"]:
raise ValueError(
"min_count can be > 1 only for nansum, nanprod."
" or for sum, prod with skipna=True."
" This is an Xarray limitation."
)
elif "nan" not in func and skipna:
func = f"nan{func}"

if axis is None:
axis = tuple(array.ndim + np.arange(-by.ndim, 0))
Expand Down
19 changes: 11 additions & 8 deletions dask_groupby/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def xarray_reduce(
method: str = "mapreduce",
backend: str = "numpy",
keep_attrs: bool = True,
skipna: bool = True,
skipna: Optional[bool] = None,
min_count: Optional[int] = None,
**finalize_kwargs,
):
Expand Down Expand Up @@ -113,10 +113,15 @@ def xarray_reduce(
keep_attrs: bool, optional
Preserve attrs?
skipna: bool, optional
Use NaN-skipping aggregations like nanmean?
min_count: int, optional
NaN out when number of non-NaN values in aggregation is < min_count
Only applies to nansum, nanprod.
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64).
min_count : int, default: None
The required number of valid values to perform the operation. If
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype.
finalize_kwargs: dict, optional
kwargs passed to the finalize function, like ddof for var, std.
Expand All @@ -130,9 +135,6 @@ def xarray_reduce(
FIXME: Add docs.
"""

if (skipna or min_count is not None) and func not in ["all", "any", "count"]:
func = f"nan{func}"

for b in by:
if isinstance(b, xr.DataArray) and b.name is None:
raise ValueError("Cannot group by unnamed DataArrays.")
Expand Down Expand Up @@ -285,6 +287,7 @@ def wrapper(*args, **kwargs):
"fill_value": fill_value,
"method": method,
"min_count": min_count,
"skipna": skipna,
"backend": backend,
# The following mess exists becuase for multiple `by`s I factorize eagerly
# here before passing it on; this means I have to handle the
Expand Down

0 comments on commit 7428689

Please sign in to comment.