Skip to content

Commit

Permalink
Major fix to subset_to_blocks (#173)
Browse files Browse the repository at this point in the history
1. Copy and extend `dask.array.blocks.__getitem__` to support orthogonal indexing. 
    This means each cohort is a single layer in the graph.
2. Significantly extend cohorts.py benchmarks
  • Loading branch information
dcherian authored Oct 16, 2022
1 parent 0bf35e0 commit 6897240
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 38 deletions.
77 changes: 72 additions & 5 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,22 @@ def track_num_tasks(self):
)[0]
return len(result.dask.to_dict())

def track_num_tasks_optimized(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
(opt,) = dask.optimize(result)
return len(opt.dask.to_dict())

def track_num_layers(self):
result = flox.groupby_reduce(
self.array, self.by, func="sum", axis=self.axis, method="cohorts"
)[0]
return len(result.dask.layers)

track_num_tasks.unit = "tasks"
track_num_tasks_optimized.unit = "tasks"
track_num_layers.unit = "layers"


class NWMMidwest(Cohorts):
Expand All @@ -45,16 +60,68 @@ def setup(self, *args, **kwargs):
self.axis = (-2, -1)


class ERA5(Cohorts):
class ERA5Dataset:
"""ERA5"""

def __init__(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 48))

def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
)


class ERA5DayOfYear(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
super().__init__()
self.by = self.time.dt.dayofyear.values


class ERA5DayOfYearRechunked(ERA5DayOfYear, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


class ERA5MonthHour(ERA5Dataset, Cohorts):
def setup(self, *args, **kwargs):
time = pd.Series(pd.date_range("2016-01-01", "2018-12-31 23:59", freq="H"))
super().__init__()
by = (self.time.dt.month.values, self.time.dt.hour.values)
ret = flox.core._factorize_multiple(
by,
expected_groups=(pd.Index(np.arange(1, 13)), pd.Index(np.arange(1, 25))),
by_is_dask=False,
reindex=False,
)
# Add one so the rechunk code is simpler and makes sense
self.by = ret[0][0] + 1

self.by = time.dt.dayofyear.values

class ERA5MonthHourRechunked(ERA5MonthHour, Cohorts):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()


class PerfectMonthly(Cohorts):
"""Perfectly chunked for a "cohorts" monthly mean climatology"""

def setup(self, *args, **kwargs):
self.time = pd.Series(pd.date_range("1961-01-01", "2018-12-31 23:59", freq="M"))
self.axis = (-1,)
self.array = dask.array.random.random((721, 1440, len(self.time)), chunks=(-1, -1, 4))
self.by = self.time.dt.month.values

array = dask.array.random.random((721, 1440, len(time)), chunks=(-1, -1, 48))
def rechunk(self):
self.array = flox.core.rechunk_for_cohorts(
array, -1, self.by, force_new_chunk_at=[1], chunksize=48, ignore_old_chunks=True
self.array, -1, self.by, force_new_chunk_at=[1], chunksize=4, ignore_old_chunks=True
)


class PerfectMonthlyRechunked(PerfectMonthly):
def setup(self, *args, **kwargs):
super().setup()
super().rechunk()
96 changes: 64 additions & 32 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import operator
from collections import namedtuple
from functools import partial, reduce
from numbers import Integral
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Mapping, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -288,7 +289,7 @@ def rechunk_for_cohorts(
divisions = []
counter = 1
for idx, lab in enumerate(labels):
if lab in force_new_chunk_at:
if lab in force_new_chunk_at or idx == 0:
divisions.append(idx)
counter = 1
continue
Expand All @@ -305,6 +306,7 @@ def rechunk_for_cohorts(
divisions.append(idx)
counter = 1
continue

counter += 1

divisions.append(len(labels))
Expand All @@ -313,6 +315,9 @@ def rechunk_for_cohorts(
print(labels_at_breaks[:40])

newchunks = tuple(np.diff(divisions))
if debug:
print(divisions[:10], newchunks[:10])
print(divisions[-10:], newchunks[-10:])
assert sum(newchunks) == len(labels)

if newchunks == array.chunks[axis]:
Expand Down Expand Up @@ -1046,26 +1051,18 @@ def _reduce_blockwise(
return result


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
def _normalize_indexes(array, flatblocks, blkshape):
"""
Advanced indexing of .blocks such that we always get a regular array back.
.blocks accessor can only accept one iterable at a time,
but can handle multiple slices.
To minimize tasks and layers, we normalize to produce slices
along as many axes as possible, and then repeatedly apply
any remaining iterables in a loop.
Parameters
----------
array : dask.array
flatblocks : flat indices of blocks to extract
blkshape : shape of blocks with which to unravel flatblocks
Returns
-------
dask.array
TODO: move this upstream
"""
if blkshape is None:
blkshape = array.blocks.shape

unraveled = np.unravel_index(flatblocks, blkshape)

normalized: list[Union[int, np.ndarray, slice]] = []
for ax, idx in enumerate(unraveled):
i = _unique(idx).squeeze()
Expand All @@ -1077,30 +1074,65 @@ def subset_to_blocks(
elif np.array_equal(i, np.arange(i[0], i[-1] + 1)):
normalized.append(slice(i[0], i[-1] + 1))
else:
normalized.append(i)
normalized.append(list(i))
full_normalized = (slice(None),) * (array.ndim - len(normalized)) + tuple(normalized)

# has no iterables
noiter = tuple(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
noiter = list(i if not hasattr(i, "__len__") else slice(None) for i in full_normalized)
# has all iterables
alliter = {
ax: i if hasattr(i, "__len__") else slice(None) for ax, i in enumerate(full_normalized)
}
alliter = {ax: i for ax, i in enumerate(full_normalized) if hasattr(i, "__len__")}

# apply everything but the iterables
if all(i == slice(None) for i in noiter):
mesh = dict(zip(alliter.keys(), np.ix_(*alliter.values())))

full_tuple = tuple(i if ax not in mesh else mesh[ax] for ax, i in enumerate(noiter))

return full_tuple


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Parameters
----------
array : dask.array
flatblocks : flat indices of blocks to extract
blkshape : shape of blocks with which to unravel flatblocks
Returns
-------
dask.array
"""
import dask.array
from dask.array.slicing import normalize_index
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

if blkshape is None:
blkshape = array.blocks.shape

index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
return array

subset = array.blocks[noiter]
# These rest is copied from dask.array.core.py with slight modifications
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

for ax, inds in alliter.items():
if isinstance(inds, slice):
continue
idxr = [slice(None, None)] * array.ndim
idxr[ax] = inds
subset = subset.blocks[tuple(idxr)]
name = "blocks-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return subset
return dask.array.Array(graph, name, chunks, meta=array)


def _extract_unknown_groups(reduced, group_chunks, dtype) -> tuple[DaskArray]:
Expand Down
12 changes: 12 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def assert_equal(a, b, tolerance=None):
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)


def assert_equal_tuple(a, b):
"""assert_equal for .blocks indexing tuples"""
assert len(a) == len(b)

for a_, b_ in zip(a, b):
assert type(a_) == type(b_)
if isinstance(a_, np.ndarray):
np.testing.assert_array_equal(a_, b_)
else:
assert a_ == b_


@pytest.fixture(scope="module", params=["flox", "numpy", "numba"])
def engine(request):
if request.param == "numba":
Expand Down
92 changes: 91 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,23 @@
from flox.core import (
_convert_expected_groups_to_index,
_get_optimal_chunks_for_groups,
_normalize_indexes,
factorize_,
find_group_cohorts,
groupby_reduce,
rechunk_for_cohorts,
reindex_,
subset_to_blocks,
)

from . import assert_equal, engine, has_dask, raise_if_dask_computes, requires_dask
from . import (
assert_equal,
assert_equal_tuple,
engine,
has_dask,
raise_if_dask_computes,
requires_dask,
)

labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
nan_labels = labels.astype(float) # copy
Expand Down Expand Up @@ -1035,3 +1044,84 @@ def test_dtype(func, dtype, engine):
labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"])
actual, _ = groupby_reduce(arr, labels, func=func, dtype=np.float64)
assert actual.dtype == np.dtype("float64")


@requires_dask
def test_subset_blocks():
array = dask.array.random.random((120,), chunks=(4,))

blockid = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27)
subset = subset_to_blocks(array, blockid)
assert subset.blocks.shape == (len(blockid),)


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expected",
(
((0, 1, 2, 3, 4), (slice(None),)),
((1, 2, 3), (slice(1, 4),)),
((1, 3), ([1, 3],)),
((0, 1, 3), ([0, 1, 3],)),
),
)
def test_normalize_block_indexing_1d(flatblocks, expected):
nblocks = 5
array = dask.array.ones((nblocks,), chunks=(1,))
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
assert_equal_tuple(expected, actual)


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expected",
(
((0, 1, 2, 3, 4), (0, slice(None))),
((1, 2, 3), (0, slice(1, 4))),
((1, 3), (0, [1, 3])),
((0, 1, 3), (0, [0, 1, 3])),
(tuple(range(10)), (slice(0, 2), slice(None))),
((0, 1, 3, 5, 6, 8), (slice(0, 2), [0, 1, 3])),
((0, 3, 4, 5, 6, 8, 24), np.ix_([0, 1, 4], [0, 1, 3, 4])),
),
)
def test_normalize_block_indexing_2d(flatblocks, expected):
nblocks = 5
ndim = 2
array = dask.array.ones((nblocks,) * ndim, chunks=(1,) * ndim)
expected = tuple(np.array(i) if isinstance(i, list) else i for i in expected)
actual = _normalize_indexes(array, flatblocks, array.blocks.shape)
assert_equal_tuple(expected, actual)


@requires_dask
def test_subset_block_passthrough():
# full slice pass through
array = dask.array.ones((5,), chunks=(1,))
subset = subset_to_blocks(array, np.arange(5))
assert subset.name == array.name

array = dask.array.ones((5, 5), chunks=1)
subset = subset_to_blocks(array, np.arange(25))
assert subset.name == array.name


@requires_dask
@pytest.mark.parametrize(
"flatblocks, expectidx",
[
(np.arange(10), (slice(2), slice(None))),
(np.arange(8), (slice(2), slice(None))),
([0, 10], ([0, 2], slice(1))),
([0, 7], (slice(2), [0, 2])),
([0, 7, 9], (slice(2), [0, 2, 4])),
([0, 6, 12, 14], (slice(3), [0, 1, 2, 4])),
([0, 12, 14, 19], np.ix_([0, 2, 3], [0, 2, 4])),
],
)
def test_subset_block_2d(flatblocks, expectidx):
array = dask.array.from_array(np.arange(25).reshape((5, 5)), chunks=1)
subset = subset_to_blocks(array, flatblocks)
assert len(subset.dask.layers) == 2
assert_equal(subset, array.compute()[expectidx])

0 comments on commit 6897240

Please sign in to comment.