From 6f8d75c4ce3c263ea3ba585cbe20b0c8e18f9328 Mon Sep 17 00:00:00 2001 From: snowman2 Date: Tue, 20 Apr 2021 08:41:09 -0500 Subject: [PATCH] BUG: Preserve original data type for writing to disk --- docs/history.rst | 1 + rioxarray/_io.py | 2 ++ rioxarray/raster_array.py | 5 +++++ test/integration/test_integration__io.py | 19 ++++++++++++++++--- .../integration/test_integration_rioxarray.py | 1 + 5 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/history.rst b/docs/history.rst index 2732124a..5448ef84 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -12,6 +12,7 @@ Latest - BUG: Return correct transform in `rio.transform` with non-rectilinear transform (discussions #280) - BUG: Update to handle WindowError in rasterio 1.2.2 (issue #286) - BUG: Don't generate x,y coords in `rio` methods if not previously there (pull #294) +- BUG: Preserve original data type for writing to disk (issue #305) 0.3.2 ----- diff --git a/rioxarray/_io.py b/rioxarray/_io.py index d4e664ab..78e8043a 100644 --- a/rioxarray/_io.py +++ b/rioxarray/_io.py @@ -812,6 +812,8 @@ def open_rasterio( encoding = {} if mask_and_scale and "_Unsigned" in attrs: unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true" + elif masked: + encoding["dtype"] = str(riods.dtypes[0]) da_name = attrs.pop("NETCDF_VARNAME", default_name) data = indexing.LazilyOuterIndexedArray( diff --git a/rioxarray/raster_array.py b/rioxarray/raster_array.py index ce9a0cd5..8fa70633 100644 --- a/rioxarray/raster_array.py +++ b/rioxarray/raster_array.py @@ -882,6 +882,11 @@ def to_raster( if driver is None and LooseVersion(rasterio.__version__) < LooseVersion("1.2"): driver = "GTiff" + dtype = ( + self._obj.encoding.get("dtype", str(self._obj.dtype)) + if dtype is None + else dtype + ) dtype = str(self._obj.dtype) if dtype is None else dtype # get the output profile from the rasterio object # if opened with xarray.open_rasterio() diff --git a/test/integration/test_integration__io.py b/test/integration/test_integration__io.py index 54447852..069c2990 100644 --- a/test/integration/test_integration__io.py +++ b/test/integration/test_integration__io.py @@ -270,7 +270,11 @@ def test_open_rasterio_mask_chunk_clip(): assert np.isnan(xdi.values).sum() == 52119 test_encoding = dict(xdi.encoding) assert test_encoding.pop("source").endswith("small_dem_3m_merged.tif") - assert test_encoding == {"_FillValue": 0.0, "grid_mapping": "spatial_ref"} + assert test_encoding == { + "_FillValue": 0.0, + "grid_mapping": "spatial_ref", + "dtype": "uint16", + } attrs = dict(xdi.attrs) assert_almost_equal( tuple(xdi.rio._cached_transform())[:6], @@ -307,7 +311,11 @@ def test_open_rasterio_mask_chunk_clip(): _assert_xarrays_equal(clipped, comp_subset) test_encoding = dict(clipped.encoding) assert test_encoding.pop("source").endswith("small_dem_3m_merged.tif") - assert test_encoding == {"_FillValue": 0.0, "grid_mapping": "spatial_ref"} + assert test_encoding == { + "_FillValue": 0.0, + "grid_mapping": "spatial_ref", + "dtype": "uint16", + } # test dataset clipped_ds = xdi.to_dataset(name="test_data").rio.clip( @@ -317,7 +325,11 @@ def test_open_rasterio_mask_chunk_clip(): _assert_xarrays_equal(clipped_ds, comp_subset_ds) test_encoding = dict(clipped.encoding) assert test_encoding.pop("source").endswith("small_dem_3m_merged.tif") - assert test_encoding == {"_FillValue": 0.0, "grid_mapping": "spatial_ref"} + assert test_encoding == { + "_FillValue": 0.0, + "grid_mapping": "spatial_ref", + "dtype": "uint16", + } ############################################################################## @@ -912,6 +924,7 @@ def test_no_mask_and_scale(open_rasterio): "_FillValue": 32767.0, "missing_value": 32767, "grid_mapping": "crs", + "dtype": "uint16", } attrs = rds.air_temperature.attrs assert attrs == { diff --git a/test/integration/test_integration_rioxarray.py b/test/integration/test_integration_rioxarray.py index ed4ae546..ebe6a735 100644 --- a/test/integration/test_integration_rioxarray.py +++ b/test/integration/test_integration_rioxarray.py @@ -1160,6 +1160,7 @@ def test_to_raster( assert_array_equal(rds.read(1), xds.fillna(xds.rio.encoded_nodata).values) assert rds.count == 1 assert rds.tags() == {"AREA_OR_POINT": "Area", **test_tags, **xds_attrs} + assert rds.dtypes == ("int16",) @pytest.mark.parametrize(