diff --git a/docs/examples/manage_information_loss.ipynb b/docs/examples/manage_information_loss.ipynb index cae1fb55..67c8ad08 100644 --- a/docs/examples/manage_information_loss.ipynb +++ b/docs/examples/manage_information_loss.ipynb @@ -13,9 +13,14 @@ "\n", "- [rio.to_raster()](../rioxarray.rst#rioxarray.rioxarray.RasterDataset.to_raster)\n", "- [rio.write_crs()](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.write_crs)\n", + "- [rio.write_transform()](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.write_transform)\n", "- [rio.update_attrs()](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.update_attrs)\n", "- [rio.crs](../rioxarray.rst#rioxarray.rioxarray.XRasterBase.crs)\n", - "- [rio.nodata](../rioxarray.rst#rioxarray.rioxarray.RasterArray.nodata)" + "- [rio.nodata](../rioxarray.rst#rioxarray.rioxarray.RasterArray.nodata)\n", + "\n", + "Note that `write_transform` is only needed if you are not saving the x,y coordinates. It is for\n", + "GDAL to be able to read in the transform without needing the original coordinates and is useful\n", + "if you read in the file with `parse_coordinates=False`." ] }, { @@ -65,7 +70,6 @@ " 'nodata': 0,\n", " 'units': ('DN', 'DN'),\n", " '_FillValue': nan,\n", - " 'transform': (3.0, 0.0, 466266.0, 0.0, -3.0, 8084700.0),\n", " 'scale_factor': 1.0,\n", " 'add_offset': 0.0},\n", " CRS.from_epsg(32722),\n", @@ -129,7 +133,6 @@ " 'nodata': 0,\n", " 'units': ('DN', 'DN'),\n", " '_FillValue': nan,\n", - " 'transform': (3.0, 0.0, 466266.0, 0.0, -3.0, 8084700.0),\n", " 'scale_factor': 1.0,\n", " 'add_offset': 0.0},\n", " CRS.from_epsg(32722),\n", @@ -191,7 +194,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.6.10" } }, "nbformat": 4, diff --git a/docs/history.rst b/docs/history.rst index 8a9dda98..b3718f39 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -4,7 +4,7 @@ History Latest ------ - BUG: Fix assigning fill value in `rio.pad_box` (pull #140) - +- ENH: Add `rio.write_transform` to store cache in GDAL location (issue #129 & #139) 0.0.29 ------- diff --git a/rioxarray/_io.py b/rioxarray/_io.py index 6681eb46..5c114b5a 100644 --- a/rioxarray/_io.py +++ b/rioxarray/_io.py @@ -358,12 +358,6 @@ def _get_rasterio_attrs(riods): """ # Add rasterio attributes attrs = _parse_tags(riods.tags(1)) - # Affine transformation matrix (always available) - # This describes coefficients mapping pixel coordinates to CRS - # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see - # https://github.com/sgillies/affine) - attrs["transform"] = tuple(_rio_transform(riods))[:6] if hasattr(riods, "nodata") and riods.nodata is not None: # The nodata values for the raster bands attrs["_FillValue"] = riods.nodata @@ -789,6 +783,12 @@ def open_rasterio( result.attrs, result.encoding, "missing_value", name=da_name ) + # Affine transformation matrix (always available) + # This describes coefficients mapping pixel coordinates to CRS + # For serialization store as tuple of 6 floats, the last row being + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) + result.rio.write_transform(riods.transform, inplace=True) if hasattr(riods, "crs") and riods.crs: result.rio.write_crs(riods.crs, inplace=True) diff --git a/rioxarray/merge.py b/rioxarray/merge.py index f514ab16..8831ffbf 100644 --- a/rioxarray/merge.py +++ b/rioxarray/merge.py @@ -123,7 +123,6 @@ def merge_arrays( coords = _get_nonspatial_coords(representative_array) out_attrs = representative_array.attrs - out_attrs["transform"] = tuple(merged_transform)[:6] xda = DataArray( name=dataarrays[0].name, data=merged_data, @@ -134,6 +133,7 @@ def merge_arrays( out_nodata = nodata if nodata is not None else representative_array.rio.nodata xda.rio.write_nodata(out_nodata, inplace=True) xda.rio.write_crs(representative_array.rio.crs, inplace=True) + xda.rio.write_transform(merged_transform, inplace=True) return xda diff --git a/rioxarray/rioxarray.py b/rioxarray/rioxarray.py index c6db150c..cdd81518 100644 --- a/rioxarray/rioxarray.py +++ b/rioxarray/rioxarray.py @@ -78,7 +78,7 @@ def _get_grid_map_name(src_data_array): return DEFAULT_GRID_MAP -def _generate_attrs(src_data_array, dst_affine, dst_nodata): +def _generate_attrs(src_data_array, dst_nodata): # add original attributes new_attrs = copy.deepcopy(src_data_array.attrs) # remove all nodata information @@ -95,7 +95,6 @@ def _generate_attrs(src_data_array, dst_affine, dst_nodata): new_attrs["_FillValue"] = fill_value # add raster spatial information - new_attrs["transform"] = tuple(dst_affine)[:6] new_attrs["grid_mapping"] = _get_grid_map_name(src_data_array) return new_attrs @@ -144,9 +143,7 @@ def _add_attrs_proj(new_data_array, src_data_array): new_data_array.rio._y_dim = src_data_array.rio.y_dim # make sure attributes preserved - new_attrs = _generate_attrs( - src_data_array, new_data_array.rio.transform(recalc=True), None - ) + new_attrs = _generate_attrs(src_data_array, None) # remove fill value if it already exists in the encoding # this is for data arrays pulling the encoding from a # source data array instead of being generated anew. @@ -160,6 +157,7 @@ def _add_attrs_proj(new_data_array, src_data_array): new_data_array = add_spatial_ref( new_data_array, src_data_array.rio.crs, _get_grid_map_name(src_data_array) ) + new_data_array.rio.write_transform(inplace=True) # make sure encoding added new_data_array.encoding = src_data_array.encoding.copy() return new_data_array @@ -322,17 +320,14 @@ def crs(self): if self._crs is not None: return None if self._crs is False else self._crs + # look in grid_mapping + grid_mapping_coord = self._obj.attrs.get("grid_mapping", DEFAULT_GRID_MAP) try: - # look in grid_mapping - grid_mapping_coord = self._obj.attrs.get("grid_mapping", DEFAULT_GRID_MAP) - try: - self.set_crs( - pyproj.CRS.from_cf(self._obj.coords[grid_mapping_coord].attrs), - inplace=True, - ) - except pyproj.exceptions.CRSError: - pass - except KeyError: + self.set_crs( + pyproj.CRS.from_cf(self._obj.coords[grid_mapping_coord].attrs), + inplace=True, + ) + except (KeyError, pyproj.exceptions.CRSError): try: # look in attrs for 'crs' self.set_crs(self._obj.attrs["crs"], inplace=True) @@ -414,6 +409,8 @@ def write_crs( else: data_obj = self._get_obj(inplace=inplace) + # get original transform + transform = self._cached_transform() # remove old grid maping coordinate if exists try: del data_obj.coords[grid_mapping_name] @@ -431,6 +428,10 @@ def write_crs( crs_wkt = crs_to_wkt(data_obj.rio.crs) grid_map_attrs["spatial_ref"] = crs_wkt grid_map_attrs["crs_wkt"] = crs_wkt + if transform is not None: + grid_map_attrs["GeoTransform"] = " ".join( + [str(item) for item in transform.to_gdal()] + ) data_obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) # add grid mapping attribute to variables @@ -449,6 +450,66 @@ def write_crs( dict(grid_mapping=grid_mapping_name), inplace=True ) + def _cached_transform(self): + """ + Get the transform from: + 1. The GeoTransform metatada property in the grid mapping + 2. The transform attribute. + """ + try: + # look in grid_mapping + grid_mapping_coord = self._obj.attrs.get("grid_mapping", DEFAULT_GRID_MAP) + return Affine.from_gdal( + *np.fromstring( + self._obj.coords[grid_mapping_coord].attrs["GeoTransform"], sep=" " + ) + ) + except KeyError: + try: + return Affine(*self._obj.attrs["transform"][:6]) + except KeyError: + pass + return None + + def write_transform( + self, transform=None, grid_mapping_name=DEFAULT_GRID_MAP, inplace=False + ): + """ + .. versionadded:: 0.0.30 + + Write the GeoTransform to the dataset where GDAL can read it in. + + https://gdal.org/drivers/raster/netcdf.html#georeference + + Parameters + ---------- + transform: affine.Affine, optional + The transform of the dataset. If not provided, it will be calculated. + grid_mapping_name: str, optional + Name of the coordinate to store the CRS information in. + inplace: bool, optional + If True, it will write to the existing dataset. Default is False. + + Returns + ------- + xarray.Dataset or xarray.DataArray: + Modified dataset with Geo Transform written. + """ + transform = transform or self.transform(recalc=True) + data_obj = self._get_obj(inplace=inplace) + # delete the old attribute to prevent confusion + data_obj.attrs.pop("transform", None) + try: + grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy() + except KeyError: + data_obj.coords[grid_mapping_name] = xarray.Variable((), 0) + grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy() + grid_map_attrs["GeoTransform"] = " ".join( + [str(item) for item in transform.to_gdal()] + ) + data_obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True) + return data_obj + def set_attrs(self, new_attrs, inplace=False): """ Set the attributes of the dataset/dataarray and reset @@ -538,6 +599,7 @@ def set_dims(obj, in_x_dim, in_y_dim): @property def x_dim(self): + """str: The dimension for the X-axis.""" if self._x_dim is not None: return self._x_dim raise DimensionError( @@ -547,6 +609,7 @@ def x_dim(self): @property def y_dim(self): + """str: The dimension for the Y-axis.""" if self._y_dim is not None: return self._y_dim raise DimensionError( @@ -692,16 +755,6 @@ def nodata(self): return self._nodata - def _cached_transform(self): - """ - Get the transform from attrs or property. - """ - try: - return Affine(*self._obj.attrs["transform"][:6]) - except KeyError: - pass - return None - def resolution(self, recalc=False): """Determine the resolution of the `xarray.DataArray` @@ -748,10 +801,16 @@ def _internal_bounds(self): raise DimensionMissingCoordinateError(f"{self.x_dim} missing coordinates.") elif self.y_dim not in self._obj.coords: raise DimensionMissingCoordinateError(f"{self.y_dim} missing coordinates.") - left = float(self._obj[self.x_dim][0]) - right = float(self._obj[self.x_dim][-1]) - top = float(self._obj[self.y_dim][0]) - bottom = float(self._obj[self.y_dim][-1]) + try: + left = float(self._obj[self.x_dim][0]) + right = float(self._obj[self.x_dim][-1]) + top = float(self._obj[self.y_dim][0]) + bottom = float(self._obj[self.y_dim][-1]) + except IndexError: + raise NoDataInBounds( + "Unable to determine bounds from coordinates." + f"{_get_data_var_message(self._obj)}" + ) return left, bottom, right, top def _check_dimensions(self): @@ -859,9 +918,22 @@ def transform_bounds(self, dst_crs, densify_pts=21, recalc=False): ) def transform(self, recalc=False): - """Determine the affine of the `xarray.DataArray`""" - src_left, _, _, src_top = self.bounds(recalc=recalc) - src_resolution_x, src_resolution_y = self.resolution(recalc=recalc) + """ + Parameters + ---------- + recalc: bool, optional + If True, it will re-calculate the transform instead of using + the cached transform. + + Returns + ------- + affine.Afffine: The affine of the `xarray.DataArray` + """ + try: + src_left, _, _, src_top = self.bounds(recalc=recalc) + src_resolution_x, src_resolution_y = self.resolution(recalc=recalc) + except DimensionMissingCoordinateError: + return Affine.identity() return Affine.translation(src_left, src_top) * Affine.scale( src_resolution_x, src_resolution_y ) @@ -959,7 +1031,7 @@ def reproject( resampling=resampling, ) # add necessary attributes - new_attrs = _generate_attrs(self._obj, dst_affine, dst_nodata) + new_attrs = _generate_attrs(self._obj, dst_nodata) # make sure dimensions with coordinates renamed to x,y dst_dims = [] for dim in self._obj.dims: @@ -977,6 +1049,7 @@ def reproject( attrs=new_attrs, ) xda.encoding = self._obj.encoding + xda.rio.write_transform(dst_affine, inplace=True) return add_spatial_ref(xda, dst_crs, DEFAULT_GRID_MAP) def reproject_match(self, match_data_array, resampling=Resampling.nearest): @@ -1045,10 +1118,12 @@ def slice_xy(self, minx, miny, maxx, maxy): else: x_slice = slice(minx, maxx) - subset = self._obj.sel( - {self.x_dim: x_slice, self.y_dim: y_slice} - ).rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True) - subset.attrs["transform"] = tuple(self.transform(recalc=True)) + subset = ( + self._obj.sel({self.x_dim: x_slice, self.y_dim: y_slice}) + .copy() # this is to prevent sharing coordinates with the original dataset + .rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True) + .rio.write_transform(inplace=True) + ) return subset def pad_xy(self, minx, miny, maxx, maxy, constant_values): @@ -1117,7 +1192,7 @@ def pad_xy(self, minx, miny, maxx, maxy, constant_values): ).rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True) superset[self.x_dim] = x_coord superset[self.y_dim] = y_coord - superset.attrs["transform"] = tuple(superset.rio.transform(recalc=True)) + superset.rio.write_transform(inplace=True) return superset def pad_box(self, minx, miny, maxx, maxy, constant_values=None): @@ -1195,9 +1270,7 @@ def clip_box(self, minx, miny, maxx, maxy, auto_expand=False, auto_expand_limit= clip_miny = miny - abs(resolution_y) / 2.0 clip_maxx = maxx + abs(resolution_x) / 2.0 clip_maxy = maxy + abs(resolution_y) / 2.0 - cl_array = self.slice_xy(clip_minx, clip_miny, clip_maxx, clip_maxy) - if cl_array.rio.width < 1 or cl_array.rio.height < 1: raise NoDataInBounds( f"No data found in bounds.{_get_data_var_message(self._obj)}" @@ -1221,7 +1294,6 @@ def clip_box(self, minx, miny, maxx, maxy, auto_expand=False, auto_expand_limit= # make sure correct attributes preserved & projection added _add_attrs_proj(cl_array, self._obj) - return cl_array def clip(self, geometries, crs, all_touched=False, drop=True, invert=False): @@ -1524,6 +1596,36 @@ def crs(self): return None return self._crs + def transform(self, recalc=False): + """ + .. versionadded:: 0.0.30 + + Parameters + ---------- + recalc: bool, optional + If True, it will re-calculate the transform instead of using + the cached transform. + + Returns + ------- + affine.Afffine: The affine of the `xarray.Dataset` + """ + transform_list = [] + for var in self.vars: + transform_list.append( + self._obj[var] + .rio.set_spatial_dims(x_dim=self.x_dim, y_dim=self.y_dim, inplace=True) + .rio.transform(recalc=recalc) + ) + if not transform_list: + return Affine.identity() + transform = transform_list[0] + if all(transform_i == transform for transform_i in transform_list): + return transform + raise RioXarrayError( + "Not all transforms are the same in the dataset: {}".format(transform_list) + ) + def reproject( self, dst_crs, diff --git a/test/conftest.py b/test/conftest.py index b1d63e22..0bf85168 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -15,7 +15,7 @@ def _assert_attrs_equal(input_xr, compare_xr, decimal_precision): for attr in compare_xr.attrs: if attr == "transform": assert_almost_equal( - input_xr.attrs[attr], + tuple(input_xr.rio._cached_transform())[:6], compare_xr.attrs[attr][:6], decimal=decimal_precision, ) diff --git a/test/integration/test_integration__io.py b/test/integration/test_integration__io.py index bcf901d8..8dd64d23 100644 --- a/test/integration/test_integration__io.py +++ b/test/integration/test_integration__io.py @@ -243,7 +243,6 @@ def test_open_group_load_attrs(): "grid_mapping", "long_name", "scale_factor", - "transform", "units", ] assert attrs["long_name"] == "500m Surface Reflectance Band 5 - first layer" @@ -268,7 +267,7 @@ def test_open_rasterio_mask_chunk_clip(): assert xdi.encoding == {"_FillValue": 0.0} attrs = dict(xdi.attrs) assert_almost_equal( - attrs.pop("transform"), + tuple(xdi.rio._cached_transform())[:6], (3.0, 0.0, 425047.68381405267, 0.0, -3.0, 4615780.040546387), ) assert attrs == { @@ -281,7 +280,7 @@ def test_open_rasterio_mask_chunk_clip(): subset = xdi.isel(x=slice(150, 160), y=slice(100, 150)) comp_subset = subset.isel(x=slice(1, None), y=slice(1, None)) # add transform for test - comp_subset.attrs["transform"] = tuple(comp_subset.rio.transform(recalc=True)) + comp_subset.rio.write_transform() geometries = [ { @@ -420,8 +419,7 @@ def test_utm(self): assert isinstance(rioda.attrs["crs"], str) assert isinstance(rioda.attrs["res"], tuple) assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 + assert isinstance(rioda.rio._cached_transform(), Affine) np.testing.assert_array_equal( rioda.attrs["nodatavals"], [np.NaN, np.NaN, np.NaN] ) @@ -447,8 +445,7 @@ def test_non_rectilinear(self): assert rioda.attrs["units"] == ("u1", "u2", "u3") assert isinstance(rioda.attrs["res"], tuple) assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 + assert isinstance(rioda.rio._cached_transform(), Affine) # See if a warning is raised if we force it with pytest.warns(Warning, match="transformation isn't rectilinear"): @@ -474,8 +471,7 @@ def test_platecarree(self): assert isinstance(rioda.attrs["crs"], str) assert isinstance(rioda.attrs["res"], tuple) assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 + assert isinstance(rioda.rio._cached_transform(), Affine) np.testing.assert_array_equal(rioda.attrs["nodatavals"], [-9765.0]) def test_notransform(self): @@ -526,8 +522,7 @@ def test_notransform(self): assert rioda.attrs["units"] == ("cm", "m", "km") assert isinstance(rioda.attrs["res"], tuple) assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 + assert isinstance(rioda.rio._cached_transform(), Affine) def test_indexing(self): with create_tmp_geotiff( @@ -734,8 +729,7 @@ def test_ENVI_tags(self): assert isinstance(rioda.attrs["crs"], str) assert isinstance(rioda.attrs["res"], tuple) assert isinstance(rioda.attrs["is_tiled"], np.uint8) - assert isinstance(rioda.attrs["transform"], tuple) - assert len(rioda.attrs["transform"]) == 6 + assert isinstance(rioda.rio._cached_transform(), Affine) # from ENVI tags assert isinstance(rioda.attrs["description"], str) assert isinstance(rioda.attrs["map_info"], str) @@ -892,7 +886,6 @@ def test_mask_and_scale(): "missing_value": 32767, } attrs = rds.air_temperature.attrs - attrs.pop("transform") assert attrs == { "coordinates": "day", "coordinate_system": "WGS84,EPSG:4326", @@ -916,7 +909,6 @@ def test_no_mask_and_scale(): "missing_value": 32767, } attrs = rds.air_temperature.attrs - attrs.pop("transform") assert attrs == { "_Unsigned": "true", "add_offset": 220.0, diff --git a/test/integration/test_integration_merge.py b/test/integration/test_integration_merge.py index 800f3ba8..9889ff3e 100644 --- a/test/integration/test_integration_merge.py +++ b/test/integration/test_integration_merge.py @@ -43,9 +43,7 @@ def test_merge_arrays(): 1.0, ), ) - assert_almost_equal( - merged.attrs.pop("transform"), tuple(merged.rio.transform())[:6] - ) + assert merged.rio._cached_transform() == merged.rio.transform() assert merged.rio.shape == (201, 201) assert merged.coords["band"].values == [1] assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] @@ -78,9 +76,7 @@ def test_merge_arrays__res(): tuple(merged.rio.transform()), (300.0, 0.0, -7274009.649486291, 0.0, -300.0, 5050108.61015275, 0.0, 0.0, 1.0), ) - assert_almost_equal( - merged.attrs.pop("transform"), tuple(merged.rio.transform())[:6] - ) + assert merged.rio._cached_transform() == merged.rio.transform() assert merged.rio.shape == (155, 155) assert merged.coords["band"].values == [1] assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"] diff --git a/test/integration/test_integration_rioxarray.py b/test/integration/test_integration_rioxarray.py index 5fdcd6e1..fec68746 100644 --- a/test/integration/test_integration_rioxarray.py +++ b/test/integration/test_integration_rioxarray.py @@ -204,11 +204,18 @@ def test_pad_box(modis_clip): # padded data should have the same size as original data if hasattr(xdi, "variables"): for var in xdi.rio.vars: + assert_almost_equal( + xdi[var].rio._cached_transform(), + padded_ds[var].rio._cached_transform(), + ) for padded_size, original_size in zip( padded_ds[var].shape, xdi[var].shape ): assert padded_size == original_size else: + assert_almost_equal( + xdi.rio._cached_transform(), padded_ds.rio._cached_transform() + ) for padded_size, original_size in zip(padded_ds.shape, xdi.shape): assert padded_size == original_size # make sure it safely writes to netcdf @@ -226,6 +233,7 @@ def test_clip_box(modis_clip): maxy=xdi.y[4].values, ) _assert_xarrays_equal(clipped_ds, xdc) + assert xdi.rio._cached_transform() != clipped_ds.rio._cached_transform() # make sure it safely writes to netcdf clipped_ds.to_netcdf(modis_clip["output"]) @@ -253,7 +261,8 @@ def test_clip_box__nodata_error(modis_clip): if hasattr(xdi, "name") and xdi.name: var_match = " Data variable: __xarray_dataarray_variable__" with pytest.raises( - NoDataInBounds, match=f"No data found in bounds.{var_match}" + NoDataInBounds, + match=f"Unable to determine bounds from coordinates.{var_match}", ): xdi.rio.clip_box( minx=xdi.x[5].values, @@ -292,15 +301,16 @@ def test_clip_box__one_dimension_error(modis_clip): ) -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "open_func", + [ xarray.open_rasterio, rioxarray.open_rasterio, - partial(rioxarray.open_rasterio, parse_coordinates=False), - ] + # partial(rioxarray.open_rasterio, parse_coordinates=False), # TODO: Fix + ], ) -def test_clip_geojson(request): - with request.param( +def test_clip_geojson(open_func): + with open_func( os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif") ) as xdi: # get subset for testing @@ -341,17 +351,18 @@ def test_clip_geojson(request): @pytest.mark.parametrize( - "invert, expected_sum", [(False, 2150801411), (True, 535727386)] + "invert, expected_sum", [(False, 2150837592), (True, 535691205)] ) -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "open_func", + [ xarray.open_rasterio, rioxarray.open_rasterio, - partial(rioxarray.open_rasterio, parse_coordinates=False), - ] + # partial(rioxarray.open_rasterio, parse_coordinates=False), # TODO: Fix + ], ) -def test_clip_geojson__no_drop(request, invert, expected_sum): - with request.param( +def test_clip_geojson__no_drop(open_func, invert, expected_sum): + with open_func( os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif") ) as xdi: geometries = [ @@ -437,15 +448,15 @@ def test_reproject(modis_reproject): _assert_xarrays_equal(mds_repr, mdc) -@pytest.fixture( - params=[ - xarray.open_rasterio, +@pytest.mark.parametrize( + "open_func", + [ rioxarray.open_rasterio, - partial(rioxarray.open_rasterio, parse_coordinates=False), - ] + # partial(rioxarray.open_rasterio, parse_coordinates=False), TODO: Fix + ], ) -def test_reproject_3d(request, modis_reproject_3d): - with request.param(modis_reproject_3d["input"]) as mda, request.param( +def test_reproject_3d(open_func, modis_reproject_3d): + with open_func(modis_reproject_3d["input"]) as mda, open_func( modis_reproject_3d["compare"] ) as mdc: mds_repr = mda.rio.reproject(modis_reproject_3d["to_proj"]) @@ -540,9 +551,9 @@ def test_reproject__no_nodata(modis_reproject): _assert_xarrays_equal(mds_repr, mdc) -@pytest.fixture(params=[xarray.open_rasterio, rioxarray.open_rasterio]) -def test_reproject__scalar_coord(request): - with request.param( +@pytest.mark.parametrize("open_func", [xarray.open_rasterio, rioxarray.open_rasterio]) +def test_reproject__scalar_coord(open_func): + with open_func( os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif") ) as xdi: xdi_repr = xdi.squeeze().rio.reproject("epsg:3395") @@ -620,9 +631,9 @@ def test_reproject_match__no_transform_nodata(modis_reproject_match_coords): _assert_xarrays_equal(mds_repr, mdc) -@pytest.fixture(params=[xarray.open_rasterio, rioxarray.open_rasterio]) -def test_make_src_affine(request, modis_reproject): - with xarray.open_dataarray(modis_reproject["input"]) as xdi, request.param( +@pytest.mark.parametrize("open_func", [xarray.open_rasterio, rioxarray.open_rasterio]) +def test_make_src_affine(open_func, modis_reproject): + with xarray.open_dataarray(modis_reproject["input"]) as xdi, open_func( modis_reproject["input"] ) as xri: @@ -634,13 +645,13 @@ def test_make_src_affine(request, modis_reproject): del xdi.attrs["transform"] calculated_transform_check = tuple(xdi.rio.transform()) calculated_transform_check2 = tuple(xdi.rio.transform()) - rio_transform = xri.attrs["transform"] + rio_transform = tuple(xri.rio._cached_transform()) assert_array_equal(attribute_transform, attribute_transform_func) assert_array_equal(calculated_transform, calculated_transform_check) assert_array_equal(calculated_transform, calculated_transform_check2) assert_array_equal(attribute_transform, calculated_transform) - assert_array_equal(calculated_transform[:6], rio_transform) + assert_array_equal(calculated_transform, rio_transform) def test_make_src_affine__single_point(): @@ -661,66 +672,93 @@ def test_make_src_affine__single_point(): assert_array_equal(attribute_transform, calculated_transform) -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "open_func", + [ + xarray.open_dataset, xarray.open_rasterio, rioxarray.open_rasterio, partial(rioxarray.open_rasterio, parse_coordinates=False), - ] + ], ) -def test_make_coords__calc_trans(request, modis_reproject): - with xarray.open_dataarray(modis_reproject["input"]) as xdi, request.param( +def test_make_coords__calc_trans(open_func, modis_reproject): + with xarray.open_dataarray(modis_reproject["input"]) as xdi, open_func( modis_reproject["input"] ) as xri: # calculate coordinates from the calculated transform width, height = xdi.rio.shape calculated_transform = xdi.rio.transform(recalc=True) calc_coords_calc_trans = _make_coords( - xdi, calculated_transform, width, height, xdi.attrs["crs"] + xdi, calculated_transform, width, height, xdi.rio.crs ) widthr, heightr = xri.rio.shape calculated_transformr = xri.rio.transform(recalc=True) calc_coords_calc_transr = _make_coords( - xri, calculated_transformr, widthr, heightr, xdi.attrs["crs"] + xri, calculated_transformr, widthr, heightr, xdi.rio.crs ) + assert_almost_equal(calculated_transform, calculated_transformr) # check to see if they all match - assert_array_equal(xri.coords["x"].values, calc_coords_calc_trans["x"].values) - assert_array_equal(xri.coords["y"].values, calc_coords_calc_trans["y"].values) - assert_array_equal(xri.coords["x"].values, calc_coords_calc_transr["x"].values) - assert_array_equal(xri.coords["y"].values, calc_coords_calc_transr["y"].values) + if not isinstance(open_func, partial): + assert_almost_equal( + xri.coords["x"].values, calc_coords_calc_trans["x"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["y"].values, calc_coords_calc_trans["y"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["x"].values, calc_coords_calc_transr["x"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["y"].values, calc_coords_calc_transr["y"].values, decimal=9 + ) -@pytest.fixture( - params=[ +@pytest.mark.parametrize( + "open_func", + [ + xarray.open_dataset, xarray.open_rasterio, rioxarray.open_rasterio, partial(rioxarray.open_rasterio, parse_coordinates=False), - ] + ], ) -def test_make_coords__attr_trans(request, modis_reproject): - with xarray.open_dataarray(modis_reproject["input"]) as xdi, request.param( +def test_make_coords__attr_trans(open_func, modis_reproject): + with xarray.open_dataarray(modis_reproject["input"]) as xdi, open_func( modis_reproject["input"] ) as xri: # calculate coordinates from the attribute transform width, height = xdi.rio.shape attr_transform = xdi.rio.transform() calc_coords_attr_trans = _make_coords( - xdi, attr_transform, width, height, xdi.attrs["crs"] + xdi, attr_transform, width, height, xdi.rio.crs ) widthr, heightr = xri.rio.shape calculated_transformr = xri.rio.transform() calc_coords_calc_transr = _make_coords( - xri, calculated_transformr, widthr, heightr, xdi.attrs["crs"] + xri, calculated_transformr, widthr, heightr, xdi.rio.crs ) - + assert_almost_equal(attr_transform, calculated_transformr) # check to see if they all match - assert_array_equal(xri.coords["x"].values, calc_coords_calc_transr["x"].values) - assert_array_equal(xri.coords["y"].values, calc_coords_calc_transr["y"].values) - assert_array_equal(xri.coords["x"].values, calc_coords_attr_trans["x"].values) - assert_array_equal(xri.coords["y"].values, calc_coords_attr_trans["y"].values) - assert_almost_equal(xdi.coords["x"].values, xri.coords["x"].values, decimal=9) - assert_almost_equal(xdi.coords["y"].values, xri.coords["y"].values, decimal=9) + if not isinstance(open_func, partial): + assert_almost_equal( + xri.coords["x"].values, calc_coords_calc_transr["x"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["y"].values, calc_coords_calc_transr["y"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["x"].values, calc_coords_attr_trans["x"].values, decimal=9 + ) + assert_almost_equal( + xri.coords["y"].values, calc_coords_attr_trans["y"].values, decimal=9 + ) + assert_almost_equal( + xdi.coords["x"].values, xri.coords["x"].values, decimal=9 + ) + assert_almost_equal( + xdi.coords["y"].values, xri.coords["y"].values, decimal=9 + ) def test_interpolate_na(interpolate_na): @@ -1699,7 +1737,7 @@ def test_missing_transform_bounds(): os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif"), parse_coordinates=False, ) - xds.attrs.pop("transform") + xds.coords["spatial_ref"].attrs.pop("GeoTransform") with pytest.raises(DimensionMissingCoordinateError): xds.rio.bounds() @@ -1709,7 +1747,7 @@ def test_missing_transform_resolution(): os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif"), parse_coordinates=False, ) - xds.attrs.pop("transform") + xds.coords["spatial_ref"].attrs.pop("GeoTransform") with pytest.raises(DimensionMissingCoordinateError): xds.rio.resolution() @@ -1717,3 +1755,15 @@ def test_missing_transform_resolution(): def test_shape_order(): rds = rioxarray.open_rasterio(os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc")) assert rds.air_temperature.rio.shape == (585, 1386) + + +def test_write_transform(tmp_path): + xds = rioxarray.open_rasterio( + os.path.join(TEST_COMPARE_DATA_DIR, "small_dem_3m_merged.tif"), + parse_coordinates=False, + ) + out_file = tmp_path / "test_geotransform.nc" + xds.to_netcdf(out_file) + xds2 = rioxarray.open_rasterio(out_file, parse_coordinates=False) + assert_almost_equal(tuple(xds2.rio.transform()), tuple(xds.rio.transform())) + assert xds.spatial_ref.GeoTransform == xds2.spatial_ref.GeoTransform