Skip to content

Commit

Permalink
Add encode_cf, decode_cf (#69)
Browse files Browse the repository at this point in the history
* Add encode_cf, decode_cf

* cleanup

* Update gitignore

* Add test

* Add cf_xarray as dependency

* Update tests

* Handle multiple CRS

* Updates

* Use crs_wkt directly in decode

* fix tests

* Check indexes for equality

* Add comment

* Don't set crs attribute

* Revert "Don't set crs attribute"

This reverts commit 2a7cf38.

* fix

* Add cf-xarray to conda env

* Update docs

* Add docstring

* Typing fixes: Disallow dataarrays

* Add to api.rst

* Another fix.
  • Loading branch information
dcherian authored Jul 17, 2024
1 parent 6167014 commit 97118ff
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 66 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ dmypy.json

# sphinx
doc/source/generated
doc/source/geo-encoded*

# ruff
.ruff_cache
doc/source/cube.joblib.compressed
doc/source/cube.pickle

cache/
cache/
4 changes: 3 additions & 1 deletion doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Methods
Dataset.xvec.to_geopandas
Dataset.xvec.extract_points
Dataset.xvec.zonal_stats
Dataset.xvec.encode_cf
Dataset.xvec.decode_cf


DataArray.xvec
Expand Down Expand Up @@ -91,4 +93,4 @@ Methods
DataArray.xvec.to_geodataframe
DataArray.xvec.to_geopandas
DataArray.xvec.extract_points
DataArray.xvec.zonal_stats
DataArray.xvec.zonal_stats
142 changes: 81 additions & 61 deletions doc/source/io.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
# required
- shapely=2
- xarray
- cf_xarray
# testing
- pytest
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"xarray >= 2022.12.0",
"pyproj >= 3.0.0",
"shapely >= 2.0b1",
"cf_xarray >= 0.9.2",
]

[project.urls]
Expand Down
122 changes: 121 additions & 1 deletion xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def geom_coords(self) -> Mapping[Hashable, xr.DataArray]:
).coords

@property
def geom_coords_indexed(self) -> Mapping[Hashable, xr.DataArray]:
def geom_coords_indexed(self) -> xr.Coordinates:
"""Returns a dictionary of xarray.DataArray objects corresponding to
coordinate variables using :class:`~xvec.GeometryIndex`.
Expand Down Expand Up @@ -1290,6 +1290,126 @@ def extract_points(
)
return result

def encode_cf(self) -> xr.Dataset:
"""
Encode all geometry variables and associated CRS with CF conventions.
Use this method prior to writing an Xarray dataset to any array format
(e.g. netCDF or Zarr).
The following invariant is satisfied:
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``
CRS information on the ``GeometryIndex`` is encoded using CF's ``grid_mapping`` convention.
This function uses ``cf_xarray.geometry.encode_geometries`` under the hood and will only
work on Datasets.
Returns
-------
Dataset
"""
import cf_xarray as cfxr

if not isinstance(self._obj, xr.Dataset):
raise ValueError(
"CF encoding is only valid on Datasets. Convert to a dataset using `.to_dataset()` first."
)

ds = self._obj.copy()
coords = self.geom_coords_indexed

# TODO: this could use geoxarray, but is quite simple in any case
# Adapted from rioxarray
# 1. First find all unique CRS objects
# preserve ordering for roundtripping
unique_crs = []
for _, xi in sorted(coords.xindexes.items()):
if xi.crs not in unique_crs:
unique_crs.append(xi.crs)
if len(unique_crs) == 1:
grid_mappings = {unique_crs.pop(): "spatial_ref"}
else:
grid_mappings = {
crs_: f"spatial_ref_{i}" for i, crs_ in enumerate(unique_crs)
}

# 2. Convert CRS to grid_mapping variables and assign them
for crs, grid_mapping in grid_mappings.items():
grid_mapping_attrs = crs.to_cf()
# TODO: not all CRS can be represented by CF grid_mappings
# For now, we allow this.
# if "grid_mapping_name" not in grid_mapping_attrs:
# raise ValueError
wkt_str = crs.to_wkt()
grid_mapping_attrs["spatial_ref"] = wkt_str
grid_mapping_attrs["crs_wkt"] = wkt_str
ds.coords[grid_mapping] = xr.Variable(
dims=(), data=0, attrs=grid_mapping_attrs
)

# 3. Associate other variables with appropriate grid_mapping variable
# We asumme that this relation follows from dimension names being shared between
# the GeometryIndex and the variable being checked.
for name, coord in coords.items():
dims = set(coord.dims)
index = coords.xindexes[name]
varnames = (k for k, v in ds._variables.items() if dims & set(v.dims))
for name in varnames:
if TYPE_CHECKING:
assert isinstance(index, GeometryIndex)
ds._variables[name].attrs["grid_mapping"] = grid_mappings[index.crs]

encoded = cfxr.geometry.encode_geometries(ds)
return encoded

def decode_cf(self) -> xr.Dataset:
"""
Decode geometries stored as CF-compliant arrays to shapely geometries.
The following invariant is satisfied:
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``
A ``GeometryIndex`` is created automatically and CRS information, if available
following CF's ``grid_mapping`` convention, will be associated with the ``GeometryIndex``.
This function uses ``cf_xarray.geometry.decode_geometries`` under the hood, and will only
work on Datasets.
Returns
-------
Dataset
"""
import cf_xarray as cfxr

if not isinstance(self._obj, xr.Dataset):
raise ValueError(
"CF decoding is only supported on Datasets. Convert to a Dataset using `.to_dataset()` first."
)

decoded = cfxr.geometry.decode_geometries(self._obj.copy())
crs = {
name: CRS.from_user_input(var.attrs["crs_wkt"])
for name, var in decoded._variables.items()
if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs
}
dims = decoded.xvec.geom_coords.dims
for dim in dims:
decoded = (
decoded.set_xindex(dim) if dim not in decoded._indexes else decoded
)
decoded = decoded.xvec.set_geom_indexes(
dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None))
)
for name in crs:
# remove spatial_ref so the coordinate system is only stored on the index
del decoded[name]
for var in decoded._variables.values():
if set(dims) & set(var.dims):
var.attrs.pop("grid_mapping", None)
return decoded


def _resolve_input(
positional: Mapping[Any, Any] | None,
Expand Down
43 changes: 41 additions & 2 deletions xvec/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def multi_dataset(geom_array, geom_array_z):

@pytest.fixture(scope="session")
def multi_geom_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -80,11 +80,32 @@ def multi_geom_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
def multi_geom_multi_crs_dataset(geom_array, geom_array_z):
ds = (
xr.Dataset(
coords={
"geom": geom_array,
"geom_z": geom_array_z,
}
)
.drop_indexes(["geom", "geom_z"])
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs="EPSG:4362")
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
def multi_geom_no_index_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -96,6 +117,9 @@ def multi_geom_no_index_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -157,3 +181,18 @@ def traffic_dataset(geom_array):
"day": pd.date_range("2023-01-01", periods=10),
},
).xvec.set_geom_indexes(["origin", "destination"], crs=26915)


@pytest.fixture(
params=[
"first_geom_dataset",
"multi_dataset",
"multi_geom_dataset",
"multi_geom_no_index_dataset",
"multi_geom_multi_crs_dataset",
"traffic_dataset",
],
scope="session",
)
def all_datasets(request):
return request.getfixturevalue(request.param)
31 changes: 31 additions & 0 deletions xvec/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,34 @@ def test_extract_points_array():
geometry=4326
),
)


def test_cf_roundtrip(all_datasets):
ds = all_datasets
copy = ds.copy(deep=True)
encoded = ds.xvec.encode_cf()

if unique_crs := {
idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs
}:
nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs)
assert len(unique_crs) == nwkts
roundtripped = encoded.xvec.decode_cf()

xr.testing.assert_identical(ds, roundtripped)
assert_indexes_equals(ds, roundtripped)
# make sure we didn't modify the original dataset.
xr.testing.assert_identical(ds, copy)


def assert_indexes_equals(left, right):
# Till https://github.com/pydata/xarray/issues/5812 is resolved
# Also, we don't record whether an unindexed coordinate was serialized
# So just asssert that the left ("expected") dataset has fewer indexes
# than the right.
# This isn't great...
assert sorted(left.xindexes.keys()) <= sorted(right.xindexes.keys())
for k in left.xindexes:
if not isinstance(left.xindexes[k], GeometryIndex):
continue
assert left.xindexes[k].equals(right.xindexes[k])

0 comments on commit 97118ff

Please sign in to comment.