Skip to content

Commit

Permalink
Faster subsetting for cohorts
Browse files Browse the repository at this point in the history
Closes #396
  • Loading branch information
dcherian committed Sep 18, 2024
1 parent 8556811 commit 4a221de
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,8 +1494,9 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
def subset_to_blocks(
array: DaskArray,
flatblocks: Sequence[int],
blkshape: tuple[int] | None = None,
blkshape: tuple[int, ...] | None = None,
reindexer=identity,
chunks_as_array: tuple[int, ...] | None = None,
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Expand All @@ -1518,6 +1519,9 @@ def subset_to_blocks(
if blkshape is None:
blkshape = array.blocks.shape

if chunks_as_array is None:
chunks_as_array = tuple(np.array(c) for c in array.chunks)

index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
Expand All @@ -1531,7 +1535,7 @@ def subset_to_blocks(
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))
chunks = tuple(tuple(c[i].tolist()) for c, i in zip(chunks_as_array, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}
Expand Down Expand Up @@ -1726,14 +1730,15 @@ def dask_groupby_agg(

reduced_ = []
groups_ = []
chunks_as_array = tuple(np.array(c) for c in array.chunks)
for blks, cohort in chunks_cohorts.items():
cohort_index = pd.Index(cohort)
reindexer = (
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
if do_simple_combine
else identity
)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer, chunks_as_array)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
Expand Down

0 comments on commit 4a221de

Please sign in to comment.