Skip to content

Commit

Permalink
BUG: Handle _Unsigned and load in all attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 committed Sep 9, 2022
1 parent f23d3d4 commit ea3ff6c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 32 deletions.
121 changes: 92 additions & 29 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
RasterioReader = Union[rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT]


def _get_unsigned_dtype(unsigned, dtype):
dtype = np.dtype(dtype)
if unsigned is True and dtype.kind == "i":
return np.dtype(f"u{dtype.itemsize}")
elif unsigned is False and dtype.kind == "u":
return np.dtype(f"i{dtype.itemsize}")
return None


class FileHandleLocal(threading.local):
"""
This contains the thread local ThreadURIManager
Expand Down Expand Up @@ -173,28 +182,29 @@ def __init__(
self._shape = (riods.count, riods.height, riods.width)

self._dtype = None
self._unsigned_dtype = None
self._fill_value = riods.nodata
dtypes = riods.dtypes
if not np.all(np.asarray(dtypes) == dtypes[0]):
raise ValueError("All bands should have the same dtype")

dtype = _rasterio_to_numpy_dtype(dtypes)

# handle unsigned case
if mask_and_scale and unsigned and dtype.kind == "i":
self._dtype = np.dtype(f"u{dtype.itemsize}")
elif mask_and_scale and unsigned:
warnings.warn(
f"variable {name!r} has _Unsigned attribute but is not "
"of integer type. Ignoring attribute.",
variables.SerializationWarning,
stacklevel=3,
if mask_and_scale and unsigned is not None:
self._unsigned_dtype = _get_unsigned_dtype(
unsigned=unsigned,
dtype=dtype,
)
self._fill_value = riods.nodata
if self._dtype is None:
if self.masked:
self._dtype, self._fill_value = maybe_promote(dtype)
else:
self._dtype = dtype
if self._unsigned_dtype is None:
warnings.warn(
f"variable {name!r} has _Unsigned attribute but is not "
"of integer type. Ignoring attribute.",
variables.SerializationWarning,
stacklevel=3,
)
if self.masked:
self._dtype, self._fill_value = maybe_promote(dtype)
else:
self._dtype = dtype

@property
def dtype(self):
Expand Down Expand Up @@ -288,6 +298,8 @@ def _getitem(self, key):
if self.vrt_params is not None:
riods = WarpedVRT(riods, **self.vrt_params)
out = riods.read(band_key, window=window, masked=self.masked)
if self._unsigned_dtype is not None:
out = out.astype(self._unsigned_dtype)
if self.masked:
out = np.ma.filled(out.astype(self.dtype), self.fill_value)
if self.mask_and_scale:
Expand Down Expand Up @@ -418,28 +430,37 @@ def _load_netcdf_attrs(tags: Dict, data_array: DataArray) -> None:
data_array.coords[variable_name].attrs.update({attr_name: value})


def _parse_netcdf_attr_array(value: str) -> List:
"""
Expected format: '{2,6}' or '[2. 6.]'
"""
if value.startswith("{"):
return value.strip("{}").split(",")
return [element.strip() for element in value.strip("[]").split()]


def _load_netcdf_1d_coords(tags: Dict) -> Dict:
"""
Dimension information:
- NETCDF_DIM_EXTRA: '{time}' (comma separated list of dim names)
- NETCDF_DIM_time_DEF: '{2,6}' (dim size, dim dtype)
- NETCDF_DIM_time_VALUES: '{0,872712.659688}' (comma separated list of data)
- NETCDF_DIM_time_DEF: '{2,6}' or '[2. 6.]' (dim size, dim dtype)
- NETCDF_DIM_time_VALUES: '{0,872712.659688}' (comma separated list of data) or [ 0. 872712.659688]
"""
dim_names = tags.get("NETCDF_DIM_EXTRA")
if not dim_names:
return {}
dim_names = dim_names.strip("{}").split(",")
dim_names = _parse_netcdf_attr_array(dim_names)
coords = {}
for dim_name in dim_names:
dim_def = tags.get(f"NETCDF_DIM_{dim_name}_DEF")
if not dim_def:
continue
# pylint: disable=unused-variable
dim_size, dim_dtype = dim_def.strip("{}").split(",")
dim_dtype = NETCDF_DTYPE_MAP.get(int(dim_dtype), object)
dim_values = tags[f"NETCDF_DIM_{dim_name}_VALUES"].strip("{}")
dim_size, dim_dtype = _parse_netcdf_attr_array(dim_def)
dim_dtype = NETCDF_DTYPE_MAP.get(int(float(dim_dtype)), object)
dim_values = _parse_netcdf_attr_array(tags[f"NETCDF_DIM_{dim_name}_VALUES"])
coords[dim_name] = IndexVariable(
dim_name, np.fromstring(dim_values, dtype=dim_dtype, sep=",")
dim_name, np.fromstring(",".join(dim_values), dtype=dim_dtype, sep=",")
)
return coords

Expand Down Expand Up @@ -491,7 +512,7 @@ def _get_rasterio_attrs(riods: RasterioReader):
"""
# pylint: disable=too-many-branches
# Add rasterio attributes
attrs = _parse_tags(riods.tags(1))
attrs = _parse_tags({**riods.tags(), **riods.tags(1)})
if riods.nodata is not None:
# The nodata values for the raster bands
attrs["_FillValue"] = riods.nodata
Expand Down Expand Up @@ -608,7 +629,7 @@ def _load_subdatasets(
"""
Load in rasterio subdatasets
"""
base_tags = _parse_tags(riods.tags())
global_tags = _parse_tags(riods.tags())
dim_groups = {}
subdataset_filter = None
if any((group, variable)):
Expand Down Expand Up @@ -638,12 +659,27 @@ def _load_subdatasets(

if len(dim_groups) > 1:
dataset: Union[Dataset, List[Dataset]] = [
Dataset(dim_group, attrs=base_tags) for dim_group in dim_groups.values()
Dataset(dim_group, attrs=global_tags) for dim_group in dim_groups.values()
]
elif not dim_groups:
dataset = Dataset(attrs=base_tags)
dataset = Dataset(attrs=global_tags)
else:
dataset = Dataset(list(dim_groups.values())[0], attrs=global_tags)

def _pop_duplicate_netcdf_attrs(dataset_to_clean):
for coord in dataset_to_clean.coords:
for variable in dataset_to_clean.variables:
dataset_to_clean[variable].attrs = {
attr: value
for attr, value in dataset_to_clean[variable].attrs.items()
if attr not in global_tags and not attr.startswith(f"{coord}#")
}

if isinstance(dataset, list):
for dataset_item in dataset:
_pop_duplicate_netcdf_attrs(dataset_item)
else:
dataset = Dataset(list(dim_groups.values())[0], attrs=base_tags)
_pop_duplicate_netcdf_attrs(dataset)
return dataset


Expand Down Expand Up @@ -886,6 +922,12 @@ def open_rasterio(
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}")
break
elif f"NETCDF_DIM_{coord}_VALUES" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}_VALUES")
attrs.pop(f"NETCDF_DIM_{coord}_DEF", None)
attrs.pop("NETCDF_DIM_EXTRA", None)
break
else:
coord_name = "band"
coords[coord_name] = np.asarray(riods.indexes)
Expand All @@ -900,7 +942,7 @@ def open_rasterio(
_generate_spatial_coords(riods.transform, riods.width, riods.height)
)

unsigned = False
unsigned = None
encoding: Dict[Hashable, Any] = {}
if mask_and_scale and "_Unsigned" in attrs:
unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true"
Expand Down Expand Up @@ -939,6 +981,13 @@ def open_rasterio(

# make sure the _FillValue is correct dtype
if "_FillValue" in attrs:
if mask_and_scale and unsigned is not None:
unsigned_dtype = _get_unsigned_dtype(
unsigned=unsigned,
dtype=encoding["dtype"],
)
if unsigned_dtype is not None:
attrs["_FillValue"] = unsigned_dtype.type(attrs["_FillValue"])
attrs["_FillValue"] = result.dtype.type(attrs["_FillValue"])

# handle encoding
Expand All @@ -964,4 +1013,18 @@ def open_rasterio(
# add file path to encoding
result.encoding["source"] = riods.name
result.encoding["rasterio_dtype"] = str(riods.dtypes[0])
# remove duplicate coordinate information
for coord in result.coords:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{coord}#")
}
# remove duplicate tags
if result.name:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}
return result
9 changes: 8 additions & 1 deletion rioxarray/raster_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
"""
Write the metadata stored in the xarray object to raster metadata
"""
tags = xarray_dataset.attrs if tags is None else {**xarray_dataset.attrs, **tags}
tags = (
xarray_dataset.attrs.copy()
if tags is None
else {**xarray_dataset.attrs, **tags}
)

# write scales and offsets
try:
Expand All @@ -57,6 +61,9 @@ def _write_metatata_to_raster(raster_handle, xarray_dataset, tags):
if add_offset is not None:
raster_handle.offsets = (add_offset,) * raster_handle.count

if "_Unsigned" in xarray_dataset.encoding:
tags["_Unsigned"] = xarray_dataset.encoding["_Unsigned"]

# filter out attributes that should be written in a different location
skip_tags = (
UNWANTED_RIO_ATTRS
Expand Down
8 changes: 6 additions & 2 deletions test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,16 @@ def test_open_group_load_attrs(open_rasterio):
) as rds:
attrs = rds["sur_refl_b05_1"].attrs
assert sorted(attrs) == [
"Nadir Data Resolution",
"_FillValue",
"add_offset",
"add_offset_err",
"calibrated_nt",
"long_name",
"scale_factor",
"scale_factor_err",
"units",
"valid_range",
]
assert attrs["long_name"] == "500m Surface Reflectance Band 5 - first layer"
assert attrs["units"] == "reflectance"
Expand Down Expand Up @@ -299,6 +304,7 @@ def test_open_rasterio_mask_chunk_clip():
(3.0, 0.0, 425047.68381405267, 0.0, -3.0, 4615780.040546387),
)
assert attrs == {
"AREA_OR_POINT": "Area",
"add_offset": 0.0,
"scale_factor": 1.0,
}
Expand Down Expand Up @@ -989,7 +995,6 @@ def test_mask_and_scale(open_rasterio):
attrs = rds.air_temperature.attrs
assert attrs == {
"coordinates": "day",
"coordinate_system": "WGS84,EPSG:4326",
"description": "Daily Maximum Temperature",
"dimensions": "lon lat time",
"long_name": "tmmx",
Expand Down Expand Up @@ -1023,7 +1028,6 @@ def test_no_mask_and_scale(open_rasterio):
"_Unsigned": "true",
"add_offset": 220.0,
"coordinates": "day",
"coordinate_system": "WGS84,EPSG:4326",
"description": "Daily Maximum Temperature",
"dimensions": "lon lat time",
"long_name": "tmmx",
Expand Down
1 change: 1 addition & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,7 @@ def test_to_raster(
def test_to_raster_3d(open_method, windowed, write_lock, compute, tmpdir):
tmp_raster = tmpdir.join("planet_3d_raster.tif")
with open_method(os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc")) as mda:
assert sorted(mda.coords) == ["spatial_ref", "time", "x", "y"]
xds = mda.green.fillna(mda.green.rio.encoded_nodata)
xds.rio._nodata = mda.green.rio.encoded_nodata
delayed = xds.rio.to_raster(
Expand Down

0 comments on commit ea3ff6c

Please sign in to comment.