From 77a954a820866e07954d2c64f06ea5e715b48735 Mon Sep 17 00:00:00 2001 From: Robbi Bishop-Taylor Date: Tue, 12 Mar 2024 14:11:12 +1100 Subject: [PATCH] Add updated xr_interpolate func and tests (#1205) * Add updated xr_interpolate func and tests * Rename test * Remove commented line * Update dea_tools.spatial.rst --- Tests/dea_tools/test_spatial.py | 114 ++++++++++++++++- Tools/dea_tools/spatial.py | 210 +++++++++++++++++++++++++++++++- Tools/gen/dea_tools.spatial.rst | 2 +- 3 files changed, 323 insertions(+), 3 deletions(-) diff --git a/Tests/dea_tools/test_spatial.py b/Tests/dea_tools/test_spatial.py index 5e9030d3f..18b2b6e8b 100644 --- a/Tests/dea_tools/test_spatial.py +++ b/Tests/dea_tools/test_spatial.py @@ -8,7 +8,13 @@ import datacube from datacube.utils.masking import mask_invalid_data -from dea_tools.spatial import subpixel_contours, xr_vectorize, xr_rasterize +from dea_tools.spatial import ( + subpixel_contours, + xr_vectorize, + xr_rasterize, + xr_interpolate, +) +from dea_tools.validation import eval_metrics @pytest.fixture( @@ -94,6 +100,19 @@ def categorical_da(request): return da +# Test set of points covering the extent of `dem_da` +@pytest.fixture() +def points_gdf(): + return gpd.GeoDataFrame( + data={"z": [400, 800, 900, 1100, 1200, 1500]}, + geometry=gpd.points_from_xy( + x=[149.06, 149.06, 149.10, 149.16, 149.20, 149.20], + y=[-35.36, -35.22, -35.29, -35.29, -35.36, -35.22], + crs="EPSG:4326", + ), + ) + + @pytest.mark.parametrize( "attribute_col, expected_col", [ @@ -323,3 +342,96 @@ def test_subpixel_contours_dim(satellite_da): # # Verify that no error is raised if we provide the correct CRS # subpixel_contours(dem_da.drop_vars("spatial_ref"), z_values=700, crs="EPSG:4326") + + +@pytest.mark.parametrize( + "method", + ["linear", "cubic", "nearest", "rbf", "idw"], +) +def test_xr_interpolate(dem_da, points_gdf, method): + # Run interpolation and verify that pixel grids are the same and + # output contains data + interpolated_ds = xr_interpolate( + dem_da, + gdf=points_gdf, + method=method, + k=5, + ) + assert interpolated_ds.odc.geobox == dem_da.odc.geobox + assert "z" in interpolated_ds.data_vars + assert interpolated_ds["z"].notnull().sum() > 0 + + # Sample interpolated values at each point, and verify that + # interpolated z values match our input z values + xs = xr.DataArray(points_gdf.to_crs(dem_da.odc.crs).geometry.x, dims="z") + ys = xr.DataArray(points_gdf.to_crs(dem_da.odc.crs).geometry.y, dims="z") + sampled = interpolated_ds["z"].interp(x=xs, y=ys, method="nearest") + val_stats = eval_metrics(points_gdf.z, sampled) + assert val_stats.Correlation > 0.9 + assert val_stats.MAE < 10 + + # Verify that a factor above 1 still returns expected results + interpolated_ds_factor10 = xr_interpolate( + dem_da, + gdf=points_gdf, + method=method, + k=5, + factor=10, + ) + assert interpolated_ds_factor10.odc.geobox == dem_da.odc.geobox + assert "z" in interpolated_ds_factor10.data_vars + assert interpolated_ds_factor10["z"].notnull().sum() > 0 + + # Verify that multiple columns can be processed, and that output + # includes only numeric vars + points_gdf["num_var"] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] + points_gdf["obj_var"] = ["a", "b", "c", "d", "e", "f"] + interpolated_ds_cols = xr_interpolate( + dem_da, + gdf=points_gdf, + method=method, + k=5, + ) + assert "z" in interpolated_ds_cols.data_vars + assert "num_var" in interpolated_ds_cols.data_vars + assert "obj_var" not in interpolated_ds_cols.data_vars + + # Verify that specific columns can be selected + interpolated_ds_cols2 = xr_interpolate( + dem_da, + gdf=points_gdf, + columns=["num_var"], + method=method, + k=5, + ) + assert "z" not in interpolated_ds_cols2.data_vars + assert "num_var" in interpolated_ds_cols2.data_vars + + # Verify that error is raised if no numeric columns exist + with pytest.raises(ValueError): + xr_interpolate( + dem_da, + gdf=points_gdf, + columns=["obj_var"], + method=method, + k=5, + ) + + # Verify that error is raised if `gdf` doesn't overlap with `ds` + with pytest.raises(ValueError): + xr_interpolate( + dem_da, + gdf=points_gdf.set_crs("EPSG:3577", allow_override=True), + method=method, + k=5, + ) + + # If IDW method, verify that k will fail if greater than points + if method == "idw": + with pytest.raises(ValueError): + xr_interpolate( + dem_da, + gdf=points_gdf, + method=method, + k=10, + ) diff --git a/Tools/dea_tools/spatial.py b/Tools/dea_tools/spatial.py index 1ca8b005f..4917de8ea 100644 --- a/Tools/dea_tools/spatial.py +++ b/Tools/dea_tools/spatial.py @@ -16,7 +16,7 @@ If you would like to report an issue with this script, file one on GitHub: https://github.com/GeoscienceAustralia/dea-notebooks/issues/new -Last modified: August 2023 +Last modified: March 2024 """ @@ -36,6 +36,7 @@ from odc.geo.geom import Geometry from odc.geo.crs import CRS from scipy import ndimage as nd +from scipy.spatial import cKDTree as KDTree from skimage.measure import label from rasterstats import zonal_stats from skimage.measure import find_contours @@ -566,6 +567,203 @@ def _time_format(i, time_format): return contours_gdf +def xr_interpolate( + ds, + gdf, + columns=None, + method="linear", + factor=1, + k=10, + crs=None, + **kwargs, +): + """ + This function takes a geopandas.GeoDataFrame points dataset + containing one or more numeric columns, and interpolates these points + into the spatial extent of an existing xarray dataset. This can be + useful for producing smooth raster surfaces from point data to + compare directly against satellite data. + + Supported interpolation methods include "linear", "nearest" and + "cubic" (using `scipy.interpolate.griddata`), "rbf" (using + `scipy.interpolate.Rbf`), and "idw" (Inverse Distance Weighted + interpolation using `k` nearest neighbours). Each numeric column + will be returned as a variable in the output xarray.Dataset. + + Last modified: March 2024 + + Parameters + ---------- + ds : xarray.DataArray or xarray.Dataset + A two-dimensional or multi-dimensional array whose spatial extent + will be used to interpolate point data into. + gdf : geopandas.GeoDataFrame + A dataset of spatial points including at least one numeric column. + By default all numeric columns in this dataset will be spatially + interpolated into the extent of `ds`; specific columns can be + selected using `columns`. An error will be raised if the points + in `gdf` do not overlap with the extent of `ds`. + columns : list, optional + An optional list of specific columns in gdf` to run the + interpolation on. These must all be of numeric data types. + method : string, optional + The method used to interpolate between point values. This string + is either passed to `scipy.interpolate.griddata` (for "linear", + "nearest" and "cubic" methods), or used to specify Radial Basis + Function interpolation using `scipy.interpolate.Rbf` ("rbf"), or + Inverse Distance Weighted interpolation ("idw"). + Defaults to 'linear'. + factor : int, optional + An optional integer that can be used to subsample the spatial + interpolation extent to obtain faster interpolation times, before + up-sampling the array back to the original dimensions of the + data as a final step. For example, `factor=10` will interpolate + data into a grid that has one tenth of the resolution of `ds`. + This will be significantly faster than interpolating at full + resolution, but will potentially produce less accurate results. + k : int, optional + The number of nearest neighbours used to calculate weightings if + `method` is "idw". Defaults to 10; setting `k=1` is equivalent to + "nearest" interpolation. + crs : string or CRS object, optional + If `ds`'s coordinate reference system (CRS) cannot be determined, + provide a CRS using this parameter (e.g. 'EPSG:3577'). + **kwargs : + Optional keyword arguments to pass to either + `scipy.interpolate.griddata` (if `method` is "linear", "nearest" + or "cubic"), or `scipy.interpolate.Rbf` (is `method` is "rbf"). + + Returns + ------- + interpolated_ds : xarray.Dataset + An xarray.Dataset containing interpolated data with the same X + and Y coordinate pixel grid as `ds`, and a data variable for + each numeric column in `gdf`. + """ + + # Add GeoBox and odc.* accessor to array using `odc-geo`, and identify + # spatial dimension names from `ds` + ds = add_geobox(ds, crs) + y_dim, x_dim = ds.odc.spatial_dims + + # Reproject to match input `ds`, and raise error if there are no overlaps + gdf = gdf.to_crs(ds.odc.crs) + if not gdf.dissolve().intersects(ds.odc.geobox.extent.geom).item(): + raise ValueError("The supplied `gdf` does not overlap spatially with `ds`.") + + # Select subset of numeric columns (non-numeric are not supported) + numeric_gdf = gdf.select_dtypes("number") + + # Subset further to supplied `columns` + try: + numeric_gdf = numeric_gdf if columns is None else numeric_gdf[columns] + except KeyError: + raise ValueError( + "One or more of the provided columns either does " + "not exist in `gdf`, or is a non-numeric column. " + "Only numeric columns are supported by `xr_interpolate`." + ) + + # Raise a warning if no numeric columns exist after selection + if len(numeric_gdf.columns) == 0: + raise ValueError( + "The provided `gdf` contains no numeric columns to interpolate." + ) + + # Identify spatial coordinates, and stack to use in interpolation + x_coords = gdf.geometry.x + y_coords = gdf.geometry.y + points_xy = np.vstack([x_coords, y_coords]).T + + # Identify x and y coordinates from `ds` to interpolate into. + # If `factor` is greater than 1, the coordinates will be subsampled + # for faster run-times. If the last x or y value in the subsampled + # grid aren't the same as the last x or y values in the original + # full resolution grid, add the final full resolution grid value to + # ensure data is interpolated up to the very edge of the array + if ds[x_dim][::factor][-1].item() == ds[x_dim][-1].item(): + x_grid_coords = ds[x_dim][::factor].values + else: + x_grid_coords = ds[x_dim][::factor].values.tolist() + [ds[x_dim][-1].item()] + + if ds[y_dim][::factor][-1].item() == ds[y_dim][-1].item(): + y_grid_coords = ds[y_dim][::factor].values + else: + y_grid_coords = ds[y_dim][::factor].values.tolist() + [ds[y_dim][-1].item()] + + # Create grid to interpolate into + grid_y, grid_x = np.meshgrid(x_grid_coords, y_grid_coords) + + # Output dict + correlation_outputs = {} + + # For each numeric column, run interpolation + for col, z_values in numeric_gdf.items(): + # Apply scipy.interpolate.griddata interpolation methods + if method in ("linear", "nearest", "cubic"): + # Interpolate x, y and z values + interp_2d = scipy.interpolate.griddata( + points=points_xy, + values=z_values, + xi=(grid_y, grid_x), + method=method, + **kwargs, + ) + + # Apply Radial Basis Function interpolation + elif method == "rbf": + # Interpolate x, y and z values + rbf = scipy.interpolate.Rbf(x_coords, y_coords, z_values, **kwargs) + interp_2d = rbf(grid_y, grid_x) + + # Apply Inverse Distance Weighted interpolation + # Code inspired by: https://github.com/DahnJ/REM-xarray + elif method == "idw": + # Verify k is smaller than total number of points + if k > len(z_values): + raise ValueError( + f"The requested number of nearest neighbours (`k={k}`) " + f"is smaller than the total number of points ({len(z_values)})." + ) + + # Create KDTree to efficiently find nearest neighbours + tree = KDTree(points_xy) + + # IWD interpolation + grid_stacked = np.column_stack((grid_y.flatten(), grid_x.flatten())) + distances, indices = tree.query(grid_stacked, k=k) + + # Calculate weights based on distance to k nearest neighbours. + # If k == 1, then return the nearest value unweighted. + if k > 1: + weights = 1 / distances + weights = weights / weights.sum(axis=1).reshape(-1, 1) + interp_1d = (weights * z_values.values[indices]).sum(axis=1) + else: + interp_1d = z_values.values[indices] + + # Reshape to 2D + interp_2d = interp_1d.reshape(len(y_grid_coords), len(x_grid_coords)) + + # Add 2D interpolated array to output dictionary + correlation_outputs[col] = ((y_dim, x_dim), interp_2d) + + # Combine all outputs into a single xr.Dataset + interpolated_ds = xr.Dataset( + correlation_outputs, coords={y_dim: y_grid_coords, x_dim: x_grid_coords} + ) + + # If factor is greater than 1, resample the interpolated array to + # match the input `ds` array + if factor > 1: + interpolated_ds = interpolated_ds.interp_like(ds) + + # Ensure CRS is correctly set on output + interpolated_ds = interpolated_ds.odc.assign_crs(crs=ds.odc.crs) + + return interpolated_ds + + def interpolate_2d( ds, x_coords, y_coords, z_coords, method="linear", factor=1, verbose=False, **kwargs ): @@ -579,6 +777,9 @@ def interpolate_2d( Supported interpolation methods include 'linear', 'nearest' and 'cubic (using `scipy.interpolate.griddata`), and 'rbf' (using `scipy.interpolate.Rbf`). + + NOTE: This function is deprecated and will be retired in a future + release. Please use `xr_interpolate` instead." Last modified: February 2020 @@ -624,6 +825,13 @@ def interpolate_2d( from `ds_array`, and Z-values interpolated from the points data. """ + warnings.warn( + "This function is deprecated and will be retired in a future " + "release. Please use `xr_interpolate` instead.", + DeprecationWarning, + stacklevel=2, + ) + # Extract xy and elev points points_xy = np.vstack([x_coords, y_coords]).T diff --git a/Tools/gen/dea_tools.spatial.rst b/Tools/gen/dea_tools.spatial.rst index ef6a908a4..8747f0550 100644 --- a/Tools/gen/dea_tools.spatial.rst +++ b/Tools/gen/dea_tools.spatial.rst @@ -16,13 +16,13 @@ add_geobox contours_to_arrays hillshade - interpolate_2d largest_region points_on_line reverse_geocode subpixel_contours sun_angles transform_geojson_wgs_to_epsg + xr_interpolate xr_rasterize xr_vectorize zonal_stats_parallel