Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limited implementation of map_overlap #462

Merged
merged 2 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .core.gufunc import apply_gufunc
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
from .nan_functions import nanmean, nansum
from .overlap import map_overlap
from .runtime.types import Callback, TaskEndEvent
from .spec import Spec

Expand All @@ -33,6 +34,7 @@
"from_array",
"from_zarr",
"map_blocks",
"map_overlap",
"measure_reserved_mem",
"nanmean",
"nansum",
Expand Down
139 changes: 139 additions & 0 deletions cubed/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Tuple

from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import map_direct
from cubed.types import T_RectangularChunks
from cubed.utils import _cumsum
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.array.overlap import coerce_boundary, coerce_depth
from cubed.vendor.dask.utils import has_keyword


def map_overlap(
func,
*args,
dtype=None,
chunks=None,
depth=None,
boundary=None,
trim=False,
**kwargs,
):
"""Apply a function to corresponding blocks from multiple input arrays with some overlap.

Parameters
----------
func : callable
Function to apply to every block (with overlap) to produce the output array.
args : arrays
The Cubed arrays to map over. Note that currently only one array may be specified.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
Chunk shape of blocks in the output array.
depth : int, tuple, dict or list
The number of elements that each block should share with its neighbors.
boundary : value type, tuple, dict or list
How to handle the boundaries. Note that this currently only supports constant values.
trim : bool
Whether or not to trim ``depth`` elements from each block after calling the map function.
Currently only ``False`` is supported.
**kwargs : dict
Extra keyword arguments to pass to function.
"""
if trim:
raise ValueError("trim is not supported")

chunks = normalize_chunks(chunks, dtype=dtype)
shape = tuple(map(sum, chunks))

# Coerce depth and boundary arguments to lists of individual
# specifications for each array argument
def coerce(xs, arg, fn):
if not isinstance(arg, list):
arg = [arg] * len(xs)
return [fn(x.ndim, a) for x, a in zip(xs, arg)]

depth = coerce(args, depth, coerce_depth)
boundary = coerce(args, boundary, coerce_boundary)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = args[0].chunkmem # TODO: support multiple

has_block_id_kw = has_keyword(func, "block_id")

return map_direct(
_overlap,
*args,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
overlap_func=func,
depth=depth,
boundary=boundary,
has_block_id_kw=has_block_id_kw,
**kwargs,
)


def _overlap(
x,
*arrays,
overlap_func=None,
depth=None,
boundary=None,
has_block_id_kw=False,
block_id=None,
**kwargs,
):
a = arrays[0] # TODO: support multiple
depth = depth[0]
boundary = boundary[0]

# First read the chunk with overlaps determined by depth, then pad boundaries second.
# Do it this way round so we can do everything with one blockwise. The alternative,
# which pads the entire array first (via concatenate), would result in at least one extra copy.
out = a.zarray[get_item_with_depth(a.chunks, block_id, depth)]
out = _pad_boundaries(out, depth, boundary, a.numblocks, block_id)
if has_block_id_kw:
return overlap_func(out, block_id=block_id, **kwargs)
else:
return overlap_func(out, **kwargs)


def _clamp(minimum: int, x: int, maximum: int) -> int:
return max(minimum, min(x, maximum))


def get_item_with_depth(
chunks: T_RectangularChunks, idx: Tuple[int, ...], depth
) -> Tuple[slice, ...]:
"""Convert a chunk index to a tuple of slices with depth offsets."""
starts = tuple(_cumsum(c, initial_zero=True) for c in chunks)
loc = tuple(
(
_clamp(0, start[i] - depth[ax], start[-1]),
_clamp(0, start[i + 1] + depth[ax], start[-1]),
)
for ax, (i, start) in enumerate(zip(idx, starts))
)
return tuple(slice(*s, None) for s in loc)


def _pad_boundaries(x, depth, boundary, numblocks, block_id):
for i in range(x.ndim):
d = depth.get(i, 0)
if d == 0 or block_id[i] not in (0, numblocks[i] - 1):
continue
pad_shape = list(x.shape)
pad_shape[i] = d
pad_shape = tuple(pad_shape)
p = nxp.full_like(x, fill_value=boundary[i], shape=pad_shape)
if block_id[i] == 0: # first block on axis i
x = nxp.concat([p, x], axis=i)
elif block_id[i] == numblocks[i] - 1: # last block on axis i
x = nxp.concat([x, p], axis=i)
return x
79 changes: 79 additions & 0 deletions cubed/tests/test_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp


def test_map_overlap_1d():
x = np.arange(6)
a = xp.asarray(x, chunks=(3,))

b = cubed.map_overlap(
lambda x: x,
a,
dtype=a.dtype,
chunks=((5, 5),),
depth=1,
boundary=0,
trim=False,
)

assert_array_equal(b.compute(), np.array([0, 0, 1, 2, 3, 2, 3, 4, 5, 0]))


def test_map_overlap_2d():
x = np.arange(36).reshape((6, 6))
a = xp.asarray(x, chunks=(3, 3))

b = cubed.map_overlap(
lambda x: x,
a,
dtype=a.dtype,
chunks=((7, 7), (5, 5)),
depth={0: 2, 1: 1},
boundary={0: 100, 1: 200},
trim=False,
)

expected = np.array(
[
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 0, 1, 2, 3, 2, 3, 4, 5, 200],
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
[200, 6, 7, 8, 9, 8, 9, 10, 11, 200],
[200, 12, 13, 14, 15, 14, 15, 16, 17, 200],
[200, 18, 19, 20, 21, 20, 21, 22, 23, 200],
[200, 24, 25, 26, 27, 26, 27, 28, 29, 200],
[200, 30, 31, 32, 33, 32, 33, 34, 35, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
[200, 100, 100, 100, 100, 100, 100, 100, 100, 200],
]
)

assert_array_equal(b.compute(), expected)


def test_map_overlap_trim():
x = np.array([1, 1, 2, 3, 5, 8, 13, 21])
a = xp.asarray(x, chunks=5)

def derivative(x):
out = x - np.roll(x, 1)
return out[1:-1] # manual trim

b = cubed.map_overlap(
derivative,
a,
dtype=a.dtype,
chunks=a.chunks,
depth=1,
boundary=0,
trim=False,
)

assert_array_equal(b.compute(), np.array([1, 0, 1, 1, 2, 3, 5, 8]))
38 changes: 38 additions & 0 deletions cubed/vendor/dask/array/overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from numbers import Integral


def coerce_depth(ndim, depth):
default = 0
if depth is None:
depth = default
if isinstance(depth, Integral):
depth = (depth,) * ndim
if isinstance(depth, tuple):
depth = dict(zip(range(ndim), depth))
if isinstance(depth, dict):
depth = {ax: depth.get(ax, default) for ax in range(ndim)}
return coerce_depth_type(ndim, depth)


def coerce_depth_type(ndim, depth):
for i in range(ndim):
if isinstance(depth[i], tuple):
depth[i] = tuple(int(d) for d in depth[i])
else:
depth[i] = int(depth[i])
return depth


def coerce_boundary(ndim, boundary):
default = "none"
if boundary is None:
boundary = default
if not isinstance(boundary, (tuple, dict)):
boundary = (boundary,) * ndim
if isinstance(boundary, tuple):
boundary = dict(zip(range(ndim), boundary))
if isinstance(boundary, dict):
boundary = {ax: boundary.get(ax, default) for ax in range(ndim)}
return boundary
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Chunk-specific functions

apply_gufunc
map_blocks
map_overlap

Non-standardised functions
==========================
Expand Down
Loading