Skip to content

Commit

Permalink
Add updated xr_interpolate func and tests (#1205)
Browse files Browse the repository at this point in the history
* Add updated xr_interpolate func and tests

* Rename test

* Remove commented line

* Update dea_tools.spatial.rst
  • Loading branch information
robbibt authored Mar 12, 2024
1 parent 324c485 commit 77a954a
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 3 deletions.
114 changes: 113 additions & 1 deletion Tests/dea_tools/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import datacube
from datacube.utils.masking import mask_invalid_data

from dea_tools.spatial import subpixel_contours, xr_vectorize, xr_rasterize
from dea_tools.spatial import (
subpixel_contours,
xr_vectorize,
xr_rasterize,
xr_interpolate,
)
from dea_tools.validation import eval_metrics


@pytest.fixture(
Expand Down Expand Up @@ -94,6 +100,19 @@ def categorical_da(request):
return da


# Test set of points covering the extent of `dem_da`
@pytest.fixture()
def points_gdf():
return gpd.GeoDataFrame(
data={"z": [400, 800, 900, 1100, 1200, 1500]},
geometry=gpd.points_from_xy(
x=[149.06, 149.06, 149.10, 149.16, 149.20, 149.20],
y=[-35.36, -35.22, -35.29, -35.29, -35.36, -35.22],
crs="EPSG:4326",
),
)


@pytest.mark.parametrize(
"attribute_col, expected_col",
[
Expand Down Expand Up @@ -323,3 +342,96 @@ def test_subpixel_contours_dim(satellite_da):

# # Verify that no error is raised if we provide the correct CRS
# subpixel_contours(dem_da.drop_vars("spatial_ref"), z_values=700, crs="EPSG:4326")


@pytest.mark.parametrize(
"method",
["linear", "cubic", "nearest", "rbf", "idw"],
)
def test_xr_interpolate(dem_da, points_gdf, method):
# Run interpolation and verify that pixel grids are the same and
# output contains data
interpolated_ds = xr_interpolate(
dem_da,
gdf=points_gdf,
method=method,
k=5,
)
assert interpolated_ds.odc.geobox == dem_da.odc.geobox
assert "z" in interpolated_ds.data_vars
assert interpolated_ds["z"].notnull().sum() > 0

# Sample interpolated values at each point, and verify that
# interpolated z values match our input z values
xs = xr.DataArray(points_gdf.to_crs(dem_da.odc.crs).geometry.x, dims="z")
ys = xr.DataArray(points_gdf.to_crs(dem_da.odc.crs).geometry.y, dims="z")
sampled = interpolated_ds["z"].interp(x=xs, y=ys, method="nearest")
val_stats = eval_metrics(points_gdf.z, sampled)
assert val_stats.Correlation > 0.9
assert val_stats.MAE < 10

# Verify that a factor above 1 still returns expected results
interpolated_ds_factor10 = xr_interpolate(
dem_da,
gdf=points_gdf,
method=method,
k=5,
factor=10,
)
assert interpolated_ds_factor10.odc.geobox == dem_da.odc.geobox
assert "z" in interpolated_ds_factor10.data_vars
assert interpolated_ds_factor10["z"].notnull().sum() > 0

# Verify that multiple columns can be processed, and that output
# includes only numeric vars
points_gdf["num_var"] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
points_gdf["obj_var"] = ["a", "b", "c", "d", "e", "f"]
interpolated_ds_cols = xr_interpolate(
dem_da,
gdf=points_gdf,
method=method,
k=5,
)
assert "z" in interpolated_ds_cols.data_vars
assert "num_var" in interpolated_ds_cols.data_vars
assert "obj_var" not in interpolated_ds_cols.data_vars

# Verify that specific columns can be selected
interpolated_ds_cols2 = xr_interpolate(
dem_da,
gdf=points_gdf,
columns=["num_var"],
method=method,
k=5,
)
assert "z" not in interpolated_ds_cols2.data_vars
assert "num_var" in interpolated_ds_cols2.data_vars

# Verify that error is raised if no numeric columns exist
with pytest.raises(ValueError):
xr_interpolate(
dem_da,
gdf=points_gdf,
columns=["obj_var"],
method=method,
k=5,
)

# Verify that error is raised if `gdf` doesn't overlap with `ds`
with pytest.raises(ValueError):
xr_interpolate(
dem_da,
gdf=points_gdf.set_crs("EPSG:3577", allow_override=True),
method=method,
k=5,
)

# If IDW method, verify that k will fail if greater than points
if method == "idw":
with pytest.raises(ValueError):
xr_interpolate(
dem_da,
gdf=points_gdf,
method=method,
k=10,
)
210 changes: 209 additions & 1 deletion Tools/dea_tools/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
If you would like to report an issue with this script, file one on
GitHub: https://github.com/GeoscienceAustralia/dea-notebooks/issues/new
Last modified: August 2023
Last modified: March 2024
"""

Expand All @@ -36,6 +36,7 @@
from odc.geo.geom import Geometry
from odc.geo.crs import CRS
from scipy import ndimage as nd
from scipy.spatial import cKDTree as KDTree
from skimage.measure import label
from rasterstats import zonal_stats
from skimage.measure import find_contours
Expand Down Expand Up @@ -566,6 +567,203 @@ def _time_format(i, time_format):
return contours_gdf


def xr_interpolate(
ds,
gdf,
columns=None,
method="linear",
factor=1,
k=10,
crs=None,
**kwargs,
):
"""
This function takes a geopandas.GeoDataFrame points dataset
containing one or more numeric columns, and interpolates these points
into the spatial extent of an existing xarray dataset. This can be
useful for producing smooth raster surfaces from point data to
compare directly against satellite data.
Supported interpolation methods include "linear", "nearest" and
"cubic" (using `scipy.interpolate.griddata`), "rbf" (using
`scipy.interpolate.Rbf`), and "idw" (Inverse Distance Weighted
interpolation using `k` nearest neighbours). Each numeric column
will be returned as a variable in the output xarray.Dataset.
Last modified: March 2024
Parameters
----------
ds : xarray.DataArray or xarray.Dataset
A two-dimensional or multi-dimensional array whose spatial extent
will be used to interpolate point data into.
gdf : geopandas.GeoDataFrame
A dataset of spatial points including at least one numeric column.
By default all numeric columns in this dataset will be spatially
interpolated into the extent of `ds`; specific columns can be
selected using `columns`. An error will be raised if the points
in `gdf` do not overlap with the extent of `ds`.
columns : list, optional
An optional list of specific columns in gdf` to run the
interpolation on. These must all be of numeric data types.
method : string, optional
The method used to interpolate between point values. This string
is either passed to `scipy.interpolate.griddata` (for "linear",
"nearest" and "cubic" methods), or used to specify Radial Basis
Function interpolation using `scipy.interpolate.Rbf` ("rbf"), or
Inverse Distance Weighted interpolation ("idw").
Defaults to 'linear'.
factor : int, optional
An optional integer that can be used to subsample the spatial
interpolation extent to obtain faster interpolation times, before
up-sampling the array back to the original dimensions of the
data as a final step. For example, `factor=10` will interpolate
data into a grid that has one tenth of the resolution of `ds`.
This will be significantly faster than interpolating at full
resolution, but will potentially produce less accurate results.
k : int, optional
The number of nearest neighbours used to calculate weightings if
`method` is "idw". Defaults to 10; setting `k=1` is equivalent to
"nearest" interpolation.
crs : string or CRS object, optional
If `ds`'s coordinate reference system (CRS) cannot be determined,
provide a CRS using this parameter (e.g. 'EPSG:3577').
**kwargs :
Optional keyword arguments to pass to either
`scipy.interpolate.griddata` (if `method` is "linear", "nearest"
or "cubic"), or `scipy.interpolate.Rbf` (is `method` is "rbf").
Returns
-------
interpolated_ds : xarray.Dataset
An xarray.Dataset containing interpolated data with the same X
and Y coordinate pixel grid as `ds`, and a data variable for
each numeric column in `gdf`.
"""

# Add GeoBox and odc.* accessor to array using `odc-geo`, and identify
# spatial dimension names from `ds`
ds = add_geobox(ds, crs)
y_dim, x_dim = ds.odc.spatial_dims

# Reproject to match input `ds`, and raise error if there are no overlaps
gdf = gdf.to_crs(ds.odc.crs)
if not gdf.dissolve().intersects(ds.odc.geobox.extent.geom).item():
raise ValueError("The supplied `gdf` does not overlap spatially with `ds`.")

# Select subset of numeric columns (non-numeric are not supported)
numeric_gdf = gdf.select_dtypes("number")

# Subset further to supplied `columns`
try:
numeric_gdf = numeric_gdf if columns is None else numeric_gdf[columns]
except KeyError:
raise ValueError(
"One or more of the provided columns either does "
"not exist in `gdf`, or is a non-numeric column. "
"Only numeric columns are supported by `xr_interpolate`."
)

# Raise a warning if no numeric columns exist after selection
if len(numeric_gdf.columns) == 0:
raise ValueError(
"The provided `gdf` contains no numeric columns to interpolate."
)

# Identify spatial coordinates, and stack to use in interpolation
x_coords = gdf.geometry.x
y_coords = gdf.geometry.y
points_xy = np.vstack([x_coords, y_coords]).T

# Identify x and y coordinates from `ds` to interpolate into.
# If `factor` is greater than 1, the coordinates will be subsampled
# for faster run-times. If the last x or y value in the subsampled
# grid aren't the same as the last x or y values in the original
# full resolution grid, add the final full resolution grid value to
# ensure data is interpolated up to the very edge of the array
if ds[x_dim][::factor][-1].item() == ds[x_dim][-1].item():
x_grid_coords = ds[x_dim][::factor].values
else:
x_grid_coords = ds[x_dim][::factor].values.tolist() + [ds[x_dim][-1].item()]

if ds[y_dim][::factor][-1].item() == ds[y_dim][-1].item():
y_grid_coords = ds[y_dim][::factor].values
else:
y_grid_coords = ds[y_dim][::factor].values.tolist() + [ds[y_dim][-1].item()]

# Create grid to interpolate into
grid_y, grid_x = np.meshgrid(x_grid_coords, y_grid_coords)

# Output dict
correlation_outputs = {}

# For each numeric column, run interpolation
for col, z_values in numeric_gdf.items():
# Apply scipy.interpolate.griddata interpolation methods
if method in ("linear", "nearest", "cubic"):
# Interpolate x, y and z values
interp_2d = scipy.interpolate.griddata(
points=points_xy,
values=z_values,
xi=(grid_y, grid_x),
method=method,
**kwargs,
)

# Apply Radial Basis Function interpolation
elif method == "rbf":
# Interpolate x, y and z values
rbf = scipy.interpolate.Rbf(x_coords, y_coords, z_values, **kwargs)
interp_2d = rbf(grid_y, grid_x)

# Apply Inverse Distance Weighted interpolation
# Code inspired by: https://github.com/DahnJ/REM-xarray
elif method == "idw":
# Verify k is smaller than total number of points
if k > len(z_values):
raise ValueError(
f"The requested number of nearest neighbours (`k={k}`) "
f"is smaller than the total number of points ({len(z_values)})."
)

# Create KDTree to efficiently find nearest neighbours
tree = KDTree(points_xy)

# IWD interpolation
grid_stacked = np.column_stack((grid_y.flatten(), grid_x.flatten()))
distances, indices = tree.query(grid_stacked, k=k)

# Calculate weights based on distance to k nearest neighbours.
# If k == 1, then return the nearest value unweighted.
if k > 1:
weights = 1 / distances
weights = weights / weights.sum(axis=1).reshape(-1, 1)
interp_1d = (weights * z_values.values[indices]).sum(axis=1)
else:
interp_1d = z_values.values[indices]

# Reshape to 2D
interp_2d = interp_1d.reshape(len(y_grid_coords), len(x_grid_coords))

# Add 2D interpolated array to output dictionary
correlation_outputs[col] = ((y_dim, x_dim), interp_2d)

# Combine all outputs into a single xr.Dataset
interpolated_ds = xr.Dataset(
correlation_outputs, coords={y_dim: y_grid_coords, x_dim: x_grid_coords}
)

# If factor is greater than 1, resample the interpolated array to
# match the input `ds` array
if factor > 1:
interpolated_ds = interpolated_ds.interp_like(ds)

# Ensure CRS is correctly set on output
interpolated_ds = interpolated_ds.odc.assign_crs(crs=ds.odc.crs)

return interpolated_ds


def interpolate_2d(
ds, x_coords, y_coords, z_coords, method="linear", factor=1, verbose=False, **kwargs
):
Expand All @@ -579,6 +777,9 @@ def interpolate_2d(
Supported interpolation methods include 'linear', 'nearest' and
'cubic (using `scipy.interpolate.griddata`), and 'rbf' (using
`scipy.interpolate.Rbf`).
NOTE: This function is deprecated and will be retired in a future
release. Please use `xr_interpolate` instead."
Last modified: February 2020
Expand Down Expand Up @@ -624,6 +825,13 @@ def interpolate_2d(
from `ds_array`, and Z-values interpolated from the points data.
"""

warnings.warn(
"This function is deprecated and will be retired in a future "
"release. Please use `xr_interpolate` instead.",
DeprecationWarning,
stacklevel=2,
)

# Extract xy and elev points
points_xy = np.vstack([x_coords, y_coords]).T

Expand Down
2 changes: 1 addition & 1 deletion Tools/gen/dea_tools.spatial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
add_geobox
contours_to_arrays
hillshade
interpolate_2d
largest_region
points_on_line
reverse_geocode
subpixel_contours
sun_angles
transform_geojson_wgs_to_epsg
xr_interpolate
xr_rasterize
xr_vectorize
zonal_stats_parallel
Expand Down

0 comments on commit 77a954a

Please sign in to comment.