diff --git a/.gitignore b/.gitignore index 7a1caf6..a083585 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,5 @@ logs/* .vscode src/coastpy/_version.py + +.idea diff --git a/src/coastpy/geo/ops.py b/src/coastpy/geo/ops.py index 5860846..67fd73b 100644 --- a/src/coastpy/geo/ops.py +++ b/src/coastpy/geo/ops.py @@ -425,7 +425,7 @@ def generate_offset_line(line: LineString, offset: float) -> LineString: return line.offset_curve(offset) if offset != 0 else line -def determine_rotation_angle( +def get_rotation_angle( pt1: Point | tuple[float, float], pt2: Point | tuple[float, float], target_axis: Literal[ @@ -433,7 +433,7 @@ def determine_rotation_angle( ] = "closest", ) -> float | None: """ - Determines the correct rotation angle to align a transect with a specified axis. + Computes the correct rotation angle to align with a specified axis. Args: pt1 (Union[Point, Tuple[float, float]]): The starting point of the transect. @@ -479,6 +479,7 @@ def determine_rotation_angle( (270, 360): lambda b: b - 270, } + # TODO: rename to landward right elif target_axis == "horizontal-right-aligned": angle_rotations = { (0, 90): lambda b: 90 + b, diff --git a/src/coastpy/utils/xarray.py b/src/coastpy/utils/xarray.py index ac5c628..c0499e0 100644 --- a/src/coastpy/utils/xarray.py +++ b/src/coastpy/utils/xarray.py @@ -1,7 +1,6 @@ import warnings import numpy as np -import rasterio import xarray as xr from affine import Affine from rasterio.enums import Resampling @@ -171,69 +170,71 @@ def interpolate_raster( ds: xr.Dataset, y_shape: int, x_shape: int, - resampling: rasterio.enums.Resampling = rasterio.enums.Resampling.nearest, + resampling: Resampling, ) -> xr.Dataset: """ - Interpolates a given raster (xarray Dataset) to a specified resolution using the provided method. + Interpolates a given raster (xarray Dataset) to a specified shape using Rasterio resampling methods. Args: ds (xr.Dataset): The input raster to interpolate. - y_shape (int): Desired number of grid points along y dimension. - x_shape (int): Desired number of grid points along x dimension. + y_shape (int): Desired number of grid points along the y dimension. + x_shape (int): Desired number of grid points along the x dimension. resampling: rasterio.enums.Resampling: The interpolation method to use. Returns: - xr.Dataset: Interpolated raster without geospatial metadata. - - Example: - >>> ds = xr.Dataset(data_vars={"var": (("x", "y"), np.random.randn(10, 10))}, - ... coords={"x": np.linspace(0, 9, 10), "y": np.linspace(0, 9, 10)}) - >>> interpolated_ds = interpolate_raster(ds, 20, 20) - >>> print(interpolated_ds.dims) - {'x': 20, 'y': 20} + xr.Dataset: Interpolated raster with updated geospatial metadata. """ - # swap dims if y is longer than x - if ds.dims["x"] < ds.dims["y"]: - y_shape, x_shape = x_shape, y_shape + # Compute the target transformation based on the desired shape - # this will be the new grid - new_y = np.linspace(ds.y.min(), ds.y.max(), y_shape) - new_x = np.linspace(ds.x.min(), ds.x.max(), x_shape) + transform = ds.rio.transform() - interpolated = ds.interp(y=new_y, x=new_x, method=resampling) + # Define the target transform for the new resolution + target_transform = transform * transform.scale( + (ds.sizes["x"] / x_shape), (ds.sizes["y"] / y_shape) + ) - # add new coords because the old ones are now two dimensional - interpolated = interpolated.assign_coords(y=range(y_shape), x=range(x_shape)) + # Create a template for the new shape + out_shape = ( + (ds.sizes["band"], y_shape, x_shape) + if "band" in ds.dims + else (y_shape, x_shape) + ) - # the transformation matrix can be computed by scaling the src one - src_dims = ds.dims - src_transform = ds.rio.transform() - x_scale = src_dims["x"] / x_shape - y_scale = src_dims["y"] / y_shape + # Reproject the dataset to the new shape using rasterio + interpolated = ds.rio.reproject( + ds.rio.crs, + shape=out_shape, + transform=target_transform, + resampling=resampling, + nodata=np.nan, + ) - # write new transformation matrix to the interplated dataset - dst_transform = src_transform * Affine.scale(x_scale, y_scale) - interpolated = interpolated.rio.write_transform(dst_transform) + interpolated = interpolated.rio.write_transform(target_transform) return interpolated +import xarray as xr + + def trim_outer_nans( data: xr.DataArray | xr.Dataset, - crop_size: int = 1, nodata: float | int | None = None, + crop_size: int = 0, ) -> xr.DataArray | xr.Dataset: """ Trim the outer nodata or NaN values from an xarray DataArray or Dataset, returning a bounding box around the data. Args: data (xr.DataArray | xr.Dataset): Input DataArray or Dataset with potential outer NaN or nodata values. - crop_size (int): The number of pixels to crop from the outer edges of the data. Defaults to 1. nodata (float | int | None): Optional no-data value to use for trimming. Defaults to None, which uses NaN. + crop_size (int, optional): The number of pixels to crop from the outer edges of the data after trimming. + Defaults to 0 (no additional cropping). Returns: - (xr.DataArray | xr.Dataset): A DataArray or Dataset trimmed of its outer NaN or nodata values. + (xr.DataArray | xr.Dataset): A DataArray or Dataset trimmed of its outer NaN or nodata values, with optional + additional cropping applied. """ # Determine the representative DataArray for NaN or nodata calculation @@ -251,18 +252,24 @@ def trim_outer_nans( if not y_valid.size or not x_valid.size: return data - # Compute bounding indices (adjusting by the crop_size in pixels as before) - y_min, y_max = y_valid.min() + crop_size, y_valid.max() - crop_size - x_min, x_max = x_valid.min() + crop_size, x_valid.max() - crop_size + # Compute bounding indices with optional additional cropping + y_min, y_max = ( + max(y_valid.min() + crop_size, 0), + min(y_valid.max() - crop_size, ref_data_array.shape[0] - 1), + ) + x_min, x_max = ( + max(x_valid.min() + crop_size, 0), + min(x_valid.max() - crop_size, ref_data_array.shape[1] - 1), + ) - # In Python slicing is end exclusive, so we need to add 1 to the max indices + # In Python slicing is end-exclusive, so we need to add 1 to the max indices trimmed_data = data.isel(y=slice(y_min, y_max + 1), x=slice(x_min, x_max + 1)) - # Adjust the x,y offsets, also taking into account the rotation + # Adjust the x,y offsets, taking into account the rotation and translation new_c = transform.c + x_min * transform.a + y_min * transform.b new_f = transform.f + x_min * transform.d + y_min * transform.e - # Make the new transformation matrix and write to the array + # Create the new transformation matrix new_transform = Affine( transform.a, transform.b, @@ -271,6 +278,8 @@ def trim_outer_nans( transform.e, new_f, ) + + # Apply the new transformation to the trimmed data trimmed_data = trimmed_data.rio.write_transform(new_transform) return trimmed_data