Skip to content

Commit

Permalink
Allow step size > 1 (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jul 21, 2023
1 parent f957a46 commit 6f9d958
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 20 deletions.
63 changes: 48 additions & 15 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zarr.indexing import (
IntDimIndexer,
OrthogonalIndexer,
SliceDimIndexer,
is_integer_list,
is_slice,
replace_ellipsis,
Expand Down Expand Up @@ -335,26 +336,46 @@ def index(x, key):
# Replace ellipsis with slices
selection = replace_ellipsis(selection, x.shape)

# Check selection is supported
if any(s.step is not None and s.step < 1 for s in selection if is_slice(s)):
raise NotImplementedError(f"Slice step must be >= 1: {key}")
assert all(isinstance(s, (slice, list, Integral)) for s in selection)
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
if len(where_list) > 1:
raise NotImplementedError("Only one integer array index is allowed.")

# Use a Zarr indexer just to find the resulting array shape and chunks
# Need to use an in-memory representation since the Zarr file has not been written yet
inmem = zarr.empty_like(x.zarray_maybe_lazy, store=zarr.storage.MemoryStore())
indexer = OrthogonalIndexer(selection, inmem)

def chunk_len_for_indexer(s):
if not isinstance(s, SliceDimIndexer):
return s.dim_chunk_len
return max(s.dim_chunk_len // s.step, 1)

def merged_chunk_len_for_indexer(s):
if not isinstance(s, SliceDimIndexer):
return s.dim_chunk_len
if s.step is None or s.step == 1:
return s.dim_chunk_len
if (s.dim_chunk_len // s.step) < 1:
return s.dim_chunk_len
# note that this may not be the same as s.dim_chunk_len
# but it is guaranteed to be a multiple of the corresponding
# value returned by chunk_len_for_indexer, which is required
# by merge_chunks
return (s.dim_chunk_len // s.step) * s.step

shape = indexer.shape
dtype = x.dtype
chunks = tuple(
s.dim_chunk_len
chunk_len_for_indexer(s)
for s in indexer.dim_indexers
if not isinstance(s, IntDimIndexer)
)
chunks = normalize_chunks(chunks, shape, dtype=dtype)

assert all(isinstance(s, (slice, list, Integral)) for s in selection)
if any(s.step is not None and s.step != 1 for s in selection if is_slice(s)):
raise NotImplementedError(f"Slice step must be 1: {key}")
where_list = [i for i, ind in enumerate(selection) if is_integer_list(ind)]
if len(where_list) > 1:
raise NotImplementedError("Only one integer array index is allowed.")
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
Expand All @@ -366,12 +387,22 @@ def index(x, key):
x,
shape=shape,
dtype=dtype,
chunks=chunks,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=chunks,
target_chunks=target_chunks,
selection=selection,
)

# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
merged_chunks = tuple(
merged_chunk_len_for_indexer(s)
for s in indexer.dim_indexers
if not isinstance(s, IntDimIndexer)
)
if chunks != merged_chunks:
out = merge_chunks(out, merged_chunks)

for axis in where_none:
from cubed.array_api.manipulation_functions import expand_dims

Expand All @@ -396,9 +427,12 @@ def _target_chunk_selection(target_chunks, idx, selection):
for s in selection:
if is_slice(s):
offset = s.start or 0
start = tuple(accumulate(add, target_chunks[i], offset))
step = s.step if s.step is not None else 1
start = tuple(
accumulate(add, tuple(x * step for x in target_chunks[i]), offset)
)
j = idx[i]
sel.append(slice(start[j], start[j + 1]))
sel.append(slice(start[j], start[j + 1], step))
i += 1
elif is_integer_list(s):
# find the cumulative chunk starts
Expand Down Expand Up @@ -639,9 +673,7 @@ def rechunk(x, chunks, target_store=None):


def merge_chunks(x, chunks):
target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)

target_chunksize = to_chunksize(target_chunks)
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})"
Expand All @@ -651,6 +683,7 @@ def merge_chunks(x, chunks):
f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}"
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
return map_direct(
_copy_chunk,
x,
Expand Down
43 changes: 42 additions & 1 deletion cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,51 @@ def test_index_2d(spec, ind):
assert_array_equal(a[ind].compute(), x[ind])


@pytest.mark.parametrize(
"shape, chunks, ind, new_chunks_expected",
[
# step divides chunks exactly
(20, 4, slice(3, 14, 2), ((4, 2),)),
# step doesn't divide chunks exactly
# chunks is prime (so merge_chunks can't restore chunks to 5)
(20, 5, slice(3, 14, 2), ((4, 2),)),
# step doesn't divide chunks exactly
(20, 8, slice(5, 18, 3), ((5,),)),
# step is bigger than chunks
(50, 5, slice(3, 50, 7), ((5, 2),)),
],
)
def test_index_1d_step(spec, shape, chunks, ind, new_chunks_expected):
a = xp.arange(shape, chunks=chunks, spec=spec)
b = a[ind]
assert_array_equal(b.compute(), np.arange(shape)[ind])
assert b.chunks == new_chunks_expected


# fmt: off
@pytest.mark.parametrize(
"shape, chunks, ind, new_chunks_expected",
[
(
(20, 20),
(4, 4),
(slice(3, 14, 2), slice(3, 14, 3)),
((4, 2), (3, 1),),
),
],
)
# fmt: on
def test_index_2d_step(spec, shape, chunks, ind, new_chunks_expected):
a = xp.ones(shape, chunks=chunks, spec=spec)
b = a[ind]
assert_array_equal(b.compute(), np.ones(shape)[ind])
assert b.chunks == new_chunks_expected


def test_index_slice_unsupported_step(spec):
with pytest.raises(NotImplementedError):
a = xp.arange(12, chunks=(4,), spec=spec)
a[3:10:2]
a[::-1]


@pytest.mark.xfail(reason="not currently compatible with lazy zarr arrays")
Expand Down
17 changes: 13 additions & 4 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,25 @@ def test_reduction_not_enough_memory(tmp_path):
xp.sum(a, axis=0, dtype=np.uint8)


@pytest.mark.parametrize("target_chunks", [(2, 3), (4, 3), (2, 6), (4, 6)])
def test_merge_chunks(spec, target_chunks):
@pytest.mark.parametrize(
"target_chunks, expected_chunksize",
[
((2, 3), None),
((4, 3), None),
((2, 6), None),
((4, 6), None),
((12, 12), (10, 10)),
],
)
def test_merge_chunks(spec, target_chunks, expected_chunksize):
a = xp.ones((10, 10), dtype=np.uint8, chunks=(2, 3), spec=spec)
b = merge_chunks(a, target_chunks)
assert b.chunksize == target_chunks
assert b.chunksize == (expected_chunksize or target_chunks)
assert_array_equal(b.compute(), np.ones((10, 10)))


@pytest.mark.parametrize(
"target_chunks", [(2,), (2, 3, 1), (3, 2), (1, 3), (5, 5), (12, 12)]
"target_chunks", [(2,), (2, 3, 1), (3, 2), (1, 3), (5, 5), (10, 10)]
)
def test_merge_chunks_fails(spec, target_chunks):
a = xp.ones((10, 10), dtype=np.uint8, chunks=(2, 3), spec=spec)
Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def test_index(spec):
run_operation("index", b)


@pytest.mark.slow
def test_index_step(spec):
a = cubed.random.random(
(10000, 10000), chunks=(5000, 5000), spec=spec
) # 200MB chunks
b = a[::2, :]
run_operation("index_step", b)


# Creation Functions


Expand Down

0 comments on commit 6f9d958

Please sign in to comment.