Skip to content

Commit

Permalink
Implement partial_reduce and tree_reduce using generalized blockwise
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jan 29, 2024
1 parent 5bd8108 commit abc0947
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 5 deletions.
5 changes: 3 additions & 2 deletions cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False):
return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False):
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -46,6 +46,7 @@ def mean(x, /, *, axis=None, keepdims=False):
intermediate_dtype=intermediate_dtype,
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
extra_func_kwargs=extra_func_kwargs,
)

Expand All @@ -64,7 +65,7 @@ def _mean_combine(a, **kwargs):
return {"n": n, "total": total}


def _mean_aggregate(a):
def _mean_aggregate(a, **kwargs):
return nxp.divide(a["total"], a["n"])


Expand Down
188 changes: 188 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import builtins
import math
import numbers
from functools import partial
from itertools import product
from numbers import Integral, Number
from operator import add
from typing import TYPE_CHECKING, Any, Sequence, Union
Expand Down Expand Up @@ -793,9 +796,24 @@ def reduction(
intermediate_dtype=None,
dtype=None,
keepdims=False,
use_new_impl=False,
split_every=None,
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if use_new_impl:
return reduction_new(
x,
func,
combine_func,
aggegrate_func,
axis,
intermediate_dtype,
dtype,
keepdims,
split_every,
extra_func_kwargs,
)
if combine_func is None:
combine_func = func
if axis is None:
Expand Down Expand Up @@ -885,6 +903,176 @@ def reduction(
return result


def reduction_new(
x: "Array",
func,
combine_func=None,
aggegrate_func=None,
axis=None,
intermediate_dtype=None,
dtype=None,
keepdims=False,
split_every=None,
extra_func_kwargs=None,
) -> "Array":
"""Apply a function to reduce an array along one or more axes."""
if combine_func is None:
combine_func = func
if axis is None:
axis = tuple(range(x.ndim))
if isinstance(axis, Integral):
axis = (axis,)
axis = validate_axis(axis, x.ndim)
if intermediate_dtype is None:
intermediate_dtype = dtype

inds = tuple(range(x.ndim))

result = x

# reduce initial chunks
args = (result, inds)
adjust_chunks = {
i: (1,) * len(c) if i in axis else c for i, c in enumerate(result.chunks)
}
result = blockwise(
func,
inds,
*args,
axis=axis,
keepdims=True,
dtype=intermediate_dtype,
adjust_chunks=adjust_chunks,
extra_func_kwargs=extra_func_kwargs,
)

# combine intermediates
result = tree_reduce(
result,
partial(combine_func, **(extra_func_kwargs or {})),
axis=axis,
dtype=intermediate_dtype,
)

# aggregate final chunks
if aggegrate_func is not None:
result = map_blocks(aggegrate_func, result, dtype=dtype)

if not keepdims:
axis_to_squeeze = tuple(i for i in axis if result.shape[i] == 1)
if len(axis_to_squeeze) > 0:
result = squeeze(result, axis_to_squeeze)

from cubed.array_api import astype

result = astype(result, dtype, copy=False)

return result


def _normalize_split_every(split_every, axis):
split_every = split_every or 4
if isinstance(split_every, dict):
split_every = {k: split_every.get(k, 2) for k in axis}
elif isinstance(split_every, Integral):
n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
split_every = dict.fromkeys(axis, n)
else:
raise ValueError("split_every must be a int or a dict")
return split_every


def tree_reduce(
x,
func,
axis,
dtype,
split_every=None,
):
"""Apply a reduction function repeatedly across multiple axes."""
if axis is None:
axis = tuple(range(x.ndim))
if isinstance(axis, Integral):
axis = (axis,)
axis = validate_axis(axis, x.ndim)

split_every = _normalize_split_every(split_every, axis)

depth = 0
for i, n in enumerate(x.numblocks):
if i in split_every and split_every[i] != 1:
depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
for _ in range(depth):
x = partial_reduce(
x,
func,
split_every=split_every,
dtype=dtype,
)
return x


def partial_reduce(x, func, split_every, dtype=None):
"""Apply a reduction function to multiple blocks across multiple axes."""
# map over output chunks
chunks = [
(1,) * math.ceil(len(c) / split_every[i]) if i in split_every else c
for (i, c) in enumerate(x.chunks)
]
shape = tuple(map(sum, chunks))
axis = tuple(ax for ax in split_every.keys())

def block_function(out_key):
out_coords = out_key[1:]

# return a tuple with a single item that is the list of input keys to be merged
in_keys = [
list(
range(
bi * split_every.get(i, 1),
min((bi + 1) * split_every.get(i, 1), x.numblocks[i]),
)
)
for i, bi in enumerate(out_coords)
]
return ([(x.name,) + tuple(p) for p in product(*in_keys)],)

return general_blockwise(
_partial_reduce,
block_function,
x,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=0,
reduce_func=func,
axis=axis,
)


def _partial_reduce(arrays, reduce_func=None, axis=None):
# reduce each array in turn, accumulating in result
result = None
for array in arrays:
reduced_chunk = reduce_func(array, axis=axis, keepdims=True)
if result is None:
result = reduced_chunk
elif isinstance(result, dict):
assert result.keys() == reduced_chunk.keys()
result = {
# only need to concatenate along first axis
k: np.concatenate([result[k], reduced_chunk[k]], axis=axis[0])
for k in result.keys()
}
result = reduce_func(result, axis=axis, keepdims=True)
else:
# only need to concatenate along first axis
result = np.concatenate([result, reduced_chunk], axis=axis[0])
result = reduce_func(result, axis=axis, keepdims=True)

return result


def arg_reduction(x, /, arg_func, axis=None, *, keepdims=False):
"""A reduction that returns the array indexes, not the values."""
dtype = np.int64 # index data type
Expand Down
5 changes: 3 additions & 2 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,12 @@ def test_argmin_axis_0(spec):
# Statistical functions


def test_mean_axis_0(spec, executor):
@pytest.mark.parametrize("use_new_impl", [False, True])
def test_mean_axis_0(spec, executor, use_new_impl):
a = xp.asarray(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], chunks=(2, 2), spec=spec
)
b = xp.mean(a, axis=0)
b = xp.mean(a, axis=0, use_new_impl=use_new_impl)
assert_array_equal(
b.compute(executor=executor),
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]).mean(axis=0),
Expand Down
19 changes: 18 additions & 1 deletion cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import merge_chunks
from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce
from cubed.tests.utils import (
ALL_EXECUTORS,
MAIN_EXECUTORS,
Expand Down Expand Up @@ -352,6 +352,23 @@ def test_reduction_not_enough_memory(tmp_path):
xp.sum(a, axis=0, dtype=np.uint8)


def test_partial_reduce(spec):
a = xp.asarray(np.arange(242).reshape((11, 22)), chunks=(3, 4), spec=spec)
b = partial_reduce(a, np.sum, split_every={0: 2})
c = partial_reduce(b, np.sum, split_every={0: 2})
assert_array_equal(
c.compute(), np.arange(242).reshape((11, 22)).sum(axis=0, keepdims=True)
)


def test_tree_reduce(spec):
a = xp.asarray(np.arange(242).reshape((11, 22)), chunks=(3, 4), spec=spec)
b = tree_reduce(a, np.sum, axis=0, dtype=np.int64, split_every={0: 2})
assert_array_equal(
b.compute(), np.arange(242).reshape((11, 22)).sum(axis=0, keepdims=True)
)


@pytest.mark.parametrize(
"target_chunks, expected_chunksize",
[
Expand Down

0 comments on commit abc0947

Please sign in to comment.