Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement int array indexing using map_selection #604

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 32 additions & 50 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,45 +583,25 @@ def merged_chunk_len_for_indexer(ia, c):

target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

if _is_basic_selection(idx):
# use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct
# use map_selection (which uses general_blockwise) to allow more opportunities for optimization than map_direct

def selection_function(out_key):
out_coords = out_key[1:]
return _target_chunk_selection(target_chunks, out_coords, selection)
def selection_function(out_key):
out_coords = out_key[1:]
return _target_chunk_selection(target_chunks, out_coords, selection)

max_num_input_blocks = _index_num_input_blocks(
idx, x.chunksize, out_chunksizes, x.numblocks
)
max_num_input_blocks = _index_num_input_blocks(
idx, x.chunksize, out_chunksizes, x.numblocks
)

out = map_selection(
None, # no function to apply after selection
selection_function,
x,
shape,
x.dtype,
target_chunks,
max_num_input_blocks=max_num_input_blocks,
)
else:
# use map_direct, which can't be fused
# (note that it should be possible to re-write as general_blockwise with more work)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)
out = map_selection(
None, # no function to apply after selection
selection_function,
x,
shape,
x.dtype,
target_chunks,
max_num_input_blocks=max_num_input_blocks,
)

# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
Expand All @@ -641,10 +621,6 @@ def selection_function(out_key):
return out


def _is_basic_selection(idx: ndindex.Tuple):
return all(isinstance(ia, (ndindex.Integer, ndindex.Slice)) for ia in idx.args)


def _index_num_input_blocks(
idx: ndindex.Tuple, in_chunksizes, out_chunksizes, numblocks
):
Expand All @@ -661,21 +637,27 @@ def _index_num_input_blocks(
# step is not a multiple of chunk size, and output chunks have more than one element
# so some output chunks will access two input chunks
num *= 2
elif isinstance(ia, ndindex.IntegerArray):
# in the worse case, elements could be retrieved from all blocks
# TODO: improve to calculate the actual max input blocks
num *= nb
else:
raise NotImplementedError("Only integer or slice indexes are supported.")
raise NotImplementedError(
"Only integer, slice, or int array indexes are supported."
)
return num


def create_basic_indexer(selection, shape, chunks):
def _create_zarr_indexer(selection, shape, chunks):
if zarr.__version__[0] == "3":
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.indexing import BasicIndexer
from zarr.core.indexing import OrthogonalIndexer

return BasicIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
return OrthogonalIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
else:
from zarr.indexing import BasicIndexer
from zarr.indexing import OrthogonalIndexer

return BasicIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))
return OrthogonalIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))


@dataclass
Expand Down Expand Up @@ -706,8 +688,8 @@ def _assemble_index_chunk(
out_coords = block_id
in_sel = selection_function(("out",) + out_coords)

# use a Zarr BasicIndexer to convert this to input coordinates
indexer = create_basic_indexer(in_sel, in_shape, in_chunksize)
# use a Zarr indexer to convert this to input coordinates
indexer = _create_zarr_indexer(in_sel, in_shape, in_chunksize)

shape = indexer.shape
out = np.empty(shape, dtype=dtype)
Expand Down Expand Up @@ -793,8 +775,8 @@ def key_function(out_key):
# compute the selection on x required to get the relevant chunk for out_key
in_sel = selection_function(out_key)

# use a Zarr BasicIndexer to convert selection to input coordinates
indexer = create_basic_indexer(in_sel, x.shape, x.chunksize)
# use a Zarr indexer to convert selection to input coordinates
indexer = _create_zarr_indexer(in_sel, x.shape, x.chunksize)

return (
iter(tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)),
Expand Down
Loading