From 096c080385fffdb809e2bd14d320023fa0f0d04d Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 11 May 2023 14:28:27 -0600 Subject: [PATCH] Enable nanargmax, nanargmin (#171) * Support nanargmin, nanargmax * Fix test * Add blockwise test * Fix blockwise test * Apply suggestions from code review --- flox/aggregations.py | 4 ++-- flox/core.py | 7 +++++-- tests/conftest.py | 2 +- tests/test_core.py | 25 +++++++++++++++++-------- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/flox/aggregations.py b/flox/aggregations.py index e85c0699d..13b23fafe 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -421,7 +421,7 @@ def _pick_second(*x): chunk=("nanmax", "nanargmax"), # order is important combine=("max", "argmax"), reduction_type="argreduce", - fill_value=(dtypes.NINF, -1), + fill_value=(dtypes.NINF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), @@ -434,7 +434,7 @@ def _pick_second(*x): chunk=("nanmin", "nanargmin"), # order is important combine=("min", "argmin"), reduction_type="argreduce", - fill_value=(dtypes.INF, -1), + fill_value=(dtypes.INF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), diff --git a/flox/core.py b/flox/core.py index 2444df8e3..57ea4556f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1323,8 +1323,11 @@ def dask_groupby_agg( by = dask.array.from_array(by, chunks=chunks) _, (array, by) = dask.array.unify_chunks(array, inds, by, inds[-by.ndim :]) - # preprocess the array: for argreductions, this zips the index together with the array block - if agg.preprocess: + # preprocess the array: + # - for argreductions, this zips the index together with the array block + # - not necessary for blockwise with argreductions + # - if this is needed later, we can fix this then + if agg.preprocess and method != "blockwise": array = agg.preprocess(array, axis=axis) # 1. We first apply the groupby-reduction blockwise to generate "intermediates" diff --git a/tests/conftest.py b/tests/conftest.py index 8e5039d28..5c3bb81f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest -@pytest.fixture(scope="module", params=["flox"]) +@pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) def engine(request): if request.param == "numba": try: diff --git a/tests/test_core.py b/tests/test_core.py index 7c152fd10..5c4db9248 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -55,7 +55,7 @@ def dask_array_ones(*args): "nansum", "argmax", "nanfirst", - pytest.param("nanargmax", marks=(pytest.mark.skip,)), + "nanargmax", "prod", "nanprod", "mean", @@ -69,7 +69,7 @@ def dask_array_ones(*args): "min", "nanmin", "argmin", - pytest.param("nanargmin", marks=(pytest.mark.skip,)), + "nanargmin", "any", "all", "nanlast", @@ -233,8 +233,13 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): # computing silences a bunch of dask warnings array_ = array.compute() if chunks is not None else array if "arg" in func and add_nan_by: + # NaNs are in by, but we can't call np.argmax([..., NaN, .. ]) + # That would return index of the NaN + # This way, we insert NaNs where there are NaNs in by, and + # call np.nanargmax + func_ = f"nan{func}" if "nan" not in func else func array_[..., nanmask] = np.nan - expected = getattr(np, "nan" + func)(array_, axis=-1, **kwargs) + expected = getattr(np, func_)(array_, axis=-1, **kwargs) # elif func in ["first", "last"]: # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs) elif func in ["nanfirst", "nanlast"]: @@ -259,6 +264,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): params = list(itertools.product(["map-reduce"], [True, False, None])) params.extend(itertools.product(["cohorts"], [False, None])) + if chunks == -1: + params.extend([("blockwise", None)]) + for method, reindex in params: call = partial( groupby_reduce, array, *by, method=method, reindex=reindex, **flox_kwargs @@ -269,11 +277,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): call() continue actual, *groups = call() - if "arg" not in func: - # make sure we use simple combine - assert any("simple-combine" in key for key in actual.dask.layers.keys()) - else: - assert any("grouped-combine" in key for key in actual.dask.layers.keys()) + if method != "blockwise": + if "arg" not in func: + # make sure we use simple combine + assert any("simple-combine" in key for key in actual.dask.layers.keys()) + else: + assert any("grouped-combine" in key for key in actual.dask.layers.keys()) for actual_group, expect in zip(groups, expected_groups): assert_equal(actual_group, expect, tolerance) if "arg" in func: