Skip to content

Commit

Permalink
Implement unstack using multiple outputs (#575)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Sep 17, 2024
1 parent 8a406dc commit 50181e0
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@
roll,
squeeze,
stack,
unstack,
)

__all__ += [
Expand All @@ -300,6 +301,7 @@
"roll",
"squeeze",
"stack",
"unstack",
]

from .array_api.searching_functions import argmax, argmin, where
Expand Down
2 changes: 2 additions & 0 deletions cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@
roll,
squeeze,
stack,
unstack,
)

__all__ += [
Expand All @@ -242,6 +243,7 @@
"roll",
"squeeze",
"stack",
"unstack",
]

from .searching_functions import argmax, argmin, where
Expand Down
40 changes: 40 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,43 @@ def key_function(out_key):

def _read_stack_chunk(array, axis=None):
return nxp.expand_dims(array, axis=axis)


def unstack(x, /, *, axis=0):
axis = validate_axis(axis, x.ndim)

n_arrays = x.shape[axis]

if n_arrays == 1:
return (x,)

shape = x.shape[:axis] + x.shape[axis + 1 :]
dtype = x.dtype
chunks = x.chunks[:axis] + x.chunks[axis + 1 :]

def key_function(out_key):
out_coords = out_key[1:]
all_in_coords = tuple(
out_coords[:axis] + (i,) + out_coords[axis:]
for i in range(x.numblocks[axis])
)
return tuple((x.name,) + in_coords for in_coords in all_in_coords)

return general_blockwise(
_unstack_chunk,
key_function,
x,
shapes=[shape] * n_arrays,
dtypes=[dtype] * n_arrays,
chunkss=[chunks] * n_arrays,
target_stores=[None] * n_arrays, # filled in by general_blockwise
axis=axis,
)


def _unstack_chunk(*arrs, axis=0):
# unstack each array in arrs and yield all in turn
for arr in arrs:
# TODO: replace with nxp.unstack(arr, axis=axis) when array-api-compat has unstack
for a in tuple(nxp.moveaxis(arr, axis, 0)):
yield a
22 changes: 22 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,28 @@ def test_stack(spec, executor):
)


@pytest.mark.parametrize("chunks", [(1, 2, 3), (2, 2, 3), (3, 2, 3)])
def test_unstack(spec, executor, chunks):
a = xp.full((4, 6), 1, chunks=(2, 3), spec=spec)
b = xp.full((4, 6), 2, chunks=(2, 3), spec=spec)
c = xp.full((4, 6), 3, chunks=(2, 3), spec=spec)
d = xp.stack([a, b, c], axis=0)

d = d.rechunk(chunks)

au, bu, cu = cubed.compute(*xp.unstack(d), executor=executor, optimize_graph=False)

assert_array_equal(au, np.full((4, 6), 1))
assert_array_equal(bu, np.full((4, 6), 2))
assert_array_equal(cu, np.full((4, 6), 3))


def test_unstack_noop(spec):
a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec)
(b,) = xp.unstack(a)
assert a is b


# Searching functions


Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ def test_stack(tmp_path, spec, executor):
run_operation(tmp_path, executor, "stack", c)


@pytest.mark.slow
def test_unstack(tmp_path, spec, executor):
a = cubed.random.random(
(2, 10000, 10000), chunks=(2, 5000, 5000), spec=spec
) # 400MB chunks
b, c = xp.unstack(a)
run_operation(tmp_path, executor, "unstack", b, c)


# Searching Functions


Expand Down

0 comments on commit 50181e0

Please sign in to comment.