Skip to content

Commit

Permalink
Limited implementation of map_overlap (#462)
Browse files Browse the repository at this point in the history
* Limited implementation of map_overlap

* Change to Array API function name `concat`
  • Loading branch information
tomwhite authored May 18, 2024
1 parent 6b99959 commit b57fa0c
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .core.gufunc import apply_gufunc
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
from .nan_functions import nanmean, nansum
from .overlap import map_overlap
from .runtime.types import Callback, TaskEndEvent
from .spec import Spec

Expand All @@ -33,6 +34,7 @@
"from_array",
"from_zarr",
"map_blocks",
"map_overlap",
"measure_reserved_mem",
"nanmean",
"nansum",
Expand Down
139 changes: 139 additions & 0 deletions cubed/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Tuple

from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_direct
from cubed.types import T_RectangularChunks
from cubed.utils import _cumsum
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.array.overlap import coerce_boundary, coerce_depth
from cubed.vendor.dask.utils import has_keyword


def map_overlap(
func,
*args,
dtype=None,
chunks=None,
depth=None,
boundary=None,
trim=False,
**kwargs,
):
"""Apply a function to corresponding blocks from multiple input arrays with some overlap.
Parameters
----------
func : callable
Function to apply to every block (with overlap) to produce the output array.
args : arrays
The Cubed arrays to map over. Note that currently only one array may be specified.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
Chunk shape of blocks in the output array.
depth : int, tuple, dict or list
The number of elements that each block should share with its neighbors.
boundary : value type, tuple, dict or list
How to handle the boundaries. Note that this currently only supports constant values.
trim : bool
Whether or not to trim ``depth`` elements from each block after calling the map function.
Currently only ``False`` is supported.
**kwargs : dict
Extra keyword arguments to pass to function.
"""
if trim:
raise ValueError("trim is not supported")

chunks = normalize_chunks(chunks, dtype=dtype)
shape = tuple(map(sum, chunks))

# Coerce depth and boundary arguments to lists of individual
# specifications for each array argument
def coerce(xs, arg, fn):
if not isinstance(arg, list):
arg = [arg] * len(xs)
return [fn(x.ndim, a) for x, a in zip(xs, arg)]

depth = coerce(args, depth, coerce_depth)
boundary = coerce(args, boundary, coerce_boundary)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = args[0].chunkmem # TODO: support multiple

has_block_id_kw = has_keyword(func, "block_id")

return map_direct(
_overlap,
*args,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
overlap_func=func,
depth=depth,
boundary=boundary,
has_block_id_kw=has_block_id_kw,
**kwargs,
)


def _overlap(
x,
*arrays,
overlap_func=None,
depth=None,
boundary=None,
has_block_id_kw=False,
block_id=None,
**kwargs,
):
a = arrays[0] # TODO: support multiple
depth = depth[0]
boundary = boundary[0]

# First read the chunk with overlaps determined by depth, then pad boundaries second.
# Do it this way round so we can do everything with one blockwise. The alternative,
# which pads the entire array first (via concatenate), would result in at least one extra copy.
out = a.zarray[get_item_with_depth(a.chunks, block_id, depth)]
out = _pad_boundaries(out, depth, boundary, a.numblocks, block_id)
if has_block_id_kw:
return overlap_func(out, block_id=block_id, **kwargs)
else:
return overlap_func(out, **kwargs)


def _clamp(minimum: int, x: int, maximum: int) -> int:
return max(minimum, min(x, maximum))


def get_item_with_depth(
chunks: T_RectangularChunks, idx: Tuple[int, ...], depth
) -> Tuple[slice, ...]:
"""Convert a chunk index to a tuple of slices with depth offsets."""
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks)
loc = tuple(
(
_clamp(0, start[i] - depth[ax], start[-1]),
_clamp(0, start[i + 1] + depth[ax], start[-1]),
)
for ax, (i, start) in enumerate(zip(idx, starts))
)
return tuple(slice(*s, None) for s in loc)


def _pad_boundaries(x, depth, boundary, numblocks, block_id):
for i in range(x.ndim):
d = depth.get(i, 0)
if d == 0 or block_id[i] not in (0, numblocks[i] - 1):
continue
pad_shape = list(x.shape)
pad_shape[i] = d
pad_shape = tuple(pad_shape)
p = nxp.full_like(x, fill_value=boundary[i], shape=pad_shape)
if block_id[i] == 0: # first block on axis i
x = nxp.concat([p, x], axis=i)
elif block_id[i] == numblocks[i] - 1: # last block on axis i
x = nxp.concat([x, p], axis=i)
return x
79 changes: 79 additions & 0 deletions cubed/tests/test_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp


def test_map_overlap_1d():
x = np.arange(6)
a = xp.asarray(x, chunks=(3,))

b = cubed.map_overlap(
lambda x: x,
a,
dtype=a.dtype,
chunks=((5, 5),),
depth=1,
boundary=0,
trim=False,
)

assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0]))


def test_map_overlap_2d():
x = np.arange(36).reshape((6, 6))
a = xp.asarray(x, chunks=(3, 3))

b = cubed.map_overlap(
lambda x: x,
a,
dtype=a.dtype,
chunks=((7, 7), (5, 5)),
depth={0: 2, 1: 1},
boundary={0: 100, 1: 200},
trim=False,
)

expected = np.array(
[
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 0, 1, 2, 3, 2, 3, 4, 5, 200],
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
[200, 30, 31, 32, 33, 32, 33, 34, 35, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
]
)

assert_array_equal(b.compute(), expected)


def test_map_overlap_trim():
x = np.array([1, 1, 2, 3, 5, 8, 13, 21])
a = xp.asarray(x, chunks=5)

def derivative(x):
out = x - np.roll(x, 1)
return out[1:-1] # manual trim

b = cubed.map_overlap(
derivative,
a,
dtype=a.dtype,
chunks=a.chunks,
depth=1,
boundary=0,
trim=False,
)

assert_array_equal(b.compute(), np.array([1, 0, 1, 1, 2, 3, 5, 8]))
38 changes: 38 additions & 0 deletions cubed/vendor/dask/array/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from numbers import Integral


def coerce_depth(ndim, depth):
default = 0
if depth is None:
depth = default
if isinstance(depth, Integral):
depth = (depth,) * ndim
if isinstance(depth, tuple):
depth = dict(zip(range(ndim), depth))
if isinstance(depth, dict):
depth = {ax: depth.get(ax, default) for ax in range(ndim)}
return coerce_depth_type(ndim, depth)


def coerce_depth_type(ndim, depth):
for i in range(ndim):
if isinstance(depth[i], tuple):
depth[i] = tuple(int(d) for d in depth[i])
else:
depth[i] = int(depth[i])
return depth


def coerce_boundary(ndim, boundary):
default = "none"
if boundary is None:
boundary = default
if not isinstance(boundary, (tuple, dict)):
boundary = (boundary,) * ndim
if isinstance(boundary, tuple):
boundary = dict(zip(range(ndim), boundary))
if isinstance(boundary, dict):
boundary = {ax: boundary.get(ax, default) for ax in range(ndim)}
return boundary
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Chunk-specific functions

apply_gufunc
map_blocks
map_overlap

Non-standardised functions
==========================
Expand Down

0 comments on commit b57fa0c

Please sign in to comment.