diff --git a/pangeo_forge_recipes/recipes/xarray_zarr.py b/pangeo_forge_recipes/recipes/xarray_zarr.py index 2b1ffb75..7f115653 100644 --- a/pangeo_forge_recipes/recipes/xarray_zarr.py +++ b/pangeo_forge_recipes/recipes/xarray_zarr.py @@ -3,6 +3,7 @@ """ import functools +import itertools import logging import os import warnings @@ -26,6 +27,7 @@ # use this filename to store global recipe metadata in the metadata_cache # it will be written once (by prepare_target) and read many times (by store_chunk) _GLOBAL_METADATA_KEY = "pangeo-forge-recipe-metadata.json" +_ARRAY_DIMENSIONS = "_ARRAY_DIMENSIONS" MAX_MEMORY = ( int(os.getenv("PANGEO_FORGE_MAX_MEMORY")) # type: ignore if os.getenv("PANGEO_FORGE_MAX_MEMORY") @@ -621,9 +623,41 @@ def store_chunk( zarr_array[zarr_region] = data -def finalize_target(target: CacheFSSpecTarget, consolidate_zarr: bool) -> None: +def _gather_coordinate_dimensions(group: zarr.Group) -> List[str]: + return list( + set(itertools.chain(*(group[var].attrs.get(_ARRAY_DIMENSIONS, []) for var in group))) + ) + + +def finalize_target( + target: CacheFSSpecTarget, + consolidate_zarr: bool, + consolidate_dimension_coordinates: bool = True, +) -> None: if target is None: raise ValueError("target has not been set.") + + if consolidate_dimension_coordinates: + logger.info("Consolidating dimension coordinate arrays") + target_mapper = target.get_mapper() + group = zarr.open(target_mapper, mode="a") + dims = _gather_coordinate_dimensions(group) + for dim in dims: + arr = group[dim] + attrs = dict(arr.attrs) + new = group.array( + dim, + arr[:], + chunks=arr.shape, + dtype=arr.dtype, + compressor=arr.compressor, + fill_value=arr.fill_value, + order=arr.order, + filters=arr.filters, + overwrite=True, + ) + new.attrs.update(attrs) + if consolidate_zarr: logger.info("Consolidating Zarr metadata") target_mapper = target.get_mapper() @@ -661,6 +695,9 @@ class XarrayZarrRecipe(BaseRecipe, FilePatternRecipeMixin): ``xr.open_dataset``. This is required for engines that can't open file-like objects (e.g. pynio). :param consolidate_zarr: Whether to consolidate the resulting Zarr dataset. + :param consolidate_dimension_coordinates: Whether to rewrite coordinate variables as a + single chunk. We recommend consolidating coordinate variables to avoid + many small read requests to get the coordinates in xarray. :param xarray_open_kwargs: Extra options for opening the inputs with Xarray. :param xarray_concat_kwargs: Extra options to pass to Xarray when concatenating the inputs to form a chunk. @@ -685,6 +722,7 @@ class XarrayZarrRecipe(BaseRecipe, FilePatternRecipeMixin): cache_inputs: Optional[bool] = None copy_input_to_local_file: bool = False consolidate_zarr: bool = True + consolidate_dimension_coordinates: bool = True xarray_open_kwargs: dict = field(default_factory=dict) xarray_concat_kwargs: dict = field(default_factory=dict) delete_input_encoding: bool = True @@ -851,7 +889,10 @@ def store_chunk(self) -> Callable[[Hashable], None]: @property def finalize_target(self) -> Callable[[], None]: return functools.partial( - finalize_target, target=self.target, consolidate_zarr=self.consolidate_zarr + finalize_target, + target=self.target, + consolidate_zarr=self.consolidate_zarr, + consolidate_dimension_coordinates=self.consolidate_dimension_coordinates, ) def iter_inputs(self) -> Iterator[InputKey]: diff --git a/tests/recipe_tests/test_XarrayZarrRecipe.py b/tests/recipe_tests/test_XarrayZarrRecipe.py index 38297456..547c8f64 100644 --- a/tests/recipe_tests/test_XarrayZarrRecipe.py +++ b/tests/recipe_tests/test_XarrayZarrRecipe.py @@ -4,6 +4,7 @@ import pytest import xarray as xr +import zarr # need to import this way (rather than use pytest.lazy_fixture) to make it work with dask from pytest_lazyfixture import lazy_fixture @@ -274,6 +275,10 @@ def do_actual_chunks_test( assert all([item == chunk_len for item in ds_actual.chunks[other_dim][:-1]]) ds_actual.load() + store = zarr.open_consolidated(target.get_mapper()) + for dim in ds_actual.dims: + assert store[dim].chunks == ds_actual[dim].shape + xr.testing.assert_identical(ds_actual, ds_expected) @@ -340,6 +345,19 @@ def test_chunks_distributed_locking( ) +def test_no_consolidate_dimension_coordinates(netCDFtoZarr_recipe): + RecipeClass, file_pattern, kwargs, ds_expected, target = netCDFtoZarr_recipe + + rec = RecipeClass(file_pattern, **kwargs) + rec.consolidate_dimension_coordinates = False + rec.to_function()() + ds_actual = xr.open_zarr(target.get_mapper()).load() + xr.testing.assert_identical(ds_actual, ds_expected) + + store = zarr.open_consolidated(target.get_mapper()) + assert store["time"].chunks == (file_pattern.nitems_per_input["time"],) + + def test_lock_timeout(netCDFtoZarr_recipe_sequential_only, execute_recipe_no_dask): RecipeClass, file_pattern, kwargs, ds_expected, target = netCDFtoZarr_recipe_sequential_only