Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added band_as_variable option to open_rasterio #592

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from xarray.core.dtypes import maybe_promote
from xarray.core.utils import is_scalar
from xarray.core.variable import as_variable
import pyproj

from rioxarray.exceptions import RioXarrayError
from rioxarray.rioxarray import _generate_spatial_coords
Expand Down Expand Up @@ -690,6 +691,7 @@ def _load_subdatasets(
default_name=subdataset.split(":")[-1].lstrip("/").replace("/", "_"),
decode_times=decode_times,
decode_timedelta=decode_timedelta,
band_as_variable=False,
**open_kwargs,
)
if shape not in dim_groups:
Expand Down Expand Up @@ -800,6 +802,7 @@ def open_rasterio(
default_name: Optional[str] = None,
decode_times: bool = True,
decode_timedelta: Optional[bool] = None,
band_as_variable: bool = False,
**open_kwargs,
) -> Union[Dataset, DataArray, List[Dataset]]:
# pylint: disable=too-many-statements,too-many-locals,too-many-branches
Expand Down Expand Up @@ -866,6 +869,9 @@ def open_rasterio(
{“days”, “hours”, “minutes”, “seconds”, “milliseconds”, “microseconds”}
into timedelta objects. If False, leave them encoded as numbers.
If None (default), assume the same value of decode_time.
band_as_variable: bool, default=False
If True, try to decode the bands description metadata as band name
If False (default), decode the bands as integers.
**open_kwargs: kwargs, optional
Optional keyword arguments to pass into :func:`rasterio.open`.

Expand All @@ -877,6 +883,7 @@ def open_rasterio(
parse_coordinates = True if parse_coordinates is None else parse_coordinates
masked = masked or mask_and_scale
vrt_params = None
output_is_dataset = False
if isinstance(filename, rasterio.io.DatasetReader):
filename = filename.name
elif isinstance(filename, rasterio.vrt.WarpedVRT):
Expand Down Expand Up @@ -967,7 +974,26 @@ def open_rasterio(
break
else:
coord_name = "band"
coords[coord_name] = np.asarray(riods.indexes)
band_descriptions = riods.descriptions
# Assign the band names only if the description is available for all of them
if band_as_variable and not any(map(lambda ele: ele is None, band_descriptions)):
output_is_dataset = True
data_vars = {}
for i in riods.indexes:
data_var_attrs = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if band specific tags should be added here?

"grid_mapping": "spatial_ref",
"scale_factor": riods.scales[i - 1],
"add_offset": riods.offsets[i - 1],
"_FillValue": riods.nodatavals[i - 1],
"description": riods.descriptions[i - 1]
}
data_vars[band_descriptions[i - 1]] = (("y", "x"), riods.read(i), data_var_attrs)
spatial_ref_attrs = pyproj.CRS.from_user_input(riods.crs).to_cf()
data_vars["spatial_ref"] = ((), 0, spatial_ref_attrs)

coords = Dataset(data_vars=data_vars)
else:
coords[coord_name] = np.asarray(riods.indexes)

has_gcps = riods.gcps[0]
if has_gcps:
Expand All @@ -978,7 +1004,6 @@ def open_rasterio(
coords.update(
_generate_spatial_coords(riods.transform, riods.width, riods.height)
)

unsigned = None
encoding: Dict[Hashable, Any] = {}
if mask_and_scale and "_Unsigned" in attrs:
Expand All @@ -1004,36 +1029,43 @@ def open_rasterio(
data = indexing.CopyOnWriteArray(data)
if cache and chunks is None:
data = indexing.MemoryCachedArray(data)

result = DataArray(
data=data, dims=(coord_name, "y", "x"), coords=coords, attrs=attrs, name=da_name
)
if output_is_dataset:
# Remove band specific attrs from Dataset attrs
band_attrs = riods.tags(1)
for k in band_attrs.keys():
attrs.pop(k)
result = coords.assign_attrs(attrs)
else:
result = DataArray(
data=data, dims=(coord_name, "y", "x"), coords=coords, attrs=attrs, name=da_name
)
result.encoding = encoding

# update attributes from NetCDF attributess
_load_netcdf_attrs(riods.tags(), result)
result = _decode_datetime_cf(
result, decode_times=decode_times, decode_timedelta=decode_timedelta
)

# make sure the _FillValue is correct dtype
if "_FillValue" in result.attrs:
if "_FillValue" in result.attrs and not output_is_dataset:
result.attrs["_FillValue"] = result.dtype.type(result.attrs["_FillValue"])

# handle encoding
_handle_encoding(result, mask_and_scale, masked, da_name, unsigned=unsigned)
# Affine transformation matrix (always available)
# This describes coefficients mapping pixel coordinates to CRS
# For serialization store as tuple of 6 floats, the last row being
# always (0, 0, 1) per definition (see
# https://github.com/sgillies/affine)
result.rio.write_transform(riods.transform, inplace=True)
rio_crs = riods.crs or result.rio.crs
if rio_crs:
result.rio.write_crs(rio_crs, inplace=True)
# If we have a Dataset with variables and each variable has already 'grid_mapping: spatial_ref'
Copy link
Member

@snowman2 snowman2 Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend not adding the CRS to each variable as you did previously and instead add it to the dataset here.

# in the variable attrs, doing the next steps would add another 'spatial_ref' as a
# coordinate and will break the functionality of xarray.to_nectdf()
# TODO: Probably we would need to check if the transform is already available in each band?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transform should be the same for each band.

if not output_is_dataset:
result.rio.write_transform(riods.transform, inplace=True)
rio_crs = riods.crs or result.rio.crs
if rio_crs:
result.rio.write_crs(rio_crs, inplace=True)
if has_gcps:
result.rio.write_gcps(*riods.gcps, inplace=True)

if chunks is not None:
result = _prepare_dask(result, riods, filename, chunks)

Expand All @@ -1050,11 +1082,22 @@ def open_rasterio(
for attr, value in result.attrs.items()
if not attr.startswith(f"{coord}#")
}
# remove duplicate tags
if result.name:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}
return result
# remove duplicate tags for DataArray
if not output_is_dataset: # result.name is not available in a Dataset
if result.name:
result.attrs = {
attr: value
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}
# remove duplicate tags for Dataset, otherwise there are issues opening the netCDF file in QGIS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. Mind adding a separate function called _clean_attrs and moving the duplicate cleanup there?

if output_is_dataset:
main_attrs = list(result.attrs.keys())
vars_attrs = [list(result[d].attrs.keys()) for d in result.data_vars]
out_list = []
for attr in vars_attrs:
out_list =list(set(attr) | set(out_list))
for k in out_list:
if k in main_attrs:
result.attrs.pop(k)
return result