diff --git a/dask_groupby/core.py b/dask_groupby/core.py index b20c16faf..a2cd17b0c 100644 --- a/dask_groupby/core.py +++ b/dask_groupby/core.py @@ -389,6 +389,8 @@ def chunk_argreduce( if not np.isnan(results["groups"]).all(): # will not work for empty groups... # glorious + # TODO: npg bug + results["intermediates"][1] = results["intermediates"][1].astype(int) newidx = np.broadcast_to(idx, array.shape)[ np.unravel_index(results["intermediates"][1], array.shape) ] @@ -992,6 +994,7 @@ def groupby_reduce( isbin: bool = False, axis=None, fill_value=None, + skipna: Optional[bool] = None, min_count: Optional[int] = None, split_out: int = 1, method: str = "mapreduce", @@ -1020,6 +1023,16 @@ def groupby_reduce( Negative integers are normalized using array.ndim fill_value: Any Value when a label in `expected_groups` is not present + skipna : bool, default: None + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. split_out: int, optional Number of chunks along group axis in output (last axis) method: {"mapreduce", "blockwise", "cohorts"}, optional @@ -1062,10 +1075,24 @@ def groupby_reduce( f"Received array of shape {array.shape} and by of shape {by.shape}" ) - if min_count is not None and min_count > 1 and func not in ["nansum", "nanprod"]: - raise ValueError( - "min_count can be > 1 only for nansum, nanprod. This is an Xarray limitation." - ) + # Handle skipna here because I need to know dtype to make a good default choice. + # We cannnot handle this easily for xarray Datasets in xarray_reduce + if skipna and func in ["all", "any", "count"]: + raise ValueError(f"skipna cannot be truthy for {func} reductions.") + + if skipna or (skipna is None and array.dtype.kind in "cfO"): + if "nan" not in func and func not in ["all", "any", "count"]: + func = f"nan{func}" + + if min_count is not None and min_count > 1: + if func not in ["nansum", "nanprod"]: + raise ValueError( + "min_count can be > 1 only for nansum, nanprod." + " or for sum, prod with skipna=True." + " This is an Xarray limitation." + ) + elif "nan" not in func and skipna: + func = f"nan{func}" if axis is None: axis = tuple(array.ndim + np.arange(-by.ndim, 0)) diff --git a/dask_groupby/xarray.py b/dask_groupby/xarray.py index 39911b6d7..4a634112c 100644 --- a/dask_groupby/xarray.py +++ b/dask_groupby/xarray.py @@ -60,7 +60,7 @@ def xarray_reduce( method: str = "mapreduce", backend: str = "numpy", keep_attrs: bool = True, - skipna: bool = True, + skipna: Optional[bool] = None, min_count: Optional[int] = None, **finalize_kwargs, ): @@ -113,10 +113,15 @@ def xarray_reduce( keep_attrs: bool, optional Preserve attrs? skipna: bool, optional - Use NaN-skipping aggregations like nanmean? - min_count: int, optional - NaN out when number of non-NaN values in aggregation is < min_count - Only applies to nansum, nanprod. + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or ``skipna=True`` has not been + implemented (object, datetime64 or timedelta64). + min_count : int, default: None + The required number of valid values to perform the operation. If + fewer than min_count non-NA values are present the result will be + NA. Only used if skipna is set to True or defaults to True for the + array's dtype. finalize_kwargs: dict, optional kwargs passed to the finalize function, like ddof for var, std. @@ -130,9 +135,6 @@ def xarray_reduce( FIXME: Add docs. """ - if (skipna or min_count is not None) and func not in ["all", "any", "count"]: - func = f"nan{func}" - for b in by: if isinstance(b, xr.DataArray) and b.name is None: raise ValueError("Cannot group by unnamed DataArrays.") @@ -285,6 +287,7 @@ def wrapper(*args, **kwargs): "fill_value": fill_value, "method": method, "min_count": min_count, + "skipna": skipna, "backend": backend, # The following mess exists becuase for multiple `by`s I factorize eagerly # here before passing it on; this means I have to handle the