From 50181e01ff951c96f255680ed34dd3f02082528b Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 17 Sep 2024 17:37:35 +0100 Subject: [PATCH] Implement `unstack` using multiple outputs (#575) --- cubed/__init__.py | 2 ++ cubed/array_api/__init__.py | 2 ++ cubed/array_api/manipulation_functions.py | 40 +++++++++++++++++++++++ cubed/tests/test_array_api.py | 22 +++++++++++++ cubed/tests/test_mem_utilization.py | 9 +++++ 5 files changed, 75 insertions(+) diff --git a/cubed/__init__.py b/cubed/__init__.py index f0089f3b..6aa7bf50 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -286,6 +286,7 @@ roll, squeeze, stack, + unstack, ) __all__ += [ @@ -300,6 +301,7 @@ "roll", "squeeze", "stack", + "unstack", ] from .array_api.searching_functions import argmax, argmin, where diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index b290674b..e7e6ffbf 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -228,6 +228,7 @@ roll, squeeze, stack, + unstack, ) __all__ += [ @@ -242,6 +243,7 @@ "roll", "squeeze", "stack", + "unstack", ] from .searching_functions import argmax, argmin, where diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index d226dd44..a18a133a 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -412,3 +412,43 @@ def key_function(out_key): def _read_stack_chunk(array, axis=None): return nxp.expand_dims(array, axis=axis) + + +def unstack(x, /, *, axis=0): + axis = validate_axis(axis, x.ndim) + + n_arrays = x.shape[axis] + + if n_arrays == 1: + return (x,) + + shape = x.shape[:axis] + x.shape[axis + 1 :] + dtype = x.dtype + chunks = x.chunks[:axis] + x.chunks[axis + 1 :] + + def key_function(out_key): + out_coords = out_key[1:] + all_in_coords = tuple( + out_coords[:axis] + (i,) + out_coords[axis:] + for i in range(x.numblocks[axis]) + ) + return tuple((x.name,) + in_coords for in_coords in all_in_coords) + + return general_blockwise( + _unstack_chunk, + key_function, + x, + shapes=[shape] * n_arrays, + dtypes=[dtype] * n_arrays, + chunkss=[chunks] * n_arrays, + target_stores=[None] * n_arrays, # filled in by general_blockwise + axis=axis, + ) + + +def _unstack_chunk(*arrs, axis=0): + # unstack each array in arrs and yield all in turn + for arr in arrs: + # TODO: replace with nxp.unstack(arr, axis=axis) when array-api-compat has unstack + for a in tuple(nxp.moveaxis(arr, axis, 0)): + yield a diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index e715184d..b34b8bbb 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -627,6 +627,28 @@ def test_stack(spec, executor): ) +@pytest.mark.parametrize("chunks", [(1, 2, 3), (2, 2, 3), (3, 2, 3)]) +def test_unstack(spec, executor, chunks): + a = xp.full((4, 6), 1, chunks=(2, 3), spec=spec) + b = xp.full((4, 6), 2, chunks=(2, 3), spec=spec) + c = xp.full((4, 6), 3, chunks=(2, 3), spec=spec) + d = xp.stack([a, b, c], axis=0) + + d = d.rechunk(chunks) + + au, bu, cu = cubed.compute(*xp.unstack(d), executor=executor, optimize_graph=False) + + assert_array_equal(au, np.full((4, 6), 1)) + assert_array_equal(bu, np.full((4, 6), 2)) + assert_array_equal(cu, np.full((4, 6), 3)) + + +def test_unstack_noop(spec): + a = xp.full((1, 4, 6), 1, chunks=(1, 2, 3), spec=spec) + (b,) = xp.unstack(a) + assert a is b + + # Searching functions diff --git a/cubed/tests/test_mem_utilization.py b/cubed/tests/test_mem_utilization.py index 6bb1e650..4ce3474e 100644 --- a/cubed/tests/test_mem_utilization.py +++ b/cubed/tests/test_mem_utilization.py @@ -279,6 +279,15 @@ def test_stack(tmp_path, spec, executor): run_operation(tmp_path, executor, "stack", c) +@pytest.mark.slow +def test_unstack(tmp_path, spec, executor): + a = cubed.random.random( + (2, 10000, 10000), chunks=(2, 5000, 5000), spec=spec + ) # 400MB chunks + b, c = xp.unstack(a) + run_operation(tmp_path, executor, "unstack", b, c) + + # Searching Functions