Skip to content

Commit

Permalink
Fixed consolidated Group getitem with multi-part key
Browse files Browse the repository at this point in the history
This fixes `Group.__getitem__` when indexing with a key
like 'subgroup/array'. The basic idea is to rewrite the indexing
operation as `group['subgroup']['array']` by splitting the key
and doing each operation independently. This is fine for consolidated
metadata which doesn't need to do IO.

There's a complication around unconsolidated metadata, though. What
if we encounter a node where `Group.getitem` returns a sub Group
without consolidated metadata. Then we need to fall back to
non-consolidated metadata. We've written _getitem_consolidated
as a regular (non-async) function so we need to pop back up to
the async caller and have *it* fall back.

Closes zarr-developers#2358
  • Loading branch information
TomAugspurger committed Oct 14, 2024
1 parent 9bbfd88 commit f4f8200
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 25 deletions.
119 changes: 95 additions & 24 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ def _parse_async_node(
raise TypeError(f"Unknown node type, got {type(node)}")


class _MixedConsolidatedMetadataException(Exception):
"""
A custom, *internal* exception for when we encounter mixed consolidated metadata.
Typically, Consolidated Metadata will explicitly indicate that there are no
additional children under a group with ``ConsolidatedMetadata(metadata={})``,
as opposed to ``metadata=None``. This is the behavior of ``consolidated_metadata``.
We rely on that "fact" to do I/O-free getitem: when a group's consolidated metadata
doesn't contain a child we can raise a ``KeyError`` without consulting the backing
store.
Users can potentially get themselves in situations where there's "mixed" consolidated
metadata. For now, we'll raise this error, catch it internally, and silently fall back
to the store (which will either succeed or raise its own KeyError, slowly). We might
want to expose this in the future, in which case rename it add it to zarr.errors.
"""


@dataclass(frozen=True)
class ConsolidatedMetadata:
"""
Expand Down Expand Up @@ -571,7 +589,10 @@ def _from_bytes_v2(

@classmethod
def _from_bytes_v3(
cls, store_path: StorePath, zarr_json_bytes: Buffer, use_consolidated: bool | None
cls,
store_path: StorePath,
zarr_json_bytes: Buffer,
use_consolidated: bool | None,
) -> AsyncGroup:
group_metadata = json.loads(zarr_json_bytes.to_bytes())
if use_consolidated and group_metadata.get("consolidated_metadata") is None:
Expand Down Expand Up @@ -604,14 +625,22 @@ async def getitem(

# Consolidated metadata lets us avoid some I/O operations so try that first.
if self.metadata.consolidated_metadata is not None:
return self._getitem_consolidated(store_path, key, prefix=self.name)
try:
return self._getitem_consolidated(store_path, key, prefix=self.name)
except _MixedConsolidatedMetadataException:
logger.info(
"Mixed consolidated and unconsolidated metadata. key=%s, store_path=%s",
key,
store_path,
)
# now fall back to the non-consolidated variant

# Note:
# in zarr-python v2, we first check if `key` references an Array, else if `key` references
# a group,using standalone `contains_array` and `contains_group` functions. These functions
# are reusable, but for v3 they would perform redundant I/O operations.
# Not clear how much of that strategy we want to keep here.
elif self.metadata.zarr_format == 3:
if self.metadata.zarr_format == 3:
zarr_json_bytes = await (store_path / ZARR_JSON).get()
if zarr_json_bytes is None:
raise KeyError(key)
Expand Down Expand Up @@ -661,18 +690,39 @@ def _getitem_consolidated(
# getitem, in the special case where we have consolidated metadata.
# Note that this is a regular def (non async) function.
# This shouldn't do any additional I/O.
# All callers *must* catch _MixedConsolidatedMetadataException to ensure
# that we correctly handle the case where we need to fall back to doing
# additional I/O.

# the caller needs to verify this!
assert self.metadata.consolidated_metadata is not None

try:
metadata = self.metadata.consolidated_metadata.metadata[key]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e
# we support nested getitems like group/subgroup/array
indexers = key.split("/")
indexers.reverse()
metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata = self.metadata

while indexers:
indexer = indexers.pop()
if isinstance(metadata, ArrayV2Metadata | ArrayV3Metadata):
# we've indexed into an array with group["array/subarray"]. Invalid.
raise KeyError(key)
try:
if metadata.consolidated_metadata is None:
# we've indexed into a group without consolidated metadata.
# Note that the `None` case is different from `metadata={}`
# where we explicitly know we have no children. In the None
# case we have to fall back to non-consolidated metadata.
raise _MixedConsolidatedMetadataException(key)
assert metadata.consolidated_metadata is not None

metadata = metadata.consolidated_metadata.metadata[indexer]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e

# update store_path to ensure that AsyncArray/Group.name is correct
if prefix != "/":
Expand Down Expand Up @@ -1087,7 +1137,8 @@ async def members(
self,
max_depth: int | None = 0,
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
Expand Down Expand Up @@ -1118,15 +1169,20 @@ async def members(
async def _members(
self, max_depth: int | None, current_depth: int
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
if self.metadata.consolidated_metadata is not None:
# we should be able to do members without any additional I/O
members = self._members_consolidated(max_depth, current_depth)

for member in members:
yield member
return
try:
members = self._members_consolidated(max_depth, current_depth)
except _MixedConsolidatedMetadataException:
# we've already logged this. We'll fall back to the non-consolidated version.
pass
else:
for member in members:
yield member
return

if not self.store_path.store.supports_listing:
msg = (
Expand Down Expand Up @@ -1177,17 +1233,28 @@ async def _members(
def _members_consolidated(
self, max_depth: int | None, current_depth: int, prefix: str = ""
) -> Generator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
consolidated_metadata = self.metadata.consolidated_metadata

# we kind of just want the top-level keys.
if consolidated_metadata is not None:
for key in consolidated_metadata.metadata.keys():
obj = self._getitem_consolidated(
self.store_path, key, prefix=self.name
) # Metadata -> Group/Array
key = f"{prefix}/{key}".lstrip("/")
try:
obj = self._getitem_consolidated(
self.store_path, key, prefix=self.name
) # Metadata -> Group/Array
key = f"{prefix}/{key}".lstrip("/")
except _MixedConsolidatedMetadataException:
logger.info(
"Mixed consolidated and unconsolidated metadata. key=%s, depth=%d, prefix=%s",
key,
current_depth,
prefix,
)
# This isn't an async def function so we need to re-raise up one more level.
raise
yield key, obj

if ((max_depth is None) or (current_depth < max_depth)) and isinstance(
Expand Down Expand Up @@ -1262,7 +1329,11 @@ async def full(
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
return await async_api.full(
shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs
shape=shape,
fill_value=fill_value,
store=self.store_path,
path=name,
**kwargs,
)

async def empty_like(
Expand Down
15 changes: 14 additions & 1 deletion tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,31 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool
group = Group.from_store(store, zarr_format=zarr_format)
subgroup = group.create_group(name="subgroup")
subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
subsubarray = subgroup.create_array(name="subarray", shape=(10,), chunk_shape=(10,))

if consolidated:
group = zarr.api.synchronous.consolidate_metadata(store=store, zarr_format=zarr_format)
# we're going to assume that `group.metadata` is correct, and reuse that to focus
# on indexing in this test. Other tests verify the correctness of group.metadata
object.__setattr__(
subgroup.metadata, "consolidated_metadata", ConsolidatedMetadata(metadata={})
subgroup.metadata,
"consolidated_metadata",
ConsolidatedMetadata(
metadata={"subarray": group.metadata.consolidated_metadata.metadata["subarray"]}
),
)

assert group["subgroup"] == subgroup
assert group["subarray"] == subarray
assert subgroup["subarray"] == subsubarray
# assert group["subgroup/subarray"] == subsubarray

with pytest.raises(KeyError):
group["nope"]

with pytest.raises(KeyError, match="subarray/subsubarray"):
group["subarray/subsubarray"]


def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
group = Group.from_store(store, zarr_format=zarr_format)
Expand Down

0 comments on commit f4f8200

Please sign in to comment.