From 6c22bfaa953117867bf5a1d9d490c423c1f1f9ca Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 14 Oct 2024 10:06:17 -0500 Subject: [PATCH] Fixed consolidated Group getitem with multi-part key This fixes `Group.__getitem__` when indexing with a key like 'subgroup/array'. The basic idea is to rewrite the indexing operation as `group['subgroup']['array']` by splitting the key and doing each operation independently. This is fine for consolidated metadata which doesn't need to do IO. There's a complication around unconsolidated metadata, though. What if we encounter a node where `Group.getitem` returns a sub Group without consolidated metadata. Then we need to fall back to non-consolidated metadata. We've written _getitem_consolidated as a regular (non-async) function so we need to pop back up to the async caller and have *it* fall back. Closes https://github.com/zarr-developers/zarr-python/issues/2358 --- src/zarr/core/group.py | 132 ++++++++++++++++++++++++++++++++--------- tests/v3/test_group.py | 15 ++++- 2 files changed, 119 insertions(+), 28 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 0b15e2f08..187c78def 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -76,7 +76,9 @@ def parse_attributes(data: Any) -> dict[str, Any]: @overload -def _parse_async_node(node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]) -> Array: ... +def _parse_async_node( + node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], +) -> Array: ... @overload @@ -97,6 +99,24 @@ def _parse_async_node( raise TypeError(f"Unknown node type, got {type(node)}") +class _MixedConsolidatedMetadataException(Exception): + """ + A custom, *internal* exception for when we encounter mixed consolidated metadata. + + Typically, Consolidated Metadata will explicitly indicate that there are no + additional children under a group with ``ConsolidatedMetadata(metadata={})``, + as opposed to ``metadata=None``. This is the behavior of ``consolidated_metadata``. + We rely on that "fact" to do I/O-free getitem: when a group's consolidated metadata + doesn't contain a child we can raise a ``KeyError`` without consulting the backing + store. + + Users can potentially get themselves in situations where there's "mixed" consolidated + metadata. For now, we'll raise this error, catch it internally, and silently fall back + to the store (which will either succeed or raise its own KeyError, slowly). We might + want to expose this in the future, in which case rename it add it to zarr.errors. + """ + + @dataclass(frozen=True) class ConsolidatedMetadata: """ @@ -235,7 +255,9 @@ def _flat_to_nested( ) @property - def flattened_metadata(self) -> dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata]: + def flattened_metadata( + self, + ) -> dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata]: """ Return the flattened representation of Consolidated Metadata. @@ -513,7 +535,10 @@ async def open( maybe_consolidated_metadata_bytes = None return cls._from_bytes_v2( - store_path, zgroup_bytes, zattrs_bytes, maybe_consolidated_metadata_bytes + store_path, + zgroup_bytes, + zattrs_bytes, + maybe_consolidated_metadata_bytes, ) else: # V3 groups are comprised of a zarr.json object @@ -571,7 +596,10 @@ def _from_bytes_v2( @classmethod def _from_bytes_v3( - cls, store_path: StorePath, zarr_json_bytes: Buffer, use_consolidated: bool | None + cls, + store_path: StorePath, + zarr_json_bytes: Buffer, + use_consolidated: bool | None, ) -> AsyncGroup: group_metadata = json.loads(zarr_json_bytes.to_bytes()) if use_consolidated and group_metadata.get("consolidated_metadata") is None: @@ -604,14 +632,22 @@ async def getitem( # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: - return self._getitem_consolidated(store_path, key, prefix=self.name) + try: + return self._getitem_consolidated(store_path, key, prefix=self.name) + except _MixedConsolidatedMetadataException: + logger.info( + "Mixed consolidated and unconsolidated metadata. key=%s, store_path=%s", + key, + store_path, + ) + # now fall back to the non-consolidated variant # Note: # in zarr-python v2, we first check if `key` references an Array, else if `key` references # a group,using standalone `contains_array` and `contains_group` functions. These functions # are reusable, but for v3 they would perform redundant I/O operations. # Not clear how much of that strategy we want to keep here. - elif self.metadata.zarr_format == 3: + if self.metadata.zarr_format == 3: zarr_json_bytes = await (store_path / ZARR_JSON).get() if zarr_json_bytes is None: raise KeyError(key) @@ -661,18 +697,39 @@ def _getitem_consolidated( # getitem, in the special case where we have consolidated metadata. # Note that this is a regular def (non async) function. # This shouldn't do any additional I/O. + # All callers *must* catch _MixedConsolidatedMetadataException to ensure + # that we correctly handle the case where we need to fall back to doing + # additional I/O. # the caller needs to verify this! assert self.metadata.consolidated_metadata is not None - try: - metadata = self.metadata.consolidated_metadata.metadata[key] - except KeyError as e: - # The Group Metadata has consolidated metadata, but the key - # isn't present. We trust this to mean that the key isn't in - # the hierarchy, and *don't* fall back to checking the store. - msg = f"'{key}' not found in consolidated metadata." - raise KeyError(msg) from e + # we support nested getitems like group/subgroup/array + indexers = key.split("/") + indexers.reverse() + metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata = self.metadata + + while indexers: + indexer = indexers.pop() + if isinstance(metadata, ArrayV2Metadata | ArrayV3Metadata): + # we've indexed into an array with group["array/subarray"]. Invalid. + raise KeyError(key) + try: + if metadata.consolidated_metadata is None: + # we've indexed into a group without consolidated metadata. + # Note that the `None` case is different from `metadata={}` + # where we explicitly know we have no children. In the None + # case we have to fall back to non-consolidated metadata. + raise _MixedConsolidatedMetadataException(key) + assert metadata.consolidated_metadata is not None + + metadata = metadata.consolidated_metadata.metadata[indexer] + except KeyError as e: + # The Group Metadata has consolidated metadata, but the key + # isn't present. We trust this to mean that the key isn't in + # the hierarchy, and *don't* fall back to checking the store. + msg = f"'{key}' not found in consolidated metadata." + raise KeyError(msg) from e # update store_path to ensure that AsyncArray/Group.name is correct if prefix != "/": @@ -1087,7 +1144,8 @@ async def members( self, max_depth: int | None = 0, ) -> AsyncGenerator[ - tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None + tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], + None, ]: """ Returns an AsyncGenerator over the arrays and groups contained in this group. @@ -1118,15 +1176,20 @@ async def members( async def _members( self, max_depth: int | None, current_depth: int ) -> AsyncGenerator[ - tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None + tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], + None, ]: if self.metadata.consolidated_metadata is not None: # we should be able to do members without any additional I/O - members = self._members_consolidated(max_depth, current_depth) - - for member in members: - yield member - return + try: + members = self._members_consolidated(max_depth, current_depth) + except _MixedConsolidatedMetadataException: + # we've already logged this. We'll fall back to the non-consolidated version. + pass + else: + for member in members: + yield member + return if not self.store_path.store.supports_listing: msg = ( @@ -1177,17 +1240,28 @@ async def _members( def _members_consolidated( self, max_depth: int | None, current_depth: int, prefix: str = "" ) -> Generator[ - tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None + tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], + None, ]: consolidated_metadata = self.metadata.consolidated_metadata # we kind of just want the top-level keys. if consolidated_metadata is not None: for key in consolidated_metadata.metadata.keys(): - obj = self._getitem_consolidated( - self.store_path, key, prefix=self.name - ) # Metadata -> Group/Array - key = f"{prefix}/{key}".lstrip("/") + try: + obj = self._getitem_consolidated( + self.store_path, key, prefix=self.name + ) # Metadata -> Group/Array + key = f"{prefix}/{key}".lstrip("/") + except _MixedConsolidatedMetadataException: + logger.info( + "Mixed consolidated and unconsolidated metadata. key=%s, depth=%d, prefix=%s", + key, + current_depth, + prefix, + ) + # This isn't an async def function so we need to re-raise up one more level. + raise yield key, obj if ((max_depth is None) or (current_depth < max_depth)) and isinstance( @@ -1262,7 +1336,11 @@ async def full( self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: return await async_api.full( - shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs + shape=shape, + fill_value=fill_value, + store=self.store_path, + path=name, + **kwargs, ) async def empty_like( diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 90933abea..c7bf2ff0b 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -300,18 +300,31 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool group = Group.from_store(store, zarr_format=zarr_format) subgroup = group.create_group(name="subgroup") subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) + subsubarray = subgroup.create_array(name="subarray", shape=(10,), chunk_shape=(10,)) if consolidated: group = zarr.api.synchronous.consolidate_metadata(store=store, zarr_format=zarr_format) + # we're going to assume that `group.metadata` is correct, and reuse that to focus + # on indexing in this test. Other tests verify the correctness of group.metadata object.__setattr__( - subgroup.metadata, "consolidated_metadata", ConsolidatedMetadata(metadata={}) + subgroup.metadata, + "consolidated_metadata", + ConsolidatedMetadata( + metadata={"subarray": group.metadata.consolidated_metadata.metadata["subarray"]} + ), ) assert group["subgroup"] == subgroup assert group["subarray"] == subarray + assert subgroup["subarray"] == subsubarray + # assert group["subgroup/subarray"] == subsubarray + with pytest.raises(KeyError): group["nope"] + with pytest.raises(KeyError, match="subarray/subsubarray"): + group["subarray/subsubarray"] + def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None: group = Group.from_store(store, zarr_format=zarr_format)