Skip to content

Commit

Permalink
Pass ddof through for numbagg (#302)
Browse files Browse the repository at this point in the history
* Support ddof with numbagg

* Fix tests
  • Loading branch information
dcherian authored Feb 2, 2024
1 parent 1368f0f commit f42f3ff
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
1 change: 0 additions & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ dependencies:
- pytest-xdist
- xarray
- pre-commit
- numbagg>=0.3
- numpy_groupies>=0.9.19
- pooch
- toolz
Expand Down
1 change: 1 addition & 0 deletions ci/no-dask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ dependencies:
- pooch
- toolz
- numba
- numbagg>=0.3
25 changes: 18 additions & 7 deletions flox/aggregate_numbagg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import numbagg
import numbagg.grouped
import numpy as np
from packaging.version import Version

NUMBAGG_SUPPORTS_DDOF = Version(numbagg.__version__) >= Version("0.7.0")

DEFAULT_FILL_VALUE = {
"nansum": 0,
Expand Down Expand Up @@ -42,6 +45,7 @@ def _numbagg_wrapper(
size=None,
fill_value=None,
dtype=None,
**kwargs,
):
cast_to = CAST_TO.get(func, None)
if cast_to:
Expand All @@ -56,6 +60,7 @@ def _numbagg_wrapper(
group_idx,
axis=axis,
num_labels=size,
**kwargs,
# The following are unsupported
# fill_value=fill_value,
# dtype=dtype,
Expand All @@ -65,30 +70,36 @@ def _numbagg_wrapper(


def nanvar(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0

kwargs = {}
if NUMBAGG_SUPPORTS_DDOF:
kwargs["ddof"] = ddof
elif ddof != 1:
raise ValueError("Need numbagg >= v0.7.0 to support ddof != 1")
return _numbagg_wrapper(
group_idx,
array,
axis=axis,
size=size,
func="nanvar",
# ddof=0,
**kwargs,
# fill_value=fill_value,
# dtype=dtype,
)


def nanstd(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None, ddof=0):
assert ddof != 0

kwargs = {}
if NUMBAGG_SUPPORTS_DDOF:
kwargs["ddof"] = ddof
elif ddof != 1:
raise ValueError("Need numbagg >= v0.7.0 to support ddof != 1")
return _numbagg_wrapper(
group_idx,
array,
axis=axis,
size=size,
func="nanstd"
# ddof=0,
func="nanstd",
**kwargs,
# fill_value=fill_value,
# dtype=dtype,
)
Expand Down
15 changes: 8 additions & 7 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ def generic_aggregate(
from . import aggregate_numbagg

try:
if (
# numabgg hardcodes ddof=1
("var" in func or "std" in func)
and kwargs.get("ddof", 0) == 0
):
method = get_npg_aggregation(func, engine="numpy")

if "var" in func or "std" in func:
ddof = kwargs.get("ddof", 0)
if aggregate_numbagg.NUMBAGG_SUPPORTS_DDOF or (ddof != 0):
method = getattr(aggregate_numbagg, func)
else:
logger.debug(f"numbagg too old for ddof={ddof}. Falling back to numpy")
method = get_npg_aggregation(func, engine="numpy")
else:
method = getattr(aggregate_numbagg, func)

except AttributeError:
logger.debug(f"Couldn't find {func} for engine='numbagg'. Falling back to numpy")
method = get_npg_aggregation(func, engine="numpy")
Expand Down

0 comments on commit f42f3ff

Please sign in to comment.