Skip to content

Commit

Permalink
Implement 'stack'
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 16, 2022
1 parent d872cb1 commit 8e6aaf5
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 15 deletions.
4 changes: 2 additions & 2 deletions api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ This table shows which [Array API functions](https://data-apis.org/array-api/lat
| | `vecdot` | | 1 | Express using `tensordot` |
| Manipulation | `broadcast_arrays` | :white_check_mark: | | |
| | `broadcast_to` | :white_check_mark: | | Primitive (Zarr view) |
| | `concat` | | 3 | Primitive (Zarr view) |
| | `concat` | | 3 | Like `stack` |
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | | 3 | Needs indexing |
| | `permute_dims` | :white_check_mark: | | |
| | `reshape` | :white_check_mark: | | Partial implementation |
| | `roll` | | 3 | Needs `concat` and `reshape` |
| | `squeeze` | :white_check_mark: | | |
| | `stack` | | 2 | Primitive (Zarr view) |
| | `stack` | :white_check_mark: | | |
| Searching | `argmax` | | 2 | `argreduction` primitive |
| | `argmin` | | 2 | `argreduction` primitive |
| | `nonzero` | | 3 | Shape is data dependent |
Expand Down
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
reshape,
result_type,
squeeze,
stack,
sum,
uint8,
uint16,
Expand Down Expand Up @@ -95,6 +96,7 @@
"result_type",
"Spec",
"squeeze",
"stack",
"sum",
"to_zarr",
"TqdmProgressBar",
Expand Down
1 change: 1 addition & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
permute_dims,
reshape,
squeeze,
stack,
)
from .statistical_functions import mean, sum
from .utility_functions import all, any
43 changes: 41 additions & 2 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from math import prod
from operator import mul

import numpy as np
Expand All @@ -10,10 +11,10 @@

from cubed.core import squeeze # noqa: F401
from cubed.core import Array, Plan, blockwise, gensym, rechunk, unify_chunks
from cubed.core.ops import map_blocks
from cubed.core.ops import map_blocks, map_direct
from cubed.primitive.broadcast import broadcast_to as primitive_broadcast_to
from cubed.primitive.reshape import reshape_chunks as primitive_reshape_chunks
from cubed.utils import to_chunksize
from cubed.utils import get_item, to_chunksize


def broadcast_arrays(*arrays):
Expand Down Expand Up @@ -102,3 +103,41 @@ def reshape(x, /, shape):
target = primitive_reshape_chunks(x2.zarray, shape, outchunks)
plan = Plan(name, "reshape", target, spec, None, None, None, x2)
return Array(name, target, plan)


def stack(arrays, /, *, axis=0):
if not arrays:
raise ValueError("Need array(s) to stack")

# TODO: check arrays all have same shape
# TODO: type promotion
# TODO: unify chunks

a = arrays[0]

axis = validate_axis(axis, a.ndim + 1)
shape = a.shape[:axis] + (len(arrays),) + a.shape[axis:]
dtype = a.dtype
chunks = a.chunks[:axis] + ((1,) * len(arrays),) + a.chunks[axis:]

# memory allocated by reading one chunk from an input array
# (output is already catered for in blockwise)
extra_required_mem = np.dtype(a.dtype).itemsize * prod(to_chunksize(a.chunks))

return map_direct(
_read_stack_chunk,
*arrays,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_required_mem=extra_required_mem,
axis=axis,
)


def _read_stack_chunk(x, *arrays, axis=None, block_id=None):
array = arrays[block_id[axis]]
idx = tuple(v for i, v in enumerate(block_id) if i != axis)
out = array.zarray[get_item(array.chunks, idx)]
out = np.expand_dims(out, axis=axis)
return out
61 changes: 60 additions & 1 deletion cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def blockwise(
zargs = list(args)
zargs[::2] = [a.zarray for a in arrays]

extra_source_arrays = kwargs.pop("extra_source_arrays", [])
source_arrays = list(arrays) + list(extra_source_arrays)

extra_required_mem = kwargs.pop("extra_required_mem", 0)

name = gensym()
spec = arrays[0].plan.spec
target_store = new_temp_store(name=name, spec=spec)
Expand All @@ -127,7 +132,14 @@ def blockwise(
**kwargs,
)
plan = Plan(
name, "blockwise", target, spec, pipeline, required_mem, num_tasks, *arrays
name,
"blockwise",
target,
spec,
pipeline,
required_mem + extra_required_mem,
num_tasks,
*source_arrays,
)
return Array(name, target, plan)

Expand Down Expand Up @@ -260,6 +272,53 @@ def _map_blocks(
)


def map_direct(func, *args, shape, dtype, chunks, extra_required_mem, **kwargs):
"""
Map a function across blocks of a new array, using side-input arrays to read directly from.
Parameters
----------
func : callable
Function to apply to every block to produce the output array.
Must accept ``block_id`` as a keyword argument (with same meaning as for ``map_blocks``).
args : arrays
The side-input arrays that may be accessed directly in the function.
shape : tuple
Shape of the output array.
dtype : np.dtype
The ``dtype`` of the output array.
chunks : tuple
Chunk shape of blocks in the output array.
extra_required_mem : int
Extra memory required (in bytes) for each map task. This should take into account the
memory requirements for any reads from the side-input arrays (``args``).
"""

from cubed.array_api.creation_functions import empty

out = empty(shape, dtype=dtype, chunks=chunks, spec=args[0].plan.spec)

kwargs["arrays"] = args

def new_func(func):
def wrap(*a, block_id=None, **kw):
arrays = kw.pop("arrays")
args = a + arrays
return func(*args, block_id=block_id, **kw)

return wrap

return map_blocks(
new_func(func),
out,
dtype=dtype,
chunks=chunks,
extra_source_arrays=args,
extra_required_mem=extra_required_mem,
**kwargs,
)


def rechunk(x, chunks, target_store=None):
name = gensym()
spec = x.plan.spec
Expand Down
11 changes: 1 addition & 10 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from dask.array.core import normalize_chunks
from dask.blockwise import make_blockwise_graph
from dask.core import flatten
from dask.utils import cached_cumsum
from rechunker.api import _zarr_empty
from rechunker.types import ArrayProxy, Pipeline, Stage
from toolz import map

from cubed.utils import to_chunksize
from cubed.utils import get_item, to_chunksize

sym_counter = 0

Expand Down Expand Up @@ -48,14 +47,6 @@ def apply_blockwise_structured(graph_item, *, config=BlockwiseSpec):
config.write.array.set_basic_selection(out_chunk_key, v, fields=k)


def get_item(chunks, idx):

starts = tuple(cached_cumsum(c, initial_zero=True) for c in chunks)

loc = tuple((start[i], start[i + 1]) for i, start in zip(idx, starts))
return tuple(slice(*s, None) for s in loc)


def blockwise(
func,
out_ind,
Expand Down
11 changes: 11 additions & 0 deletions cubed/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def test_reshape(spec, executor):
)


def test_stack(spec, executor):
a = xp.full((4, 6), 1, chunks=(2, 3), spec=spec)
b = xp.full((4, 6), 2, chunks=(2, 3), spec=spec)
c = xp.full((4, 6), 3, chunks=(2, 3), spec=spec)
d = xp.stack([a, b, c], axis=0)
assert_array_equal(
d.compute(executor=executor),
np.stack([np.full((4, 6), 1), np.full((4, 6), 2), np.full((4, 6), 3)], axis=0),
)


def test_squeeze_1d(spec, executor):
a = xp.asarray([[1, 2, 3]], chunks=(1, 2), spec=spec)
b = xp.squeeze(a, 0)
Expand Down
8 changes: 8 additions & 0 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,18 @@
from urllib.parse import quote, unquote, urlsplit, urlunsplit

from dask.array.core import _check_regular_chunks
from dask.utils import cached_cumsum

PathType = Union[str, Path]


def get_item(chunks, idx):
"""Convert a chunk index to a tuple of slices."""
starts = tuple(cached_cumsum(c, initial_zero=True) for c in chunks)
loc = tuple((start[i], start[i + 1]) for i, start in zip(idx, starts))
return tuple(slice(*s, None) for s in loc)


def join_path(dir_url: PathType, child_path: str) -> str:
"""Combine a URL for a directory with a child path"""
parts = urlsplit(str(dir_url))
Expand Down

0 comments on commit 8e6aaf5

Please sign in to comment.