Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG:dataset: Fix writing tags for bands & prevent overwriting long_name attribute #616

Merged
merged 1 commit into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ History

Latest
------
- BUG:dataset: Fix writing tags for bands (issue #615)
- BUG:dataset: prevent overwriting long_name attribute (pull #616)

0.13.1
------
Expand Down
9 changes: 7 additions & 2 deletions rioxarray/raster_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,18 +497,23 @@ def to_raster(
"""
variable_dim = f"band_{uuid4()}"
data_array = self._obj.to_array(dim=variable_dim)
# write data array names to raster
data_array.attrs["long_name"] = data_array[variable_dim].values.tolist()
# ensure raster metadata preserved
scales = []
offsets = []
nodatavals = []
band_tags = []
long_name = []
for data_var in data_array[variable_dim].values:
scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
long_name.append(self._obj[data_var].attrs.get("long_name", data_var))
nodatavals.append(self._obj[data_var].rio.nodata)
band_tags.append(self._obj[data_var].attrs.copy())
data_array.attrs["scales"] = scales
data_array.attrs["offsets"] = offsets
data_array.attrs["band_tags"] = band_tags
data_array.attrs["long_name"] = long_name

nodata = nodatavals[0]
if (
all(nodataval == nodata for nodataval in nodatavals)
Expand Down
67 changes: 42 additions & 25 deletions rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,10 @@ def is_dask_collection(_) -> bool: # type: ignore
# Note: transform & crs are removed in write_transform/write_crs


def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
def _write_tags(raster_handle, tags):
"""
Write the metadata stored in the xarray object to raster metadata
Write tags to raster dataset
"""
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
raster_handle.scales = tags["scales"]
except KeyError:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
raster_handle.scales = (scale_factor,) * raster_handle.count
try:
raster_handle.offsets = tags["offsets"]
except KeyError:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count

# filter out attributes that should be written in a different location
skip_tags = (
UNWANTED_RIO_ATTRS
Expand All @@ -80,10 +58,19 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
# in this case, it will be stored in the raster description
if not isinstance(tags.get("long_name"), str):
skip_tags += ("long_name",)
band_tags = tags.pop("band_tags", [])
tags = {key: value for key, value in tags.items() if key not in skip_tags}
raster_handle.update_tags(**tags)

# write band name information
if isinstance(band_tags, list):
for iii, band_tag in enumerate(band_tags):
raster_handle.update_tags(iii + 1, **band_tag)


def _write_band_description(raster_handle, xarray_dataset):
"""
Write band descriptions using the long name
"""
long_name = xarray_dataset.attrs.get("long_name")
if isinstance(long_name, (tuple, list)):
if len(long_name) != raster_handle.count:
Expand All @@ -100,6 +87,36 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
raster_handle.set_band_description(iii + 1, band_description)


def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
"""
Write the metadata stored in the xarray object to raster metadata
"""
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
raster_handle.scales = tags["scales"]
except KeyError:
scale_factor = tags.get(
"scale_factor", xarray_dataset.encoding.get("scale_factor")
)
if scale_factor is not None:
raster_handle.scales = (scale_factor,) * raster_handle.count
try:
raster_handle.offsets = tags["offsets"]
except KeyError:
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count

_write_tags(raster_handle=raster_handle, tags=tags)
_write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)


def _ensure_nodata_dtype(original_nodata, new_dtype):
"""
Convert the nodata to the new datatype and raise warning
Expand Down
46 changes: 31 additions & 15 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,25 +484,41 @@ def test_utm():
assert "y" not in rioda.coords


def test_band_as_variable(open_rasterio):
def test_band_as_variable(open_rasterio, tmp_path):
test_raster = tmp_path / "test.tif"

with create_tmp_geotiff() as (tmp_file, expected):
with open_rasterio(
tmp_file, band_as_variable=True, mask_and_scale=False
) as riods:
for band in (1, 2, 3):
band_name = f"band_{band}"
assert_allclose(riods[band_name], expected.sel(band=band).drop("band"))
assert riods[band_name].attrs["BAND"] == band
assert riods[band_name].attrs["scale_factor"] == 1.0
assert riods[band_name].attrs["add_offset"] == 0.0
assert riods[band_name].attrs["long_name"] == f"d{band}"
assert riods[band_name].attrs["units"] == f"u{band}"
assert riods[band_name].rio.crs == expected.rio.crs
assert_array_equal(
riods[band_name].rio.resolution(), expected.rio.resolution()
)
assert isinstance(riods[band_name].rio._cached_transform(), Affine)
assert riods[band_name].rio.nodata is None

def _check_raster(raster_ds):
for band in (1, 2, 3):
band_name = f"band_{band}"
assert_allclose(
raster_ds[band_name], expected.sel(band=band).drop("band")
)
assert raster_ds[band_name].attrs["BAND"] == band
assert raster_ds[band_name].attrs["scale_factor"] == 1.0
assert raster_ds[band_name].attrs["add_offset"] == 0.0
assert raster_ds[band_name].attrs["long_name"] == f"d{band}"
assert raster_ds[band_name].attrs["units"] == f"u{band}"
assert raster_ds[band_name].rio.crs == expected.rio.crs
assert_array_equal(
raster_ds[band_name].rio.resolution(), expected.rio.resolution()
)
assert isinstance(
raster_ds[band_name].rio._cached_transform(), Affine
)
assert raster_ds[band_name].rio.nodata is None

_check_raster(riods)
# test roundtrip
riods.rio.to_raster(test_raster)
with open_rasterio(
test_raster, band_as_variable=True, mask_and_scale=False
) as riods_round:
_check_raster(riods_round)


def test_platecarree():
Expand Down
2 changes: 1 addition & 1 deletion test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
with rioxarray.open_rasterio(str(output_raster)) as rdscompare:
assert rdscompare.scale_factor == 0.1
assert rdscompare.add_offset == 220.0
assert rdscompare.long_name == "air_temperature"
assert rdscompare.long_name == "tmmx"
assert rdscompare.rio.crs == rds.rio.crs
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata

Expand Down