From 4a221de460380b3bf9f8c3675345887a6e9fe50c Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 17 Sep 2024 19:41:43 -0600 Subject: [PATCH] Faster subsetting for cohorts Closes #396 --- flox/core.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/flox/core.py b/flox/core.py index 7e5362e1..1fb91dc2 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple: def subset_to_blocks( array: DaskArray, flatblocks: Sequence[int], - blkshape: tuple[int] | None = None, + blkshape: tuple[int, ...] | None = None, reindexer=identity, + chunks_as_array: tuple[int, ...] | None = None, ) -> DaskArray: """ Advanced indexing of .blocks such that we always get a regular array back. @@ -1518,6 +1519,9 @@ def subset_to_blocks( if blkshape is None: blkshape = array.blocks.shape + if chunks_as_array is None: + chunks_as_array = tuple(np.array(c) for c in array.chunks) + index = _normalize_indexes(array, flatblocks, blkshape) if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index): @@ -1531,7 +1535,7 @@ def subset_to_blocks( new_keys = array._key_array[index] squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index) - chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed)) + chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed)) keys = itertools.product(*(range(len(c)) for c in chunks)) layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys} @@ -1726,6 +1730,7 @@ def dask_groupby_agg( reduced_ = [] groups_ = [] + chunks_as_array = tuple(np.array(c) for c in array.chunks) for blks, cohort in chunks_cohorts.items(): cohort_index = pd.Index(cohort) reindexer = ( @@ -1733,7 +1738,7 @@ def dask_groupby_agg( if do_simple_combine else identity ) - reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer) + reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array) # now that we have reindexed, we can set reindex=True explicitlly reduced_.append( tree_reduce(