diff --git a/asv_bench/benchmarks/cohorts.py b/asv_bench/benchmarks/cohorts.py index 9062611f1..2c19c881f 100644 --- a/asv_bench/benchmarks/cohorts.py +++ b/asv_bench/benchmarks/cohorts.py @@ -29,7 +29,22 @@ def track_num_tasks(self): )[0] return len(result.dask.to_dict()) + def track_num_tasks_optimized(self): + result = flox.groupby_reduce( + self.array, self.by, func="sum", axis=self.axis, method="cohorts" + )[0] + (opt,) = dask.optimize(result) + return len(opt.dask.to_dict()) + + def track_num_layers(self): + result = flox.groupby_reduce( + self.array, self.by, func="sum", axis=self.axis, method="cohorts" + )[0] + return len(result.dask.layers) + track_num_tasks.unit = "tasks" + track_num_tasks_optimized.unit = "tasks" + track_num_layers.unit = "layers" class NWMMidwest(Cohorts): @@ -45,16 +60,68 @@ def setup(self, *args, **kwargs): self.axis = (-2, -1) -class ERA5(Cohorts): +class ERA5Dataset: """ERA5""" + def __init__(self, *args, **kwargs): + self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H")) + self.axis = (-1,) + self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48)) + + def rechunk(self): + self.array = flox.core.rechunk_for_cohorts( + self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True + ) + + +class ERA5DayOfYear(ERA5Dataset, Cohorts): + def setup(self, *args, **kwargs): + super().__init__() + self.by = self.time.dt.dayofyear.values + + +class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts): + def setup(self, *args, **kwargs): + super().setup() + super().rechunk() + + +class ERA5MonthHour(ERA5Dataset, Cohorts): def setup(self, *args, **kwargs): - time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H")) + super().__init__() + by = (self.time.dt.month.values, self.time.dt.hour.values) + ret = flox.core._factorize_multiple( + by, + expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))), + by_is_dask=False, + reindex=False, + ) + # Add one so the rechunk code is simpler and makes sense + self.by = ret[0][0] + 1 - self.by = time.dt.dayofyear.values + +class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts): + def setup(self, *args, **kwargs): + super().setup() + super().rechunk() + + +class PerfectMonthly(Cohorts): + """Perfectly chunked for a "cohorts" monthly mean climatology""" + + def setup(self, *args, **kwargs): + self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M")) self.axis = (-1,) + self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4)) + self.by = self.time.dt.month.values - array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48)) + def rechunk(self): self.array = flox.core.rechunk_for_cohorts( - array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True + self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True ) + + +class PerfectMonthlyRechunked(PerfectMonthly): + def setup(self, *args, **kwargs): + super().setup() + super().rechunk() diff --git a/flox/core.py b/flox/core.py index 0e2b73ac9..6bd390137 100644 --- a/flox/core.py +++ b/flox/core.py @@ -6,6 +6,7 @@ import operator from collections import namedtuple from functools import partial, reduce +from numbers import Integral from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union import numpy as np @@ -288,7 +289,7 @@ def rechunk_for_cohorts( divisions = [] counter = 1 for idx, lab in enumerate(labels): - if lab in force_new_chunk_at: + if lab in force_new_chunk_at or idx == 0: divisions.append(idx) counter = 1 continue @@ -305,6 +306,7 @@ def rechunk_for_cohorts( divisions.append(idx) counter = 1 continue + counter += 1 divisions.append(len(labels)) @@ -313,6 +315,9 @@ def rechunk_for_cohorts( print(labels_at_breaks[:40]) newchunks = tuple(np.diff(divisions)) + if debug: + print(divisions[:10], newchunks[:10]) + print(divisions[-10:], newchunks[-10:]) assert sum(newchunks) == len(labels) if newchunks == array.chunks[axis]: @@ -1046,26 +1051,18 @@ def _reduce_blockwise( return result -def subset_to_blocks( - array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None -) -> DaskArray: +def _normalize_indexes(array, flatblocks, blkshape): """ - Advanced indexing of .blocks such that we always get a regular array back. + .blocks accessor can only accept one iterable at a time, + but can handle multiple slices. + To minimize tasks and layers, we normalize to produce slices + along as many axes as possible, and then repeatedly apply + any remaining iterables in a loop. - Parameters - ---------- - array : dask.array - flatblocks : flat indices of blocks to extract - blkshape : shape of blocks with which to unravel flatblocks - - Returns - ------- - dask.array + TODO: move this upstream """ - if blkshape is None: - blkshape = array.blocks.shape - unraveled = np.unravel_index(flatblocks, blkshape) + normalized: list[Union[int, np.ndarray, slice]] = [] for ax, idx in enumerate(unraveled): i = _unique(idx).squeeze() @@ -1077,30 +1074,65 @@ def subset_to_blocks( elif np.array_equal(i, np.arange(i[0], i[-1] + 1)): normalized.append(slice(i[0], i[-1] + 1)) else: - normalized.append(i) + normalized.append(list(i)) full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized) # has no iterables - noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) + noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized) # has all iterables - alliter = { - ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized) - } + alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")} - # apply everything but the iterables - if all(i == slice(None) for i in noiter): + mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values()))) + + full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter)) + + return full_tuple + + +def subset_to_blocks( + array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None +) -> DaskArray: + """ + Advanced indexing of .blocks such that we always get a regular array back. + + Parameters + ---------- + array : dask.array + flatblocks : flat indices of blocks to extract + blkshape : shape of blocks with which to unravel flatblocks + + Returns + ------- + dask.array + """ + import dask.array + from dask.array.slicing import normalize_index + from dask.base import tokenize + from dask.highlevelgraph import HighLevelGraph + + if blkshape is None: + blkshape = array.blocks.shape + + index = _normalize_indexes(array, flatblocks, blkshape) + + if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index): return array - subset = array.blocks[noiter] + # These rest is copied from dask.array.core.py with slight modifications + index = normalize_index(index, array.numblocks) + index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index) - for ax, inds in alliter.items(): - if isinstance(inds, slice): - continue - idxr = [slice(None, None)] * array.ndim - idxr[ax] = inds - subset = subset.blocks[tuple(idxr)] + name = "blocks-" + tokenize(array, index) + new_keys = array._key_array[index] + + squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) + chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed)) + + keys = itertools.product(*(range(len(c)) for c in chunks)) + layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} + graph = HighLevelGraph.from_collections(name, layer, dependencies=[array]) - return subset + return dask.array.Array(graph, name, chunks, meta=array) def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]: diff --git a/tests/__init__.py b/tests/__init__.py index 0cd967d11..b1a266652 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -115,6 +115,18 @@ def assert_equal(a, b, tolerance=None): np.testing.assert_allclose(a, b, equal_nan=True, **tolerance) +def assert_equal_tuple(a, b): + """assert_equal for .blocks indexing tuples""" + assert len(a) == len(b) + + for a_, b_ in zip(a, b): + assert type(a_) == type(b_) + if isinstance(a_, np.ndarray): + np.testing.assert_array_equal(a_, b_) + else: + assert a_ == b_ + + @pytest.fixture(scope="module", params=["flox", "numpy", "numba"]) def engine(request): if request.param == "numba": diff --git a/tests/test_core.py b/tests/test_core.py index f9d412182..e31f11e56 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,14 +12,23 @@ from flox.core import ( _convert_expected_groups_to_index, _get_optimal_chunks_for_groups, + _normalize_indexes, factorize_, find_group_cohorts, groupby_reduce, rechunk_for_cohorts, reindex_, + subset_to_blocks, ) -from . import assert_equal, engine, has_dask, raise_if_dask_computes, requires_dask +from . import ( + assert_equal, + assert_equal_tuple, + engine, + has_dask, + raise_if_dask_computes, + requires_dask, +) labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0]) nan_labels = labels.astype(float) # copy @@ -1035,3 +1044,84 @@ def test_dtype(func, dtype, engine): labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64) assert actual.dtype == np.dtype("float64") + + +@requires_dask +def test_subset_blocks(): + array = dask.array.random.random((120,), chunks=(4,)) + + blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27) + subset = subset_to_blocks(array, blockid) + assert subset.blocks.shape == (len(blockid),) + + +@requires_dask +@pytest.mark.parametrize( + "flatblocks, expected", + ( + ((0, 1, 2, 3, 4), (slice(None),)), + ((1, 2, 3), (slice(1, 4),)), + ((1, 3), ([1, 3],)), + ((0, 1, 3), ([0, 1, 3],)), + ), +) +def test_normalize_block_indexing_1d(flatblocks, expected): + nblocks = 5 + array = dask.array.ones((nblocks,), chunks=(1,)) + expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected) + actual = _normalize_indexes(array, flatblocks, array.blocks.shape) + assert_equal_tuple(expected, actual) + + +@requires_dask +@pytest.mark.parametrize( + "flatblocks, expected", + ( + ((0, 1, 2, 3, 4), (0, slice(None))), + ((1, 2, 3), (0, slice(1, 4))), + ((1, 3), (0, [1, 3])), + ((0, 1, 3), (0, [0, 1, 3])), + (tuple(range(10)), (slice(0, 2), slice(None))), + ((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])), + ((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])), + ), +) +def test_normalize_block_indexing_2d(flatblocks, expected): + nblocks = 5 + ndim = 2 + array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim) + expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected) + actual = _normalize_indexes(array, flatblocks, array.blocks.shape) + assert_equal_tuple(expected, actual) + + +@requires_dask +def test_subset_block_passthrough(): + # full slice pass through + array = dask.array.ones((5,), chunks=(1,)) + subset = subset_to_blocks(array, np.arange(5)) + assert subset.name == array.name + + array = dask.array.ones((5, 5), chunks=1) + subset = subset_to_blocks(array, np.arange(25)) + assert subset.name == array.name + + +@requires_dask +@pytest.mark.parametrize( + "flatblocks, expectidx", + [ + (np.arange(10), (slice(2), slice(None))), + (np.arange(8), (slice(2), slice(None))), + ([0, 10], ([0, 2], slice(1))), + ([0, 7], (slice(2), [0, 2])), + ([0, 7, 9], (slice(2), [0, 2, 4])), + ([0, 6, 12, 14], (slice(3), [0, 1, 2, 4])), + ([0, 12, 14, 19], np.ix_([0, 2, 3], [0, 2, 4])), + ], +) +def test_subset_block_2d(flatblocks, expectidx): + array = dask.array.from_array(np.arange(25).reshape((5, 5)), chunks=1) + subset = subset_to_blocks(array, flatblocks) + assert len(subset.dask.layers) == 2 + assert_equal(subset, array.compute()[expectidx])