Skip to content

Commit

Permalink
Implement 'concat'
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 27, 2022
1 parent 21fc7a5 commit 3d35cca
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_full
array_api_tests/test_linalg.py::test_matmul
array_api_tests/test_manipulation_functions.py::test_concat
array_api_tests/test_manipulation_functions.py::test_stack
# not implemented
Expand All @@ -76,7 +77,6 @@ jobs:
array_api_tests/test_linalg.py::test_matrix_transpose
array_api_tests/test_linalg.py::test_tensordot
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_concat
array_api_tests/test_manipulation_functions.py::test_flip
array_api_tests/test_manipulation_functions.py::test_roll
array_api_tests/test_searching_functions.py::test_argmax
Expand Down
2 changes: 1 addition & 1 deletion api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `vecdot` | | 1 | Express using `tensordot` |
| Manipulation | `broadcast_arrays` | :white_check_mark: | | |
| | `broadcast_to` | :white_check_mark: | | |
| | `concat` | | 3 | Like `stack` |
| | `concat` | :white_check_mark: | | |
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | | 3 | Needs indexing |
| | `permute_dims` | :white_check_mark: | | |
Expand Down
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
broadcast_arrays,
broadcast_to,
can_cast,
concat,
divide,
e,
empty,
Expand Down Expand Up @@ -73,6 +74,7 @@
"broadcast_to",
"can_cast",
"Callback",
"concat",
"divide",
"e",
"empty",
Expand Down
1 change: 1 addition & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .manipulation_functions import (
broadcast_arrays,
broadcast_to,
concat,
expand_dims,
permute_dims,
reshape,
Expand Down
74 changes: 71 additions & 3 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from bisect import bisect
from itertools import product
from math import prod
from operator import mul
from operator import add, mul

import numpy as np
import tlz
from dask.array.core import broadcast_chunks, normalize_chunks
from dask.array.reshape import reshape_rechunk
from dask.array.slicing import sanitize_index
from dask.array.utils import validate_axis
from tlz import concat
from toolz import reduce

from cubed.array_api.creation_functions import empty
Expand All @@ -22,7 +23,7 @@ def broadcast_arrays(*arrays):

# Unify uneven chunking
inds = [list(reversed(range(x.ndim))) for x in arrays]
uc_args = concat(zip(arrays, inds))
uc_args = tlz.concat(zip(arrays, inds))
_, args = unify_chunks(*uc_args, warn=False)

shape = np.broadcast_shapes(*(e.shape for e in args))
Expand Down Expand Up @@ -74,6 +75,73 @@ def _broadcast_like(x, template):
return np.broadcast_to(x, template.shape)


def concat(arrays, /, *, axis=0):
if not arrays:
raise ValueError("Need array(s) to concat")

if axis is None:
raise NotImplementedError("None axis not supported in concat")

# TODO: check arrays all have same shape (except in the dimension specified by axis)
# TODO: type promotion
# TODO: unify chunks

a = arrays[0]

# offsets along axis for the start of each array
offsets = [0] + list(tlz.accumulate(add, [a.shape[axis] for a in arrays]))

axis = validate_axis(axis, a.ndim)
shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :]
dtype = a.dtype
chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks,
# the chunks are read in series, reusing memory
extra_required_mem = np.dtype(a.dtype).itemsize * prod(to_chunksize(a.chunks))

return map_direct(
_read_concat_chunk,
*arrays,
shape=shape,
dtype=dtype,
chunks=chunks,
extra_required_mem=extra_required_mem,
axis=axis,
offsets=offsets,
)


def _read_concat_chunk(x, *arrays, axis=None, offsets=None, block_id=None):
# determine the start and stop indexes for this block along the axis dimension
chunks = arrays[0].zarray.chunks
start = block_id[axis] * chunks[axis]
stop = start + x.shape[axis]

# produce a key that has slices (except for axis dimension, which is replaced below)
idx = tuple(0 if i == axis else v for i, v in enumerate(block_id))
key = get_item(arrays[0].chunks, idx)

# concatenate slices of the arrays
parts = []
for ai, sl in _array_slices(offsets, start, stop):
key = tuple(sl if i == axis else k for i, k in enumerate(key))
parts.append(arrays[ai].zarray[key])
return np.concatenate(parts, axis=axis)


def _array_slices(offsets, start, stop):
"""Return pairs of array index and slice to slice from start to stop in the concatenated array."""
slice_start = start
while slice_start < stop:
# find array that slice_start falls in
i = bisect(offsets, slice_start) - 1
slice_stop = min(stop, offsets[i + 1])
yield i, slice(slice_start - offsets[i], slice_stop - offsets[i])
slice_start = slice_stop


def expand_dims(x, /, *, axis):
if not isinstance(axis, tuple):
axis = (axis,)
Expand Down
14 changes: 14 additions & 0 deletions cubed/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,20 @@ def test_broadcast_to(
assert_array_equal(b.compute(executor=executor), np.broadcast_to(x, new_shape))


def test_concat(spec, executor):
# note: middle chunk of output reads from three input chunks
a = xp.full((4, 5), 1, chunks=(3, 2), spec=spec)
b = xp.full((1, 5), 2, chunks=(3, 2), spec=spec)
c = xp.full((3, 5), 3, chunks=(3, 2), spec=spec)
d = xp.concat([a, b, c], axis=0)
assert_array_equal(
d.compute(executor=executor),
np.concatenate(
[np.full((4, 5), 1), np.full((1, 5), 2), np.full((3, 5), 3)], axis=0
),
)


def test_expand_dims(spec, executor):
a = xp.asarray([1, 2, 3], chunks=(2,), spec=spec)
b = xp.expand_dims(a, axis=0)
Expand Down

0 comments on commit 3d35cca

Please sign in to comment.