Skip to content

Commit

Permalink
Implement int array indexing using map_selection (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Nov 1, 2024
1 parent c372711 commit 535f08d
Showing 1 changed file with 32 additions and 50 deletions.
82 changes: 32 additions & 50 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,45 +583,25 @@ def merged_chunk_len_for_indexer(ia, c):

target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

if _is_basic_selection(idx):
# use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct
# use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct

def selection_function(out_key):
out_coords = out_key[1:]
return _target_chunk_selection(target_chunks, out_coords, selection)
def selection_function(out_key):
out_coords = out_key[1:]
return _target_chunk_selection(target_chunks, out_coords, selection)

max_num_input_blocks = _index_num_input_blocks(
idx, x.chunksize, out_chunksizes, x.numblocks
)
max_num_input_blocks = _index_num_input_blocks(
idx, x.chunksize, out_chunksizes, x.numblocks
)

out = map_selection(
None, # no function to apply after selection
selection_function,
x,
shape,
x.dtype,
target_chunks,
max_num_input_blocks=max_num_input_blocks,
)
else:
# use map_direct, which can't be fused
# (note that it should be possible to re-write as general_blockwise with more work)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)
out = map_selection(
None, # no function to apply after selection
selection_function,
x,
shape,
x.dtype,
target_chunks,
max_num_input_blocks=max_num_input_blocks,
)

# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
Expand All @@ -641,10 +621,6 @@ def selection_function(out_key):
return out


def _is_basic_selection(idx: ndindex.Tuple):
return all(isinstance(ia, (ndindex.Integer, ndindex.Slice)) for ia in idx.args)


def _index_num_input_blocks(
idx: ndindex.Tuple, in_chunksizes, out_chunksizes, numblocks
):
Expand All @@ -661,21 +637,27 @@ def _index_num_input_blocks(
# step is not a multiple of chunk size, and output chunks have more than one element
# so some output chunks will access two input chunks
num *= 2
elif isinstance(ia, ndindex.IntegerArray):
# in the worse case, elements could be retrieved from all blocks
# TODO: improve to calculate the actual max input blocks
num *= nb
else:
raise NotImplementedError("Only integer or slice indexes are supported.")
raise NotImplementedError(
"Only integer, slice, or int array indexes are supported."
)
return num


def create_basic_indexer(selection, shape, chunks):
def _create_zarr_indexer(selection, shape, chunks):
if zarr.__version__[0] == "3":
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.indexing import BasicIndexer
from zarr.core.indexing import OrthogonalIndexer

return BasicIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
return OrthogonalIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
else:
from zarr.indexing import BasicIndexer
from zarr.indexing import OrthogonalIndexer

return BasicIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))
return OrthogonalIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))


@dataclass
Expand Down Expand Up @@ -706,8 +688,8 @@ def _assemble_index_chunk(
out_coords = block_id
in_sel = selection_function(("out",) + out_coords)

# use a Zarr BasicIndexer to convert this to input coordinates
indexer = create_basic_indexer(in_sel, in_shape, in_chunksize)
# use a Zarr indexer to convert this to input coordinates
indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize)

shape = indexer.shape
out = np.empty(shape, dtype=dtype)
Expand Down Expand Up @@ -793,8 +775,8 @@ def key_function(out_key):
# compute the selection on x required to get the relevant chunk for out_key
in_sel = selection_function(out_key)

# use a Zarr BasicIndexer to convert selection to input coordinates
indexer = create_basic_indexer(in_sel, x.shape, x.chunksize)
# use a Zarr indexer to convert selection to input coordinates
indexer = _create_zarr_indexer(in_sel, x.shape, x.chunksize)

return (
iter(tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)),
Expand Down

0 comments on commit 535f08d

Please sign in to comment.