diff --git a/.pylintrc b/.pylintrc index d7961212..f7feea52 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,7 +15,7 @@ enable=c-extension-no-member [FORMAT] -max-module-lines=1200 +max-module-lines=1250 [DESIGN] max-locals=20 diff --git a/docs/history.rst b/docs/history.rst index 354d7534..c4c52875 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -3,6 +3,7 @@ History Latest ------ +- ENH: Added band_as_variable option to open_rasterio (issue #296) - BUG: Pass warp_extras dictionary to raster.vrt.WarpedVRT (issue #598) 0.12.3 diff --git a/rioxarray/_io.py b/rioxarray/_io.py index fb2db0c1..b12ff03b 100644 --- a/rioxarray/_io.py +++ b/rioxarray/_io.py @@ -7,6 +7,7 @@ """ import contextlib +import functools import importlib.metadata import os import re @@ -38,7 +39,91 @@ RASTERIO_LOCK = SerializableLock() NO_LOCK = contextlib.nullcontext() -RasterioReader = Union[rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT] + +class SingleBandDatasetReader: + """ + Hack to have a DatasetReader behave like it only has one band + """ + + def __init__(self, riods, bidx) -> None: + self._riods = riods + self._bidx = bidx + + def __getattr__(self, __name: str) -> Any: + return getattr(self._riods, __name) + + @property + def count(self): + """ + int: band count + """ + return 1 + + @property + def nodata(self): + """ + Nodata value for the band + """ + return self._riods.nodatavals[self._bidx] + + @property + def offsets(self): + """ + Offset value for the band + """ + return [self._riods.offsets[self._bidx]] + + @property + def scales(self): + """ + Scale value for the band + """ + return [self._riods.scales[self._bidx]] + + @property + def units(self): + """ + Unit for the band + """ + return [self._riods.units[self._bidx]] + + @property + def descriptions(self): + """ + Description for the band + """ + return [self._riods.descriptions[self._bidx]] + + @property + def dtypes(self): + """ + dtype for the band + """ + return [self._riods.dtypes[self._bidx]] + + @property + def indexes(self): + """ + indexes for the band + """ + return [self._riods.indexes[self._bidx]] + + def read(self, indexes=None, **kwargs): # pylint: disable=unused-argument + """ + read data for the band + """ + return self._riods.read(indexes=self._bidx + 1, **kwargs) + + def tags(self, bidx=None, **kwargs): # pylint: disable=unused-argument + """ + read tags for the band + """ + return self._riods.tags(bidx=self._bidx + 1, **kwargs) + + +RasterioReader = Union[ + rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT, SingleBandDatasetReader +] try: @@ -711,6 +796,49 @@ def _load_subdatasets( return dataset +def _load_bands_as_variables( + riods: RasterioReader, + parse_coordinates: bool, + chunks: Optional[Union[int, Tuple, Dict]], + cache: Optional[bool], + lock: Any, + masked: bool, + mask_and_scale: bool, + decode_times: bool, + decode_timedelta: Optional[bool], + **open_kwargs, +) -> Union[Dataset, List[Dataset]]: + """ + Load in rasterio bands as variables + """ + global_tags = _parse_tags(riods.tags()) + data_vars = {} + for band in riods.indexes: + band_riods = SingleBandDatasetReader( + riods=riods, + bidx=band - 1, + ) + band_name = f"band_{band}" + data_vars[band_name] = ( + open_rasterio( # type: ignore + band_riods, + parse_coordinates=band == 1 and parse_coordinates, + chunks=chunks, + cache=cache, + lock=lock, + masked=masked, + mask_and_scale=mask_and_scale, + default_name=band_name, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + **open_kwargs, + ) + .squeeze() # type: ignore + .drop("band") # type: ignore + ) + return Dataset(data_vars, attrs=global_tags) + + def _prepare_dask( result: DataArray, riods: RasterioReader, @@ -785,9 +913,23 @@ def _handle_encoding( ) +def _single_band_open(*args, bidx=0, **kwargs): + """ + Open file as if it only has a single band + """ + return SingleBandDatasetReader( + riods=rasterio.open(*args, **kwargs), + bidx=bidx, + ) + + def open_rasterio( filename: Union[ - str, os.PathLike, rasterio.io.DatasetReader, rasterio.vrt.WarpedVRT + str, + os.PathLike, + rasterio.io.DatasetReader, + rasterio.vrt.WarpedVRT, + SingleBandDatasetReader, ], parse_coordinates: Optional[bool] = None, chunks: Optional[Union[int, Tuple, Dict]] = None, @@ -800,6 +942,7 @@ def open_rasterio( default_name: Optional[str] = None, decode_times: bool = True, decode_timedelta: Optional[bool] = None, + band_as_variable: bool = False, **open_kwargs, ) -> Union[Dataset, DataArray, List[Dataset]]: # pylint: disable=too-many-statements,too-many-locals,too-many-branches @@ -812,6 +955,8 @@ def open_rasterio( `_ for more information). + .. versionadded:: 0.13 band_as_variable + Parameters ---------- filename: str, rasterio.io.DatasetReader, or rasterio.vrt.WarpedVRT @@ -866,6 +1011,8 @@ def open_rasterio( {“days”, “hours”, “minutes”, “seconds”, “milliseconds”, “microseconds”} into timedelta objects. If False, leave them encoded as numbers. If None (default), assume the same value of decode_time. + band_as_variable: bool, default=False + If True, will load bands in a raster to separate variables. **open_kwargs: kwargs, optional Optional keyword arguments to pass into :func:`rasterio.open`. @@ -877,7 +1024,13 @@ def open_rasterio( parse_coordinates = True if parse_coordinates is None else parse_coordinates masked = masked or mask_and_scale vrt_params = None - if isinstance(filename, rasterio.io.DatasetReader): + file_opener = rasterio.open + if isinstance(filename, SingleBandDatasetReader): + file_opener = functools.partial( + _single_band_open, + bidx=filename._bidx, + ) + if isinstance(filename, (rasterio.io.DatasetReader, SingleBandDatasetReader)): filename = filename.name elif isinstance(filename, rasterio.vrt.WarpedVRT): vrt = filename @@ -909,13 +1062,27 @@ def open_rasterio( with warnings.catch_warnings(record=True) as rio_warnings: if lock is not NO_LOCK and isinstance(filename, (str, os.PathLike)): manager: FileManager = CachingFileManager( - rasterio.open, filename, lock=lock, mode="r", kwargs=open_kwargs + file_opener, filename, lock=lock, mode="r", kwargs=open_kwargs ) else: - manager = URIManager(rasterio.open, filename, mode="r", kwargs=open_kwargs) + manager = URIManager(file_opener, filename, mode="r", kwargs=open_kwargs) riods = manager.acquire() captured_warnings = rio_warnings.copy() + if band_as_variable: + return _load_bands_as_variables( + riods=riods, + parse_coordinates=parse_coordinates, + chunks=chunks, + cache=cache, + lock=lock, + masked=masked, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + **open_kwargs, + ) + # raise the NotGeoreferencedWarning if applicable for rio_warning in captured_warnings: if not riods.subdatasets or not isinstance( diff --git a/test/integration/test_integration__io.py b/test/integration/test_integration__io.py index 72b418a3..45095880 100644 --- a/test/integration/test_integration__io.py +++ b/test/integration/test_integration__io.py @@ -426,6 +426,8 @@ def create_tmp_geotiff( ) as s: for attr, val in additional_attrs.items(): setattr(s, attr, val) + for band in range(1, nz + 1): + s.update_tags(band, BAND=band) s.write(data, **write_kwargs) dx, dy = s.res[0], -s.res[1] tt = s.transform @@ -480,6 +482,25 @@ def test_utm(): assert "y" not in rioda.coords +def test_band_as_variable(): + with create_tmp_geotiff() as (tmp_file, expected): + with rioxarray.open_rasterio(tmp_file, band_as_variable=True) as riods: + for band in (1, 2, 3): + band_name = f"band_{band}" + assert_allclose(riods[band_name], expected.sel(band=band).drop("band")) + assert riods[band_name].attrs["BAND"] == band + assert riods[band_name].attrs["scale_factor"] == 1.0 + assert riods[band_name].attrs["add_offset"] == 0.0 + assert riods[band_name].attrs["long_name"] == f"d{band}" + assert riods[band_name].attrs["units"] == f"u{band}" + assert riods[band_name].rio.crs == expected.rio.crs + assert_array_equal( + riods[band_name].rio.resolution(), expected.rio.resolution() + ) + assert isinstance(riods[band_name].rio._cached_transform(), Affine) + assert riods[band_name].rio.nodata is None + + def test_platecarree(): with create_tmp_geotiff( 8,