Skip to content

Commit

Permalink
Enable nanargmax, nanargmin (xarray-contrib#171)
Browse files Browse the repository at this point in the history
* Support nanargmin, nanargmax

* Fix test

* Add blockwise test

* Fix blockwise test

* Apply suggestions from code review
  • Loading branch information
dcherian authored May 11, 2023
1 parent 6a5969f commit 096c080
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
4 changes: 2 additions & 2 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _pick_second(*x):
chunk=("nanmax", "nanargmax"), # order is important
combine=("max", "argmax"),
reduction_type="argreduce",
fill_value=(dtypes.NINF, -1),
fill_value=(dtypes.NINF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.intp),
Expand All @@ -434,7 +434,7 @@ def _pick_second(*x):
chunk=("nanmin", "nanargmin"), # order is important
combine=("min", "argmin"),
reduction_type="argreduce",
fill_value=(dtypes.INF, -1),
fill_value=(dtypes.INF, 0),
final_fill_value=-1,
finalize=_pick_second,
dtypes=(None, np.intp),
Expand Down
7 changes: 5 additions & 2 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,8 +1323,11 @@ def dask_groupby_agg(
by = dask.array.from_array(by, chunks=chunks)
_, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :])

# preprocess the array: for argreductions, this zips the index together with the array block
if agg.preprocess:
# preprocess the array:
# - for argreductions, this zips the index together with the array block
# - not necessary for blockwise with argreductions
# - if this is needed later, we can fix this then
if agg.preprocess and method != "blockwise":
array = agg.preprocess(array, axis=axis)

# 1. We first apply the groupby-reduction blockwise to generate "intermediates"
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


@pytest.fixture(scope="module", params=["flox"])
@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
def engine(request):
if request.param == "numba":
try:
Expand Down
25 changes: 17 additions & 8 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def dask_array_ones(*args):
"nansum",
"argmax",
"nanfirst",
pytest.param("nanargmax", marks=(pytest.mark.skip,)),
"nanargmax",
"prod",
"nanprod",
"mean",
Expand All @@ -69,7 +69,7 @@ def dask_array_ones(*args):
"min",
"nanmin",
"argmin",
pytest.param("nanargmin", marks=(pytest.mark.skip,)),
"nanargmin",
"any",
"all",
"nanlast",
Expand Down Expand Up @@ -233,8 +233,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
# computing silences a bunch of dask warnings
array_ = array.compute() if chunks is not None else array
if "arg" in func and add_nan_by:
# NaNs are in by, but we can't call np.argmax([..., NaN, .. ])
# That would return index of the NaN
# This way, we insert NaNs where there are NaNs in by, and
# call np.nanargmax
func_ = f"nan{func}" if "nan" not in func else func
array_[..., nanmask] = np.nan
expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs)
expected = getattr(np, func_)(array_, axis=-1, **kwargs)
# elif func in ["first", "last"]:
# expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
elif func in ["nanfirst", "nanlast"]:
Expand All @@ -259,6 +264,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):

params = list(itertools.product(["map-reduce"], [True, False, None]))
params.extend(itertools.product(["cohorts"], [False, None]))
if chunks == -1:
params.extend([("blockwise", None)])

for method, reindex in params:
call = partial(
groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs
Expand All @@ -269,11 +277,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
call()
continue
actual, *groups = call()
if "arg" not in func:
# make sure we use simple combine
assert any("simple-combine" in key for key in actual.dask.layers.keys())
else:
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
if method != "blockwise":
if "arg" not in func:
# make sure we use simple combine
assert any("simple-combine" in key for key in actual.dask.layers.keys())
else:
assert any("grouped-combine" in key for key in actual.dask.layers.keys())
for actual_group, expect in zip(groups, expected_groups):
assert_equal(actual_group, expect, tolerance)
if "arg" in func:
Expand Down

0 comments on commit 096c080

Please sign in to comment.