Skip to content

Commit

Permalink
Identify non dimension coords (#156)
Browse files Browse the repository at this point in the history
* test non dimension coordinates are properly detected and roundtripped

* ensure non dimension coordinates are read from kerchunk references

* ensure kerchunk references written out record non dimension coordinates

* release notes
  • Loading branch information
TomNicholas authored Jun 24, 2024
1 parent ef2429a commit daf0377
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Bug fixes
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Ensure that `.attrs` on coordinate variables are preserved during round-tripping. (:issue:`155`, :pull:`154`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Ensure that non-dimension coordinate variables described via the CF conventions are preserved during round-tripping. (:issue:`105`, :pull:`156`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
9 changes: 8 additions & 1 deletion virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,18 @@ def dataset_to_kerchunk_refs(ds: xr.Dataset) -> KerchunkStoreRefs:

all_arr_refs.update(prepended_with_var_name)

zattrs = ds.attrs
if ds.coords:
coord_names = list(ds.coords)
# this weird concatenated string instead of a list of strings is inconsistent with how other features in the kerchunk references format are stored
# see https://github.com/zarr-developers/VirtualiZarr/issues/105#issuecomment-2187266739
zattrs["coordinates"] = " ".join(coord_names)

ds_refs = {
"version": 1,
"refs": {
".zgroup": '{"zarr_format":2}',
".zattrs": ujson.dumps(ds.attrs),
".zattrs": ujson.dumps(zattrs),
**all_arr_refs,
},
}
Expand Down
26 changes: 26 additions & 0 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import xarray as xr
import xarray.testing as xrt
Expand Down Expand Up @@ -95,6 +96,31 @@ def test_kerchunk_roundtrip_concat(self, tmpdir, format):
# assert identical to original dataset
xrt.assert_identical(roundtrip, ds)

def test_non_dimension_coordinates(self, tmpdir, format):
# regression test for GH issue #105

# set up example xarray dataset containing non-dimension coordinate variables
ds = xr.Dataset(coords={"lat": (["x", "y"], np.arange(6).reshape(2, 3))})

# save it to disk as netCDF (in temporary directory)
ds.to_netcdf(f"{tmpdir}/non_dim_coords.nc")

vds = open_virtual_dataset(f"{tmpdir}/non_dim_coords.nc", indexes={})

assert "lat" in vds.coords
assert "coordinates" not in vds.attrs

# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = xr.open_dataset(
f"{tmpdir}/refs.{format}", engine="kerchunk", decode_times=False
)

# assert equal to original dataset
xrt.assert_identical(roundtrip, ds)


def test_open_scalar_variable(tmpdir):
# regression test for GH issue #100
Expand Down
21 changes: 13 additions & 8 deletions virtualizarr/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def open_virtual_dataset(
virtual_array_class=virtual_array_class,
)
ds_attrs = kerchunk.fully_decode_arr_refs(vds_refs["refs"]).get(".zattrs", {})
coord_names = ds_attrs.pop("coordinates", [])

if indexes is None or len(loadable_variables) > 0:
# TODO we are reading a bunch of stuff we know we won't need here, e.g. all of the data variables...
Expand Down Expand Up @@ -152,7 +153,7 @@ def open_virtual_dataset(

vars = {**virtual_vars, **loadable_vars}

data_vars, coords = separate_coords(vars, indexes)
data_vars, coords = separate_coords(vars, indexes, coord_names)

vds = xr.Dataset(
data_vars,
Expand All @@ -177,6 +178,7 @@ def open_virtual_dataset_from_v3_store(
_storepath = Path(storepath)

ds_attrs = attrs_from_zarr_group_json(_storepath / "zarr.json")
coord_names = ds_attrs.pop("coordinates", [])

# TODO recursive glob to create a datatree
# Note: this .is_file() check should not be necessary according to the pathlib docs, but tests fail on github CI without it
Expand Down Expand Up @@ -205,7 +207,7 @@ def open_virtual_dataset_from_v3_store(
else:
indexes = dict(**indexes) # for type hinting: to allow mutation

data_vars, coords = separate_coords(vars, indexes)
data_vars, coords = separate_coords(vars, indexes, coord_names)

ds = xr.Dataset(
data_vars,
Expand All @@ -223,8 +225,10 @@ def virtual_vars_from_kerchunk_refs(
virtual_array_class=ManifestArray,
) -> Mapping[str, xr.Variable]:
"""
Translate a store-level kerchunk reference dict into aa set of xarray Variables containing virtualized arrays.
Translate a store-level kerchunk reference dict into aaset of xarray Variables containing virtualized arrays.
Parameters
----------
drop_variables: list[str], default is None
Variables in the file to drop before returning.
virtual_array_class
Expand Down Expand Up @@ -263,12 +267,12 @@ def dataset_from_kerchunk_refs(
"""

vars = virtual_vars_from_kerchunk_refs(refs, drop_variables, virtual_array_class)
ds_attrs = kerchunk.fully_decode_arr_refs(refs["refs"]).get(".zattrs", {})
coord_names = ds_attrs.pop("coordinates", [])

if indexes is None:
indexes = {}
data_vars, coords = separate_coords(vars, indexes)

ds_attrs = kerchunk.fully_decode_arr_refs(refs["refs"]).get(".zattrs", {})
data_vars, coords = separate_coords(vars, indexes, coord_names)

vds = xr.Dataset(
data_vars,
Expand Down Expand Up @@ -301,6 +305,7 @@ def variable_from_kerchunk_refs(
def separate_coords(
vars: Mapping[str, xr.Variable],
indexes: MutableMapping[str, Index],
coord_names: Iterable[str] | None = None,
) -> tuple[Mapping[str, xr.Variable], xr.Coordinates]:
"""
Try to generate a set of coordinates that won't cause xarray to automatically build a pandas.Index for the 1D coordinates.
Expand All @@ -310,8 +315,8 @@ def separate_coords(
Will also preserve any loaded variables and indexes it is passed.
"""

# this would normally come from CF decoding, let's hope the fact we're skipping that doesn't cause any problems...
coord_names: list[str] = []
if coord_names is None:
coord_names = []

# split data and coordinate variables (promote dimension coordinates)
data_vars = {}
Expand Down

0 comments on commit daf0377

Please sign in to comment.