Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Handle netCDF _Unsigned and load in all attributes #575

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ jobs:
conda info

- name: pylint
if: matrix.python-version == '3.8'
shell: bash
run: |
source activate test
pylint rioxarray/

- name: mypy
shell: bash
if: matrix.python-version == '3.8'
run: |
source activate test
mypy rioxarray/
Expand Down
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ History

Latest
------
- BUG: Handle `_Unsigned` and load in all attributes (pull #575)

0.12.0
-------
Expand Down
159 changes: 121 additions & 38 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import re
import threading
import warnings
from typing import Any, Dict, Hashable, List, Optional, Tuple, Union
from typing import Any, Dict, Hashable, Iterable, List, Optional, Tuple, Union

import numpy as np
import rasterio
from numpy.typing import NDArray
from packaging import version
from rasterio.errors import NotGeoreferencedWarning
from rasterio.vrt import WarpedVRT
Expand All @@ -39,6 +40,18 @@
RasterioReader = Union[rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT]


def _get_unsigned_dtype(unsigned, dtype):
"""
Based on: https://github.com/pydata/xarray/blob/abe1e613a96b000ae603c53d135828df532b952e/xarray/coding/variables.py#L306-L334
"""
dtype = np.dtype(dtype)
if unsigned is True and dtype.kind == "i":
return np.dtype(f"u{dtype.itemsize}")
if unsigned is False and dtype.kind == "u":
return np.dtype(f"i{dtype.itemsize}")
return None


class FileHandleLocal(threading.local):
"""
This contains the thread local ThreadURIManager
Expand Down Expand Up @@ -173,28 +186,31 @@ def __init__(
self._shape = (riods.count, riods.height, riods.width)

self._dtype = None
self._unsigned_dtype = None
self._fill_value = riods.nodata
dtypes = riods.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError("All bands should have the same dtype")

dtype = _rasterio_to_numpy_dtype(dtypes)

# handle unsigned case
if mask_and_scale and unsigned and dtype.kind == "i":
self._dtype = np.dtype(f"u{dtype.itemsize}")
elif mask_and_scale and unsigned:
warnings.warn(
f"variable {name!r} has _Unsigned attribute but is not "
"of integer type. Ignoring attribute.",
variables.SerializationWarning,
stacklevel=3,
if mask_and_scale and unsigned is not None:
self._unsigned_dtype = _get_unsigned_dtype(
unsigned=unsigned,
dtype=dtype,
)
self._fill_value = riods.nodata
if self._dtype is None:
if self.masked:
self._dtype, self._fill_value = maybe_promote(dtype)
else:
self._dtype = dtype
if self._unsigned_dtype is not None and self._fill_value is not None:
self._fill_value = self._unsigned_dtype.type(self._fill_value)
if self._unsigned_dtype is None and dtype.kind not in ("i", "u"):
warnings.warn(
f"variable {name!r} has _Unsigned attribute but is not "
"of integer type. Ignoring attribute.",
variables.SerializationWarning,
stacklevel=3,
)
if self.masked:
self._dtype, self._fill_value = maybe_promote(dtype)
else:
self._dtype = dtype

@property
def dtype(self):
Expand Down Expand Up @@ -288,6 +304,8 @@ def _getitem(self, key):
if self.vrt_params is not None:
riods = WarpedVRT(riods, **self.vrt_params)
out = riods.read(band_key, window=window, masked=self.masked)
if self._unsigned_dtype is not None:
out = out.astype(self._unsigned_dtype)
if self.masked:
out = np.ma.filled(out.astype(self.dtype), self.fill_value)
if self.mask_and_scale:
Expand Down Expand Up @@ -418,29 +436,44 @@ def _load_netcdf_attrs(tags: Dict, data_array: DataArray) -> None:
data_array.coords[variable_name].attrs.update({attr_name: value})


def _parse_netcdf_attr_array(attr: Union[NDArray, str], dtype=None) -> NDArray:
"""
Expected format: '{2,6}' or '[2. 6.]'
"""
value: Union[NDArray, str, List]
if isinstance(attr, str):
if attr.startswith("{"):
value = attr.strip("{}").split(",")
else:
value = attr.strip("[]").split()
elif not isinstance(attr, Iterable):
value = [attr]
else:
value = attr
return np.array(value, dtype=dtype)


def _load_netcdf_1d_coords(tags: Dict) -> Dict:
"""
Dimension information:
- NETCDF_DIM_EXTRA: '{time}' (comma separated list of dim names)
- NETCDF_DIM_time_DEF: '{2,6}' (dim size, dim dtype)
- NETCDF_DIM_time_VALUES: '{0,872712.659688}' (comma separated list of data)
- NETCDF_DIM_time_DEF: '{2,6}' or '[2. 6.]' (dim size, dim dtype)
- NETCDF_DIM_time_VALUES: '{0,872712.659688}' (comma separated list of data) or [ 0. 872712.659688]
"""
dim_names = tags.get("NETCDF_DIM_EXTRA")
if not dim_names:
return {}
dim_names = dim_names.strip("{}").split(",")
dim_names = _parse_netcdf_attr_array(dim_names)
coords = {}
for dim_name in dim_names:
dim_def = tags.get(f"NETCDF_DIM_{dim_name}_DEF")
if not dim_def:
if dim_def is None:
continue
# pylint: disable=unused-variable
dim_size, dim_dtype = dim_def.strip("{}").split(",")
dim_dtype = NETCDF_DTYPE_MAP.get(int(dim_dtype), object)
dim_values = tags[f"NETCDF_DIM_{dim_name}_VALUES"].strip("{}")
coords[dim_name] = IndexVariable(
dim_name, np.fromstring(dim_values, dtype=dim_dtype, sep=",")
)
dim_size, dim_dtype = _parse_netcdf_attr_array(dim_def)
dim_dtype = NETCDF_DTYPE_MAP.get(int(float(dim_dtype)), object)
dim_values = _parse_netcdf_attr_array(tags[f"NETCDF_DIM_{dim_name}_VALUES"])
coords[dim_name] = IndexVariable(dim_name, dim_values)
return coords


Expand Down Expand Up @@ -491,7 +524,7 @@ def _get_rasterio_attrs(riods: RasterioReader):
"""
# pylint: disable=too-many-branches
# Add rasterio attributes
attrs = _parse_tags(riods.tags(1))
attrs = _parse_tags({**riods.tags(), **riods.tags(1)})
if riods.nodata is not None:
# The nodata values for the raster bands
attrs["_FillValue"] = riods.nodata
Expand Down Expand Up @@ -591,6 +624,19 @@ def _parse_driver_tags(
attrs[key] = value


def _pop_global_netcdf_attrs_from_vars(dataset_to_clean: Dataset) -> Dataset:
# remove GLOBAL netCDF attributes from dataset variables
for coord in dataset_to_clean.coords:
for variable in dataset_to_clean.variables:
dataset_to_clean[variable].attrs = {
attr: value
for attr, value in dataset_to_clean[variable].attrs.items()
if attr not in dataset_to_clean.attrs
and not attr.startswith(f"{coord}#")
}
return dataset_to_clean


def _load_subdatasets(
riods: RasterioReader,
group: Optional[Union[str, List[str], Tuple[str, ...]]],
Expand All @@ -608,7 +654,7 @@ def _load_subdatasets(
"""
Load in rasterio subdatasets
"""
base_tags = _parse_tags(riods.tags())
global_tags = _parse_tags(riods.tags())
dim_groups = {}
subdataset_filter = None
if any((group, variable)):
Expand Down Expand Up @@ -638,12 +684,15 @@ def _load_subdatasets(

if len(dim_groups) > 1:
dataset: Union[Dataset, List[Dataset]] = [
Dataset(dim_group, attrs=base_tags) for dim_group in dim_groups.values()
_pop_global_netcdf_attrs_from_vars(Dataset(dim_group, attrs=global_tags))
for dim_group in dim_groups.values()
]
elif not dim_groups:
dataset = Dataset(attrs=base_tags)
dataset = Dataset(attrs=global_tags)
else:
dataset = Dataset(list(dim_groups.values())[0], attrs=base_tags)
dataset = _pop_global_netcdf_attrs_from_vars(
Dataset(list(dim_groups.values())[0], attrs=global_tags)
)
return dataset


Expand Down Expand Up @@ -689,7 +738,11 @@ def _prepare_dask(


def _handle_encoding(
result: DataArray, mask_and_scale: bool, masked: bool, da_name: Optional[Hashable]
result: DataArray,
mask_and_scale: bool,
masked: bool,
da_name: Optional[Hashable],
unsigned: Union[bool, None],
) -> None:
"""
Make sure encoding handled properly
Expand All @@ -711,6 +764,16 @@ def _handle_encoding(
result.attrs, result.encoding, "missing_value", name=da_name
)

if mask_and_scale and unsigned is not None and "_FillValue" in result.encoding:
unsigned_dtype = _get_unsigned_dtype(
unsigned=unsigned,
dtype=result.encoding["dtype"],
)
if unsigned_dtype is not None:
result.encoding["_FillValue"] = unsigned_dtype.type(
result.encoding["_FillValue"]
)


def open_rasterio(
filename: Union[
Expand Down Expand Up @@ -879,13 +942,19 @@ def open_rasterio(

# parse tags & load alternate coords
attrs = _get_rasterio_attrs(riods=riods)
coords = _load_netcdf_1d_coords(riods.tags())
coords = _load_netcdf_1d_coords(attrs)
_parse_driver_tags(riods=riods, attrs=attrs, coords=coords)
for coord in coords:
if f"NETCDF_DIM_{coord}" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}")
break
if f"NETCDF_DIM_{coord}_VALUES" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}_VALUES")
attrs.pop(f"NETCDF_DIM_{coord}_DEF", None)
attrs.pop("NETCDF_DIM_EXTRA", None)
break
else:
coord_name = "band"
coords[coord_name] = np.asarray(riods.indexes)
Expand All @@ -900,7 +969,7 @@ def open_rasterio(
_generate_spatial_coords(riods.transform, riods.width, riods.height)
)

unsigned = False
unsigned = None
encoding: Dict[Hashable, Any] = {}
if mask_and_scale and "_Unsigned" in attrs:
unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true"
Expand Down Expand Up @@ -938,11 +1007,11 @@ def open_rasterio(
)

# make sure the _FillValue is correct dtype
if "_FillValue" in attrs:
attrs["_FillValue"] = result.dtype.type(attrs["_FillValue"])
if "_FillValue" in result.attrs:
result.attrs["_FillValue"] = result.dtype.type(result.attrs["_FillValue"])

# handle encoding
_handle_encoding(result, mask_and_scale, masked, da_name)
_handle_encoding(result, mask_and_scale, masked, da_name, unsigned=unsigned)
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
Expand All @@ -964,4 +1033,18 @@ def open_rasterio(
# add file path to encoding
result.encoding["source"] = riods.name
result.encoding["rasterio_dtype"] = str(riods.dtypes[0])
# remove duplicate coordinate information
for coord in result.coords:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{coord}#")
}
# remove duplicate tags
if result.name:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}
return result
21 changes: 20 additions & 1 deletion rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rasterio.windows import Window
from xarray.conventions import encode_cf_variable

from rioxarray._io import _get_unsigned_dtype
from rioxarray.exceptions import RioXarrayError

try:
Expand All @@ -39,7 +40,11 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
"""
Write the metadata stored in the xarray object to raster metadata
"""
tags = xarray_dataset.attrs if tags is None else {**xarray_dataset.attrs, **tags}
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
Expand Down Expand Up @@ -224,11 +229,25 @@ def to_raster(self, xarray_dataarray, tags, windowed, lock, compute, **kwargs):
**kwargs
Keyword arguments to pass into writing the raster.
"""
xarray_dataarray = xarray_dataarray.copy()
kwargs["dtype"], numpy_dtype = _get_dtypes(
kwargs["dtype"],
xarray_dataarray.encoding.get("rasterio_dtype"),
xarray_dataarray.encoding.get("dtype", str(xarray_dataarray.dtype)),
)
# there is no equivalent for netCDF _Unsigned
# across output GDAL formats. It is safest to convert beforehand.
# https://github.com/OSGeo/gdal/issues/6352#issuecomment-1245981837
if "_Unsigned" in xarray_dataarray.encoding:
unsigned_dtype = _get_unsigned_dtype(
unsigned=xarray_dataarray.encoding["_Unsigned"] == "true",
dtype=numpy_dtype,
)
if unsigned_dtype is not None:
numpy_dtype = unsigned_dtype
kwargs["dtype"] = unsigned_dtype
xarray_dataarray.encoding["rasterio_dtype"] = str(unsigned_dtype)
xarray_dataarray.encoding["dtype"] = str(unsigned_dtype)

if kwargs["nodata"] is not None:
# Ensure dtype of output data matches the expected dtype.
Expand Down
Loading