diff --git a/flox/core.py b/flox/core.py index dc121bf83..b199878d5 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1568,8 +1568,11 @@ def groupby_reduce( result = results[agg.name] else: - if agg.chunk is None: - raise NotImplementedError(f"{func} not implemented for dask arrays") + if agg.chunk[0] is None and method != "blockwise": + raise NotImplementedError( + f"Aggregation {func.name!r} is only implemented for dask arrays when method='blockwise'." + f"\n\n Received: {func}" + ) # we always need some fill_value (see above) so choose the default if needed if kwargs["fill_value"] is None: diff --git a/tests/test_core.py b/tests/test_core.py index 7f26bbc49..9c0bd6adb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,6 +5,7 @@ import pytest from numpy_groupies.aggregate_numpy import aggregate +from flox.aggregations import Aggregation from flox.core import ( _convert_expected_groups_to_index, _get_optimal_chunks_for_groups, @@ -964,3 +965,47 @@ def test_factorize_reindex_sorting_ints(): expected = factorize_(**kwargs, reindex=True, sort=False)[0] assert_equal(expected, [6, 4, 6, 3, 2, 0]) + + +@requires_dask +def test_custom_aggregation_blockwise(): + def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None): + return aggregate( + group_idx, + array, + func=np.median, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + agg_median = Aggregation( + name="median", numpy=grouped_median, fill_value=-1, chunk=None, combine=None + ) + + array = np.arange(100, dtype=np.float32).reshape(5, 20) + by = np.ones((20,)) + + actual, _ = groupby_reduce(array, by, func=agg_median, axis=-1) + expected = np.median(array, axis=-1, keepdims=True) + assert_equal(expected, actual) + + for method in ["map-reduce", "cohorts", "split-reduce"]: + with pytest.raises(NotImplementedError): + groupby_reduce( + dask.array.from_array(array, chunks=(1, -1)), + by, + func=agg_median, + axis=-1, + method=method, + ) + + actual, _ = groupby_reduce( + dask.array.from_array(array, chunks=(1, -1)), + by, + func=agg_median, + axis=-1, + method="blockwise", + ) + assert_equal(expected, actual)