Skip to content

Commit

Permalink
BUG: Support writing GCPs to netCDF
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Jun 20, 2024
1 parent 98abd84 commit 4f5a974
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
------
- BUG: Raise OverflowError when nodata data type conversion is unsafe (pull #782)
- BUG: Support writing GCPs to netCDF (issue #778)

0.15.5
------
Expand Down
14 changes: 10 additions & 4 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
to xarray datasets/dataarrays.
"""
# pylint: disable=too-many-lines
import json
import math
import warnings
from collections.abc import Hashable, Iterable
Expand Down Expand Up @@ -289,6 +290,7 @@ def __init__(self, xarray_obj: Union[xarray.DataArray, xarray.Dataset]):
self._height: Optional[int] = None
self._width: Optional[int] = None
self._crs: Union[rasterio.crs.CRS, None, Literal[False]] = None
self._gcps: Optional[dict] = None

@property
def crs(self) -> Optional[rasterio.crs.CRS]:
Expand Down Expand Up @@ -347,6 +349,7 @@ def _get_obj(self, inplace: bool) -> Union[xarray.Dataset, xarray.DataArray]:
obj_copy.rio._width = self._width
obj_copy.rio._height = self._height
obj_copy.rio._crs = self._crs
obj_copy.rio._gcps = self._gcps
return obj_copy

def set_crs(
Expand Down Expand Up @@ -1235,7 +1238,8 @@ def write_gcps(
gcp_crs, grid_mapping_name=grid_mapping_name, inplace=inplace
)
geojson_gcps = _convert_gcps_to_geojson(gcps)
data_obj.coords[grid_mapping_name].attrs["gcps"] = geojson_gcps
data_obj.coords[grid_mapping_name].attrs["gcps"] = json.dumps(geojson_gcps)
self._gcps = gcps
return data_obj

def get_gcps(self) -> Optional[list[GroundControlPoint]]:
Expand All @@ -1249,8 +1253,10 @@ def get_gcps(self) -> Optional[list[GroundControlPoint]]:
list of :obj:`rasterio.control.GroundControlPoint` or None
The Ground Control Points from the dataset or None if not applicable
"""
if self._gcps is not None:
return self._gcps
try:
geojson_gcps = self._obj.coords[self.grid_mapping].attrs["gcps"]
geojson_gcps = json.loads(self._obj.coords[self.grid_mapping].attrs["gcps"])
except (KeyError, AttributeError):
return None

Expand All @@ -1267,8 +1273,8 @@ def _parse_gcp(gcp) -> GroundControlPoint:
info=gcp["properties"]["info"],
)

gcps = [_parse_gcp(gcp) for gcp in geojson_gcps["features"]]
return gcps
self._gcps = [_parse_gcp(gcp) for gcp in geojson_gcps["features"]]
return self._gcps


def _convert_gcps_to_geojson(
Expand Down
28 changes: 28 additions & 0 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,34 @@ def test_writing_gcps(tmp_path):
_check_rio_gcps(darr, *gdal_gcps)


def test_writing_gcps__to_netcdf(tmp_path):
"""
Test writing gcps to a netCDF file.
"""
tiffname = tmp_path / "test.tif"
nc_name = tmp_path / "test_written.nc"

src_gcps, crs = _create_gdal_gcps()

with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=numpy.uint8,
driver="GTiff",
) as source:
source.gcps = (src_gcps, crs)

with rioxarray.open_rasterio(tiffname) as darr:
darr.to_netcdf(nc_name)

with xarray.open_dataset(nc_name, decode_coords="all") as darr:
assert "gcps" in darr.coords["spatial_ref"].attrs
_check_rio_gcps(darr, src_gcps=src_gcps, crs=crs)


def test_read_file_handle_with_dask():
with open(
os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif"), "rb"
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2987,7 +2987,7 @@ def _check_rio_gcps(darr, src_gcps, crs):
assert "y" not in darr.coords
assert darr.rio.crs == crs
assert "gcps" in darr.spatial_ref.attrs
gcps = darr.spatial_ref.attrs["gcps"]
gcps = json.loads(darr.spatial_ref.attrs["gcps"])
assert gcps["type"] == "FeatureCollection"
assert len(gcps["features"]) == len(src_gcps)
for feature, gcp in zip(gcps["features"], src_gcps):
Expand Down

0 comments on commit 4f5a974

Please sign in to comment.