Skip to content

Commit

Permalink
Wip/typology (#6)
Browse files Browse the repository at this point in the history
* enhancements

* wip typology

* add idea for pycharm

---------

Co-authored-by: floriscalkoen <floris_calkoen@hotmail.com>
  • Loading branch information
FlorisCalkoen and floriscalkoen authored Sep 11, 2024
1 parent 3180de0 commit cf6eab9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ logs/*
.vscode

src/coastpy/_version.py

.idea
5 changes: 3 additions & 2 deletions src/coastpy/geo/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,15 @@ def generate_offset_line(line: LineString, offset: float) -> LineString:
return line.offset_curve(offset) if offset != 0 else line


def determine_rotation_angle(
def get_rotation_angle(
pt1: Point | tuple[float, float],
pt2: Point | tuple[float, float],
target_axis: Literal[
"closest", "vertical", "horizontal", "horizontal-right-aligned"
] = "closest",
) -> float | None:
"""
Determines the correct rotation angle to align a transect with a specified axis.
Computes the correct rotation angle to align with a specified axis.
Args:
pt1 (Union[Point, Tuple[float, float]]): The starting point of the transect.
Expand Down Expand Up @@ -479,6 +479,7 @@ def determine_rotation_angle(
(270, 360): lambda b: b - 270,
}

# TODO: rename to landward right
elif target_axis == "horizontal-right-aligned":
angle_rotations = {
(0, 90): lambda b: 90 + b,
Expand Down
87 changes: 48 additions & 39 deletions src/coastpy/utils/xarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import warnings

import numpy as np
import rasterio
import xarray as xr
from affine import Affine
from rasterio.enums import Resampling
Expand Down Expand Up @@ -171,69 +170,71 @@ def interpolate_raster(
ds: xr.Dataset,
y_shape: int,
x_shape: int,
resampling: rasterio.enums.Resampling = rasterio.enums.Resampling.nearest,
resampling: Resampling,
) -> xr.Dataset:
"""
Interpolates a given raster (xarray Dataset) to a specified resolution using the provided method.
Interpolates a given raster (xarray Dataset) to a specified shape using Rasterio resampling methods.
Args:
ds (xr.Dataset): The input raster to interpolate.
y_shape (int): Desired number of grid points along y dimension.
x_shape (int): Desired number of grid points along x dimension.
y_shape (int): Desired number of grid points along the y dimension.
x_shape (int): Desired number of grid points along the x dimension.
resampling: rasterio.enums.Resampling: The interpolation method to use.
Returns:
xr.Dataset: Interpolated raster without geospatial metadata.
Example:
>>> ds = xr.Dataset(data_vars={"var": (("x", "y"), np.random.randn(10, 10))},
... coords={"x": np.linspace(0, 9, 10), "y": np.linspace(0, 9, 10)})
>>> interpolated_ds = interpolate_raster(ds, 20, 20)
>>> print(interpolated_ds.dims)
{'x': 20, 'y': 20}
xr.Dataset: Interpolated raster with updated geospatial metadata.
"""

# swap dims if y is longer than x
if ds.dims["x"] < ds.dims["y"]:
y_shape, x_shape = x_shape, y_shape
# Compute the target transformation based on the desired shape

# this will be the new grid
new_y = np.linspace(ds.y.min(), ds.y.max(), y_shape)
new_x = np.linspace(ds.x.min(), ds.x.max(), x_shape)
transform = ds.rio.transform()

interpolated = ds.interp(y=new_y, x=new_x, method=resampling)
# Define the target transform for the new resolution
target_transform = transform * transform.scale(
(ds.sizes["x"] / x_shape), (ds.sizes["y"] / y_shape)
)

# add new coords because the old ones are now two dimensional
interpolated = interpolated.assign_coords(y=range(y_shape), x=range(x_shape))
# Create a template for the new shape
out_shape = (
(ds.sizes["band"], y_shape, x_shape)
if "band" in ds.dims
else (y_shape, x_shape)
)

# the transformation matrix can be computed by scaling the src one
src_dims = ds.dims
src_transform = ds.rio.transform()
x_scale = src_dims["x"] / x_shape
y_scale = src_dims["y"] / y_shape
# Reproject the dataset to the new shape using rasterio
interpolated = ds.rio.reproject(
ds.rio.crs,
shape=out_shape,
transform=target_transform,
resampling=resampling,
nodata=np.nan,
)

# write new transformation matrix to the interplated dataset
dst_transform = src_transform * Affine.scale(x_scale, y_scale)
interpolated = interpolated.rio.write_transform(dst_transform)
interpolated = interpolated.rio.write_transform(target_transform)

return interpolated


import xarray as xr


def trim_outer_nans(
data: xr.DataArray | xr.Dataset,
crop_size: int = 1,
nodata: float | int | None = None,
crop_size: int = 0,
) -> xr.DataArray | xr.Dataset:
"""
Trim the outer nodata or NaN values from an xarray DataArray or Dataset, returning a bounding box around the data.
Args:
data (xr.DataArray | xr.Dataset): Input DataArray or Dataset with potential outer NaN or nodata values.
crop_size (int): The number of pixels to crop from the outer edges of the data. Defaults to 1.
nodata (float | int | None): Optional no-data value to use for trimming. Defaults to None, which uses NaN.
crop_size (int, optional): The number of pixels to crop from the outer edges of the data after trimming.
Defaults to 0 (no additional cropping).
Returns:
(xr.DataArray | xr.Dataset): A DataArray or Dataset trimmed of its outer NaN or nodata values.
(xr.DataArray | xr.Dataset): A DataArray or Dataset trimmed of its outer NaN or nodata values, with optional
additional cropping applied.
"""

# Determine the representative DataArray for NaN or nodata calculation
Expand All @@ -251,18 +252,24 @@ def trim_outer_nans(
if not y_valid.size or not x_valid.size:
return data

# Compute bounding indices (adjusting by the crop_size in pixels as before)
y_min, y_max = y_valid.min() + crop_size, y_valid.max() - crop_size
x_min, x_max = x_valid.min() + crop_size, x_valid.max() - crop_size
# Compute bounding indices with optional additional cropping
y_min, y_max = (
max(y_valid.min() + crop_size, 0),
min(y_valid.max() - crop_size, ref_data_array.shape[0] - 1),
)
x_min, x_max = (
max(x_valid.min() + crop_size, 0),
min(x_valid.max() - crop_size, ref_data_array.shape[1] - 1),
)

# In Python slicing is end exclusive, so we need to add 1 to the max indices
# In Python slicing is end-exclusive, so we need to add 1 to the max indices
trimmed_data = data.isel(y=slice(y_min, y_max + 1), x=slice(x_min, x_max + 1))

# Adjust the x,y offsets, also taking into account the rotation
# Adjust the x,y offsets, taking into account the rotation and translation
new_c = transform.c + x_min * transform.a + y_min * transform.b
new_f = transform.f + x_min * transform.d + y_min * transform.e

# Make the new transformation matrix and write to the array
# Create the new transformation matrix
new_transform = Affine(
transform.a,
transform.b,
Expand All @@ -271,6 +278,8 @@ def trim_outer_nans(
transform.e,
new_f,
)

# Apply the new transformation to the trimmed data
trimmed_data = trimmed_data.rio.write_transform(new_transform)

return trimmed_data
Expand Down

0 comments on commit cf6eab9

Please sign in to comment.