From be62472954d3b3d7985804b175b6e964330fb5a3 Mon Sep 17 00:00:00 2001 From: Floris Calkoen <44444001+FlorisCalkoen@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:00:37 +0200 Subject: [PATCH] Fix gcts (#3) * filename * dask complete install * revised workflow * more flexible dask client manager * add features to gdf * only detect antimerdian crosses * revised gcts workflow and schema * lints * prepare run on slurm * revised launch settings * revised configs * . * . * . * avoid scattering * . * wip * wip * typo * wip * more workers * try adding countries around antimeridian * comment sampling * fix: geometry type as wkb and not bytes * revised stacs * revised generic buffer func * add offset rectangle * fix suffix error in name data * fix pd type checking * fix rotation angle --------- Co-authored-by: floriscalkoen Co-authored-by: Floris Calkoen --- analytics/hypsometry.ipynb | 19 +- ci/envs/312-coastal-full.yaml | 2 +- scripts/python/add_gcts_to_stac.py | 42 ++- scripts/python/make_gcts.py | 327 +++++++++---------- src/coastpy/geo/__init__.py | 3 +- src/coastpy/geo/ops.py | 157 +++++---- src/coastpy/geo/quadtiles_utils.py | 7 +- src/coastpy/geo/transect.py | 88 ++--- src/coastpy/geo/utils.py | 0 src/coastpy/io/partitioner.py | 50 ++- src/coastpy/io/utils.py | 120 ++++++- src/coastpy/libs/stac_table.py | 4 + src/coastpy/utils/__init__.py | 6 +- src/coastpy/utils/dask.py | 154 +++++++++ src/coastpy/utils/dask_utils.py | 51 --- src/coastpy/utils/pandas.py | 156 +++++++++ src/coastpy/utils/{size_utils.py => size.py} | 0 src/coastpy/utils/{xr_utils.py => xarray.py} | 61 ++-- 18 files changed, 844 insertions(+), 403 deletions(-) create mode 100644 src/coastpy/geo/utils.py create mode 100644 src/coastpy/utils/dask.py delete mode 100644 src/coastpy/utils/dask_utils.py create mode 100644 src/coastpy/utils/pandas.py rename src/coastpy/utils/{size_utils.py => size.py} (100%) rename src/coastpy/utils/{xr_utils.py => xarray.py} (85%) diff --git a/analytics/hypsometry.ipynb b/analytics/hypsometry.ipynb index 0d0e2e7..dcc381f 100644 --- a/analytics/hypsometry.ipynb +++ b/analytics/hypsometry.ipynb @@ -76,7 +76,7 @@ "GCTS_LANDWARD_CONTAINER = \"az://public/coastal-analytics/gcts-2000m-landward.parquet\"\n", "# NOTE: before we stored the results here, keep for a while as ref\n", "GCTS_ELEVATION_CONTAINER = \"az://public/coastal-analytics/gcts-2000m-elevation.parquet\"\n", - "# NOTE: Next iteration we will store it here and extract the profiles including the tr_name\n", + "# NOTE: Next iteration we will store it here and extract the profiles including the transect_id\n", "# GCTS_ELEVATION_CONTAINER = \"az://coastal-transect-repository/deltadtm-elevation.parquet\"\n", "H3_ELEVATION_CONTAINER = (\n", " f\"az://public/coastal-analytics/h3-l{H3_LEVEL}-pct-lt-{LOWER_THAN}m.parquet\"\n", @@ -131,11 +131,14 @@ "metadata": {}, "outputs": [], "source": [ - "from coastpy.utils.dask_utils import create_dask_client\n", + "from coastpy.utils.dask import DaskClientManager\n", "\n", "instance_type = configure_instance()\n", - "client = create_dask_client(instance_type)\n", - "client" + "client = DaskClientManager().create_client(\n", + " instance_type,\n", + " threads_per_worker=1,\n", + " processes=True,\n", + " )\n" ] }, { @@ -168,7 +171,7 @@ " return shapely.LineString([p1, p2])\n", "\n", " geoms = df.apply(extract_landward_side, axis=1)\n", - " return gpd.GeoDataFrame(df[\"tr_name\"], geometry=geoms, crs=4326)\n", + " return gpd.GeoDataFrame(df[\"transect_id\"], geometry=geoms, crs=4326)\n", "\n", " gcts_collection = coclico_catalog.get_child(\"gcts\")\n", " gcts_extents = read_items_extent(gcts_collection, columns=[\"geometry\", \"assets\"])\n", @@ -177,7 +180,7 @@ " # template GDF that matches what is retunred from map_extract_landward_side\n", " META = gpd.GeoDataFrame(\n", " {\n", - " \"tr_name\": gpd.GeoSeries([], dtype=str),\n", + " \"transect_id\": gpd.GeoSeries([], dtype=str),\n", " \"geometry\": gpd.GeoSeries([], dtype=GeometryDtype()),\n", " }\n", " )\n", @@ -185,7 +188,7 @@ " transects = dask_geopandas.read_parquet(\n", " gcts_hrefs,\n", " storage_options=storage_options,\n", - " columns=[\"tr_name\", \"geometry\", \"lon\", \"lat\"],\n", + " columns=[\"transect_id\", \"geometry\", \"lon\", \"lat\"],\n", " )\n", "\n", " transects = transects.map_partitions(map_extract_landward_side, meta=META)\n", @@ -241,7 +244,7 @@ " da = da.where(da != da.rio.nodata, np.nan)\n", " da = da.rio.write_nodata(np.nan)\n", "\n", - " # TODO: ensure that tr_name is tracked so that we can use the elevation data later at a transect level\n", + " # TODO: ensure that transect_id is tracked so that we can use the elevation data later at a transect level\n", " clipped = da.rio.clip(transects.geometry.to_list()).rename(\"band_data\")\n", "\n", " df = (\n", diff --git a/ci/envs/312-coastal-full.yaml b/ci/envs/312-coastal-full.yaml index 95e9d1e..3646ff5 100644 --- a/ci/envs/312-coastal-full.yaml +++ b/ci/envs/312-coastal-full.yaml @@ -127,7 +127,7 @@ dependencies: - cartopy - cfgrib - contextily - - dask + - dask[complete] - dask-geopandas - dask-image - dask-jobqueue diff --git a/scripts/python/add_gcts_to_stac.py b/scripts/python/add_gcts_to_stac.py index 287000a..cdf8e47 100644 --- a/scripts/python/add_gcts_to_stac.py +++ b/scripts/python/add_gcts_to_stac.py @@ -26,9 +26,13 @@ storage_account_name = "coclico" storage_options = {"account_name": storage_account_name, "credential": sas_token} +# NOTE: +TEST_RELEASE = True + # Container and URI configuration CONTAINER_NAME = "gcts" -PREFIX = "release/2024-03-18" +RELEASE_DATE = "2024-08-02" +PREFIX = f"release/{RELEASE_DATE}" CONTAINER_URI = f"az://{CONTAINER_NAME}/{PREFIX}" PARQUET_MEDIA_TYPE = "application/vnd.apache.parquet" LICENSE = "CC-BY-4.0" @@ -36,7 +40,6 @@ # Collection information COLLECTION_ID = "gcts" COLLECTION_TITLE = "Global Coastal Transect System (GCTS)" -DATE_TRANSECTS_CREATED = "2024-03-18" # Transect and zoom configuration TRANSECT_LENGTH = 2000 @@ -56,11 +59,14 @@ ASSET_DESCRIPTION = f"Parquet dataset with coastal transects ({TRANSECT_LENGTH} m) at 100 m alongshore resolution for this region." # GeoParquet STAC items -GEOPARQUET_STAC_ITEMS_HREF = f"az://items/{COLLECTION_ID}.parquet" +if TEST_RELEASE: + GEOPARQUET_STAC_ITEMS_HREF = f"az://items-test/{COLLECTION_ID}.parquet" +else: + GEOPARQUET_STAC_ITEMS_HREF = f"az://items/{COLLECTION_ID}.parquet" COLUMN_DESCRIPTIONS = [ { - "name": "tr_name", + "name": "transect_id", "type": "string", "description": "A unique identifier for each transect, constructed from three key components: the 'coastline_id', 'segment_id', and 'interpolated_distance'. The 'coastline_id' corresponds to the FID in OpenStreetMap (OSM) and is prefixed with 'cl'. The 'segment_id' indicates the segment of the OSM coastline split by a UTM grid, prefixed with 's'. The 'interpolated_distance' represents the distance from the starting point of the coastline to the transect, interpolated along the segment, and is prefixed with 'tr'. The complete structure is 'cl[coastline_id]s[segment_id]tr[interpolated_distance]', exemplified by 'cl32946s04tr08168547'. This composition ensures each transect name is a distinct and informative representation of its geographical and spatial attributes.", }, @@ -85,17 +91,17 @@ "description": "Well-Known Binary (WKB) representation of the transect as a linestring geometry.", }, { - "name": "coastline_is_closed", + "name": "osm_coastline_is_closed", "type": "bool", "description": "Indicates whether the source OpenStreetMap (OSM) coastline, from which the transects were derived, forms a closed loop. A value of 'true' suggests that the coastline represents an enclosed area, such as an island.", }, { - "name": "coastline_length", + "name": "osm_coastline_length", "type": "int32", "description": "Represents the total length of the source OpenStreetMap (OSM) coastline, that is summed across various UTM regions. It reflects the aggregate length of the original coastline from which the transects are derived.", }, { - "name": "utm_crs", + "name": "utm_epsg", "type": "int32", "description": "EPSG code representing the UTM Coordinate Reference System for the transect.", }, @@ -110,24 +116,24 @@ "description": "QuadKey corresponding to the transect origin location at zoom 12, following the Bing Maps Tile System for spatial indexing.", }, { - "name": "isoCountryCodeAlpha2", + "name": "continent", "type": "string", - "description": "ISO 3166-1 alpha-2 country code for the country in which the transect is located.", + "description": "Name of the continent in which the transect is located.", }, { - "name": "admin_level_1_name", + "name": "country", "type": "string", - "description": "Name of the first-level administrative division (e.g., country) in which the transect is located.", + "description": "ISO alpha-2 country code for the country in which the transect is located. The country data are extracted from Overture Maps (divisions).", }, { - "name": "isoSubCountryCode", + "name": "common_country_name", "type": "string", - "description": "ISO code for the sub-country or second-level administrative division in which the transect is located.", + "description": "Common country name (EN) in which the transect is located. The country data are extracted from Overture Maps (divisions).", }, { - "name": "admin_level_2_name", + "name": "common_region_name", "type": "string", - "description": "Name of the second-level administrative division (e.g., state or province) in which the transect is located.", + "description": "Common region name (EN) in which the transect is located. The regions are extracted from Overture Maps (divisions).", }, ] @@ -192,7 +198,7 @@ def create_collection( ), ] - start_datetime = datetime.datetime.strptime(DATE_TRANSECTS_CREATED, "%Y-%m-%d") + start_datetime = datetime.datetime.strptime(RELEASE_DATE, "%Y-%m-%d") extent = pystac.Extent( pystac.SpatialExtent([[-180.0, 90.0, 180.0, -90.0]]), @@ -276,7 +282,7 @@ def create_collection( collection.stac_extensions.append(stac_table.SCHEMA_URI) VersionExtension.add_to(collection) - collection.extra_fields["version"] = "1.0.0" + collection.extra_fields["version"] = RELEASE_DATE return collection @@ -304,7 +310,7 @@ def create_item( "description": ASSET_DESCRIPTION, } - dt = datetime.datetime.strptime(DATE_TRANSECTS_CREATED, "%Y-%m-%d") + dt = datetime.datetime.strptime(RELEASE_DATE, "%Y-%m-%d") # shape = shapely.box(*bbox) # geometry = shapely.geometry.mapping(shape) template = pystac.Item( diff --git a/scripts/python/make_gcts.py b/scripts/python/make_gcts.py index 52cfcf4..329a58b 100644 --- a/scripts/python/make_gcts.py +++ b/scripts/python/make_gcts.py @@ -1,21 +1,16 @@ -import dask - -# NOTE: explicitly set query-planning to False to avoid issues with dask-geopandas -dask.config.set({"dataframe.query-planning": False}) - +import datetime import logging import os import time import warnings from functools import partial +import dask import dask_geopandas import fsspec import geopandas as gpd import pandas as pd -import pyproj import shapely -from distributed import Client from dotenv import load_dotenv from geopandas.array import GeometryDtype from shapely.geometry import LineString, Point @@ -23,90 +18,73 @@ from coastpy.geo.ops import crosses_antimeridian from coastpy.geo.quadtiles_utils import add_geo_columns from coastpy.geo.transect import generate_transects_from_coastline -from coastpy.utils.dask_utils import ( +from coastpy.io.partitioner import QuadKeyEqualSizePartitioner +from coastpy.io.utils import rm_from_storage +from coastpy.utils.config import configure_instance +from coastpy.utils.dask import ( + DaskClientManager, silence_shapely_warnings, ) +from coastpy.utils.pandas import add_attributes_from_gdfs load_dotenv(override=True) sas_token = os.getenv("AZURE_STORAGE_SAS_TOKEN") storage_options = {"account_name": "coclico", "credential": sas_token} -utm_grid_url = "az://grid/utm.parquet" -osm_url = "az://coastlines-osm/release/2023-02-09/coast_3857_gen9.parquet" -countries_url = "az://public/countries.parquet" # From overture maps 2024-04-16 - -import datetime +# NOTE: The generalized coastline used here cannot be made publicly available contact +# authors for access. +osm_coastline_uri = "az://coastlines-osm/release/2023-02-09/coast_3857_gen9.parquet" +utm_grid_uri = "az://grid/utm.parquet" +countries_uri = "az://public/countries.parquet" # From overture maps 2024-07-22 +regions_uri = "az://public/regions.parquet" # From overture maps 2024-07-22 today = datetime.datetime.now().strftime("%Y-%m-%d") -OUT_BASE_URI = f"az://gcts/release/{today}.parquet" +OUT_BASE_URI = f"az://gcts/release/{today}" TMP_BASE_URI = OUT_BASE_URI.replace("az://", "az://tmp/") -# DATA_DIR = pathlib.Path.home() / "data" -# TMP_DIR = DATA_DIR / "tmp" -# SRC_DIR = DATA_DIR / "src" -# PRC_DIR = DATA_DIR / "prc" -# RES_DIR = DATA_DIR / "res" -# LIVE_DIR = DATA_DIR / "live" - - # TODO: make cli using argsparse # transect configuration settings MIN_COASTLINE_LENGTH = 5000 SMOOTH_DISTANCE = 1.0e-3 -START_DISTANCE = 50 -COASTLINE_SEGMENT_LENGTH = 1e4 TRANSECT_LENGTH = 2000 +SPACING = 100 FILENAMER = "part.{numb2er}.parquet" COASTLINE_ID_COLUMN = "FID" # FID (OSM) or OBJECTID (Sayre) COLUMNS = [COASTLINE_ID_COLUMN, "geometry"] COASTLINE_ID_RENAME = "FID" -PRC_CRS = "EPSG:3857" -DST_CRS = "EPSG:4326" +PRC_EPSG = 3857 +DST_EPSG = 4326 -prc_epsg = pyproj.CRS.from_user_input(PRC_CRS).to_epsg() -dst_epsg = pyproj.CRS.from_user_input(DST_CRS).to_epsg() - -# # dataset specific settings -# COASTLINES_DIR = SRC_DIR / "coastlines_osm_generalized_v2023" / "coast_3857_gen9.shp" - -# UTM_GRID_FP = LIVE_DIR / "tiles" / "utm.parquet" -# ADMIN1_FP = LIVE_DIR / "overture" / "2024-02-15" / "admin_bounds_level_1.parquet" -# ADMIN2_FP = LIVE_DIR / "overture" / "2024-02-15" / "admin_bounds_level_2.parquet" - -# OUT_DIR = PRC_DIR / COASTLINES_DIR.stem.replace( -# "coast", f"transects_{TRANSECT_LENGTH}_test" -# ) -# OUT_DIR = PRC_DIR / "gcts" / "release" / "2024-07-25" - -# To drop transects at meridonal boundary -SPACING = 100 MAX_PARTITION_SIZE = ( "500MB" # compressed parquet is usually order two smaller, so multiply this ) + MIN_ZOOM_QUADKEY = 2 DTYPES = { - "tr_name": str, + "transect_id": str, "lon": "float32", "lat": "float32", "bearing": "float32", "geometry": GeometryDtype(), # NOTE: leave here because before we used to store the coastline name - # "coastline_name": str, - "coastline_is_closed": bool, - "coastline_length": "int32", - "utm_crs": "int32", + # "osm_coastline_id": str, + "osm_coastline_is_closed": bool, + "osm_coastline_length": "int32", + "utm_epsg": "int32", "bbox": object, "quadkey": str, # NOTE: leave here because before we used to store the bounding quadkey # "bounding_quadkey": str, - "isoCountryCodeAlpha2": str, - "admin_level_1_name": str, - "isoSubCountryCode": str, - "admin_level_2_name": str, + # NOTE: the object dtype are necessary because some rows do not contain contient, country or region data. + # However, it is terriibly inefficient, but for now leave it like this. Otherwise they would become "nan"? + "continent": object, + "country": object, + "common_country_name": object, + "common_region_name": object, } @@ -131,9 +109,16 @@ def silence_warnings(): r" version. Check `isinstance\(dtype, pd.DatetimeTZDtype\)` instead." ), ) + warnings.filterwarnings( + "ignore", + r"DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated," + r"and in a future version of pandas the grouping columns will be excluded from the operation." + r"Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping" + r"action= columns after groupby to silence this warning.", + ) -def zero_pad_tr_name(tr_names: pd.Series) -> pd.Series: +def zero_pad_transect_id(transect_ids: pd.Series) -> pd.Series: """ Zero-pads the numerical parts of transect names to ensure logical sorting. @@ -143,13 +128,13 @@ def zero_pad_tr_name(tr_names: pd.Series) -> pd.Series: reconstructs the transect names with zero-padded ids. Args: - tr_names (pd.Series): A Series of transect names in the format "cl{coastline_id}tr{transect_id}". + transect_ids (pd.Series): A Series of transect names in the format "cl{coastline_id}tr{transect_id}". Returns: pd.Series: A Series of zero-padded transect names for logical sorting. """ # Extract and rename IDs - ids = tr_names.str.extract(r"cl(\d+)s(\d+)tr(\d+)").rename( + ids = transect_ids.str.extract(r"cl(\d+)s(\d+)tr(\d+)").rename( columns={0: "coastline_id", 1: "segment_id", 2: "transect_id"} ) ids = ids.astype({"coastline_id": str, "segment_id": str, "transect_id": str}) @@ -169,7 +154,7 @@ def zero_pad_tr_name(tr_names: pd.Series) -> pd.Series: "cl" + ids["coastline_id"] + "s" + ids["segment_id"] + "tr" + ids["transect_id"] ) - return pd.Series(zero_padded_names, index=tr_names.index) + return pd.Series(zero_padded_names, index=transect_ids.index) def sort_line_segments(segments, original_line): @@ -221,37 +206,33 @@ def sort_line_segments(segments, original_line): silence_warnings() logging.basicConfig(level=logging.INFO) + logging.getLogger("azure").setLevel(logging.WARNING) logging.info(f"Transects will be written to {OUT_BASE_URI}") - # if not OUT_DIR.exists(): - # OUT_DIR.mkdir(exist_ok=True, parents=True) - start_time = time.time() - client = Client( - threads_per_worker=1, processes=True, local_directory="/tmp", n_workers=8 + instance_type = configure_instance() + client = DaskClientManager().create_client( + instance_type, ) client.run(silence_shapely_warnings) logging.info(f"Client dashboard link: {client.dashboard_link}") - with fsspec.open(utm_grid_url, **storage_options) as f: + with fsspec.open(utm_grid_uri, **storage_options) as f: utm_grid = gpd.read_parquet(f) - with fsspec.open(countries_url, **storage_options) as f: - countries = gpd.read_parquet(f) - - utm_grid = utm_grid.dissolve("epsg").to_crs(prc_epsg).reset_index() + utm_grid = utm_grid.dissolve("epsg").to_crs(PRC_EPSG).reset_index() [utm_grid_scattered] = client.scatter( [utm_grid.loc[:, ["geometry", "epsg", "utm_code"]]], broadcast=True ) coastlines = ( - dask_geopandas.read_parquet(osm_url, storage_options=storage_options) + dask_geopandas.read_parquet(osm_coastline_uri, storage_options=storage_options) .repartition(npartitions=10) .persist() # .sample(frac=0.02) - .to_crs(prc_epsg) + .to_crs(PRC_EPSG) ) def is_closed(geometry): @@ -259,14 +240,14 @@ def is_closed(geometry): return geometry.is_closed def wrap_is_closed(df): - df["coastline_is_closed"] = df.geometry.astype(object).apply(is_closed) + df["osm_coastline_is_closed"] = df.geometry.astype(object).apply(is_closed) return df META = gpd.GeoDataFrame( { "FID": pd.Series([], dtype="i8"), "geometry": gpd.GeoSeries([], dtype=GeometryDtype), - "coastline_is_closed": pd.Series([], dtype="bool"), + "osm_coastline_is_closed": pd.Series([], dtype="bool"), } ) @@ -284,7 +265,7 @@ def wrap_is_closed(df): ) ], crs="EPSG:4326", - ).to_crs(prc_epsg) + ).to_crs(PRC_EPSG) [utm_extent_scattered] = client.scatter([utm_extent], broadcast=True) @@ -298,7 +279,7 @@ def overlay_by_grid(df, grid): META = gpd.GeoDataFrame( { "FID": pd.Series([], dtype="i8"), - "coastline_is_closed": pd.Series([], dtype="bool"), + "osm_coastline_is_closed": pd.Series([], dtype="bool"), "epsg": pd.Series([], dtype="i8"), "utm_code": pd.Series([], dtype=object), "geometry": gpd.GeoSeries([], dtype=GeometryDtype), @@ -320,11 +301,11 @@ def overlay_by_grid(df, grid): ) # type: ignore # TODO: use coastpy.geo.utils add_geometry_lengths - def add_lengths(df, utm_crs): + def add_lengths(df, utm_epsg): silence_shapely_warnings() # compute geometry length in local utm crs df = ( - df.to_crs(utm_crs) + df.to_crs(utm_epsg) .assign(geometry_length=lambda df: df.geometry.length) .to_crs(df.crs) ) @@ -332,7 +313,7 @@ def add_lengths(df, utm_crs): coastline_lengths = ( df.groupby("FID_osm")["geometry_length"] .sum() - .rename("coastline_length") + .rename("osm_coastline_length") .reset_index() ) # add to dataframe @@ -343,14 +324,15 @@ def add_lengths(df, utm_crs): META = gpd.GeoDataFrame( { "FID_osm": pd.Series([], dtype="i4"), - "coastline_is_closed": pd.Series([], dtype="bool"), + "osm_coastline_is_closed": pd.Series([], dtype="bool"), "epsg": pd.Series([], dtype="i4"), "utm_code": pd.Series([], dtype="string"), "geometry": gpd.GeoSeries([], dtype=GeometryDtype), - "coastline_length": pd.Series([], dtype="f8"), + "osm_coastline_length": pd.Series([], dtype="f8"), } ) + # NOTE: check how to handle the group keys with Pandas > 2.2.2 coastlines = coastlines.map_partitions( lambda partition: partition.groupby("epsg", group_keys=False).apply( lambda gr: add_lengths(gr, gr.name) @@ -364,18 +346,18 @@ def add_coastline_names(df): names = [ f"cl{fid}s{seg}" for fid, seg in zip(df.FID_osm, segment_ids, strict=False) ] - df["coastline_name"] = names + df["osm_coastline_id"] = names return df META = gpd.GeoDataFrame( { "FID_osm": pd.Series([], dtype="i4"), - "coastline_is_closed": pd.Series([], dtype="bool"), + "osm_coastline_is_closed": pd.Series([], dtype="bool"), "epsg": pd.Series([], dtype="i4"), "utm_code": pd.Series([], dtype="string"), "geometry": gpd.GeoSeries([], dtype=GeometryDtype), - "coastline_length": pd.Series([], dtype="f8"), - "coastline_name": pd.Series([], dtype="string"), + "osm_coastline_length": pd.Series([], dtype="f8"), + "osm_coastline_id": pd.Series([], dtype="string"), } ) coastlines = coastlines.map_partitions(add_coastline_names, meta=META).set_crs( @@ -383,27 +365,27 @@ def add_coastline_names(df): ) # coastlines = ( - # coastlines.assign(coastline_name=1) - # .assign(coastline_name=lambda df: df.coastline_name.cumsum()) + # coastlines.assign(osm_coastline_id=1) + # .assign(osm_coastline_id=lambda df: df.osm_coastline_id.cumsum()) # .persist() # ).set_crs(coastlines.crs) - # coastline_names = coastlines.coastline_name.value_counts().compute() + # coastline_names = coastlines.osm_coastline_id.value_counts().compute() # drop coastlines that are too short coastlines = coastlines.loc[ - coastlines.coastline_length > MIN_COASTLINE_LENGTH + coastlines.osm_coastline_length > MIN_COASTLINE_LENGTH ].persist() def generate_filtered_transects( coastline: LineString, transect_length: float, spacing: float | int, - coastline_name: str, - coastline_is_closed: bool, - coastline_length: int, + osm_coastline_id: str, + osm_coastline_is_closed: bool, + osm_coastline_length: int, src_crs: int, - utm_crs: int, + utm_epsg: int, dst_crs: int, smooth_distance: float = 1e-3, ) -> gpd.GeoDataFrame: @@ -411,11 +393,11 @@ def generate_filtered_transects( coastline, transect_length, spacing, - coastline_name, - coastline_is_closed, - coastline_length, + osm_coastline_id, + osm_coastline_is_closed, + osm_coastline_length, src_crs, - utm_crs, + utm_epsg, dst_crs, smooth_distance, ) @@ -431,9 +413,9 @@ def generate_filtered_transects( # tr_corrected = generate_transects_from_coastline_with_antimeridian_correction( # coastline, # transect_length, - # coastline_name, + # osm_coastline_id, # src_crs, - # utm_crs, + # utm_epsg, # dst_crs, # crosses=crosses, # utm_grid=utm_grid_scattered.result().set_index("epsg"), @@ -443,14 +425,14 @@ def generate_filtered_transects( return transects # Order of columns in the coastlines dataframe - # ['FID_osm', 'epsg', 'utm_code', 'geometry', 'coastline_length','coastline_name'] + # ['FID_osm', 'epsg', 'utm_code', 'geometry', 'osm_coastline_length','osm_coastline_id'] # create a partial function with arguments that do not change partial_generate_filtered_transects = partial( generate_filtered_transects, transect_length=TRANSECT_LENGTH, spacing=SPACING, src_crs=coastlines.crs.to_epsg(), - dst_crs=dst_epsg, + dst_crs=DST_EPSG, smooth_distance=SMOOTH_DISTANCE, ) @@ -460,10 +442,10 @@ def generate_filtered_transects( transects = bag.map( lambda b: partial_generate_filtered_transects( coastline=b[4], - coastline_name=b[6], - coastline_is_closed=b[1], - coastline_length=int(b[5]), - utm_crs=b[2], + osm_coastline_id=b[6], + osm_coastline_is_closed=b[1], + osm_coastline_length=int(b[5]), + utm_epsg=b[2], ) ) @@ -472,72 +454,89 @@ def generate_filtered_transects( transects, geo_columns=["bbox", "quadkey"], quadkey_zoom_level=12 ) - transects["tr_name"] = zero_pad_tr_name(transects["tr_name"]) - transects.to_parquet("/Users/calkoen/transects-test.gpkg") - - print("writing") - with fsspec.open(TMP_BASE_URI, "wb", **storage_options) as f: - transects.to_parquet(f, index=False) - - # transects.to_parquet( - # TMP_BASE_URI, - # index=False, - # storage_options=storage_options, - # ) - - # # NOTE: in next gcts release move this out of processing and add from countries (divisions) seperately - # admin1 = ( - # gpd.read_parquet(ADMIN1_FP) - # .to_crs(transects.crs) - # .drop(columns=["id"]) - # .rename(columns={"primary_name": "admin_level_1_name"}) - # ) - # admin2 = ( - # gpd.read_parquet(ADMIN2_FP) - # .to_crs(transects.crs) - # .drop(columns=["id", "isoCountryCodeAlpha2"]) - # .rename(columns={"primary_name": "admin_level_2_name"}) - # ) - - # # NOTE: zoom level 5 is hard-coded here because I believe spatial join will be faster - # quadkey_grouper = "quadkey_z5" - # transects[quadkey_grouper] = transects.apply( - # lambda r: mercantile.quadkey(mercantile.tile(r.lon, r.lat, 5)), axis=1 - # ) - - # def add_admin_bounds(df, admin_df, max_distance=20000): - # points = gpd.GeoDataFrame( - # df[["tr_name"]], geometry=gpd.GeoSeries.from_xy(df.lon, df.lat, crs=4326) - # ).to_crs(3857) - # joined = gpd.sjoin_nearest( - # points, admin_df.to_crs(3857), max_distance=max_distance - # ).drop(columns=["index_right", "geometry"]) - - # df = pd.merge(df, joined, on="tr_name", how="left") - # return df - - # transects = transects.groupby(quadkey_grouper, group_keys=False).apply( - # lambda gr: add_admin_bounds(gr, admin1), - # ) - # transects = transects.groupby(quadkey_grouper, group_keys=False).apply( - # lambda gr: add_admin_bounds(gr, admin2), - # ) + transects["transect_id"] = zero_pad_transect_id(transects["transect_id"]) + + logging.info(f"Removing files/bytes from {TMP_BASE_URI} if present.") + rm_from_storage( + pattern=(TMP_BASE_URI + "/*.parquet"), + storage_options=storage_options, + confirm=False, + verbose=False, + ) + partitioner = QuadKeyEqualSizePartitioner( + transects, + out_dir=TMP_BASE_URI, + max_size="1GB", + min_quadkey_zoom=4, + sort_by="quadkey", + geo_columns=["bbox", "quadkey"], + storage_options=storage_options, + ) + partitioner.process() + + # with fsspec.open(TMP_BASE_URI, "wb", **storage_options) as f: + # transects.to_parquet(f, index=False) + + logging.info(f"Transects written to {TMP_BASE_URI}") + + transects = dask_geopandas.read_parquet( + TMP_BASE_URI, storage_options=storage_options + ) + # zoom = 5 + # quadkey_grouper = f"quadkey_{zoom}" + # transects[quadkey_grouper] = transects.quadkey.str[:zoom] + + def process(transects_group, countries_uri, regions_uri, max_distance=20000): + with fsspec.open(countries_uri, **storage_options) as f: + countries = gpd.read_parquet( + f, columns=["country", "common_country_name", "continent", "geometry"] + ) + with fsspec.open(regions_uri, **storage_options) as f: + regions = gpd.read_parquet(f, columns=["common_region_name", "geometry"]) + r = add_attributes_from_gdfs( + transects_group, [countries, regions], max_distance=max_distance + ) + return r + + logging.info("Part 2: adding attributes to transects...") + # logging.info(f"Grouping the transects by quadkey zoom level {zoom}") + + tasks = [] + for group in transects.to_delayed(): + t = dask.delayed(process)(group, countries_uri, regions_uri, max_distance=20000) + tasks.append(t) + + logging.info("Computing the submitted tasks..") + transects = pd.concat(dask.compute(*tasks)) # transects = transects.drop(columns=[quadkey_grouper]) - # transects.to_parquet(OUT_DIR, index=False) - - # partitioner = QuadKeyEqualSizePartitioner( - # transects, - # out_dir=OUT_DIR, - # max_size=MAX_PARTITION_SIZE, - # min_quadkey_zoom=MIN_ZOOM_QUADKEY, - # sort_by="quadkey", - # geo_columns=["bbox", "quadkey"], - # column_order=list(DTYPES.keys()), - # dtypes=DTYPES, - # ) - # partitioner.process() + logging.info( + f"Partitioning into equal partitions by quadkey at zoom level {MIN_ZOOM_QUADKEY}" + ) + logging.info(f"Removing files/bytes from {OUT_BASE_URI} if present.") + rm_from_storage( + pattern=(OUT_BASE_URI + "/*.parquet"), + storage_options=storage_options, + confirm=False, + verbose=False, + ) + partitioner = QuadKeyEqualSizePartitioner( + transects, + out_dir=OUT_BASE_URI, + max_size=MAX_PARTITION_SIZE, + min_quadkey_zoom=MIN_ZOOM_QUADKEY, + sort_by="quadkey", + geo_columns=["bbox", "quadkey"], + column_order=list(DTYPES.keys()), + dtypes=DTYPES, + storage_options=storage_options, + naming_function_kwargs={"include_random_hex": True}, + ) + partitioner.process() + + logging.info("Closing client.") + client.close() logging.info("Done!") elapsed_time = time.time() - start_time diff --git a/src/coastpy/geo/__init__.py b/src/coastpy/geo/__init__.py index 3f5d387..949c762 100644 --- a/src/coastpy/geo/__init__.py +++ b/src/coastpy/geo/__init__.py @@ -1,4 +1,5 @@ +from .geoms import create_offset_rectangle from .quadtiles import make_mercantiles from .quadtiles_utils import add_geo_columns -__all__ = ["add_geo_columns", "make_mercantiles", ""] +__all__ = ["add_geo_columns", "make_mercantiles", "create_offset_rectangle"] diff --git a/src/coastpy/geo/ops.py b/src/coastpy/geo/ops.py index 1a30e02..5860846 100644 --- a/src/coastpy/geo/ops.py +++ b/src/coastpy/geo/ops.py @@ -15,11 +15,11 @@ MultiLineString, MultiPoint, Point, - Polygon, + base, ) from shapely.ops import snap, split -from coastpy.utils.dask_utils import silence_shapely_warnings +from coastpy.utils.dask import silence_shapely_warnings def shift_point( @@ -428,26 +428,24 @@ def generate_offset_line(line: LineString, offset: float) -> LineString: def determine_rotation_angle( pt1: Point | tuple[float, float], pt2: Point | tuple[float, float], - target_axis: Literal["closest", "vertical", "horizontal"] = "closest", -) -> float: + target_axis: Literal[ + "closest", "vertical", "horizontal", "horizontal-right-aligned" + ] = "closest", +) -> float | None: """ - Determines the correct rotation angle to align the orientation of a cross-shore transect - either vertically, horizontally or to the closest. + Determines the correct rotation angle to align a transect with a specified axis. Args: - pt1: The starting point of the transect. Can be either a Point object or a tuple of floats. - pt2: The ending point of the transect. Can be either a Point object or a tuple of floats. - target_axis: The target axis to align the transect to. Can be either "closest", "vertical" or "horizontal". + pt1 (Union[Point, Tuple[float, float]]): The starting point of the transect. + pt2 (Union[Point, Tuple[float, float]]): The ending point of the transect. + target_axis (Literal["closest", "vertical", "horizontal", "horizontal-right-aligned"], optional): + The target axis to align the transect. Defaults to "closest". Returns: float: The rotation angle in degrees. Positive values represent counterclockwise rotation. - Example: - >>> determine_rotation_angle((0,0), (1,1), target_axis="horizontal") - -45.0 - Raises: - - ValueError: If the computed angle is not within the expected range [-180, 180]. + ValueError: If an invalid target axis is provided or if the bearing is out of the expected range. """ x1, y1 = extract_coordinates(pt1) @@ -458,7 +456,7 @@ def determine_rotation_angle( logging.info(f"Angle between points: {angle} degrees.") logging.info(f"Bearing between points: {bearing} degrees.") - if x1 == x2 or y1 == y2: # combines the two conditions as the result is the same + if x1 == x2 or y1 == y2: return 0 if target_axis == "closest": @@ -474,10 +472,18 @@ def determine_rotation_angle( } elif target_axis == "horizontal": + angle_rotations = { + (0, 90): lambda b: -(90 - b), + (90, 180): lambda b: b - 90, + (180, 270): lambda b: -(270 - b), + (270, 360): lambda b: b - 270, + } + + elif target_axis == "horizontal-right-aligned": angle_rotations = { (0, 90): lambda b: 90 + b, (90, 180): lambda b: -(270 - b), - (180, 270): lambda b: b - 270, + (180, 270): lambda b: -(270 - b), (270, 360): lambda b: b - 270, } @@ -490,10 +496,7 @@ def determine_rotation_angle( } else: - msg = ( - f"Invalid target_axis: {target_axis}. Must be one of 'closest', 'vertical'" - " or 'horizontal'." - ) + msg = f"Invalid target_axis: {target_axis}. Must be one of 'closest', 'vertical', 'horizontal', or 'horizontal-right-aligned'." raise ValueError(msg) for (lower_bound, upper_bound), rotation_func in angle_rotations.items(): @@ -504,73 +507,103 @@ def determine_rotation_angle( raise ValueError(msg) -def crosses_antimeridian(df: gpd.GeoDataFrame) -> gpd.GeoSeries: +def crosses_antimeridian(df: gpd.GeoDataFrame) -> pd.Series: """ - Determines whether linestrings in a GeoDataFrame cross the International Date Line. + Determines whether LineStrings in a GeoDataFrame cross the International Date Line. Args: df (gpd.GeoDataFrame): Input GeoDataFrame with LineString geometries. Returns: - gpd.GeoSeries: Series indicating whether each LineString crosses the antimeridian. + pd.Series: Series indicating whether each LineString crosses the antimeridian. Example: >>> df = gpd.read_file('path_to_file.geojson') - >>> df['crosses'] = crosses_antimeridian(df) - >>> print(df['crosses']) - - Note: - Assumes the input GeoDataFrame uses a coordinate system in meters. - If using a degree-based system like EPSG:4326, the results may not be accurate. + >>> df['crosses_antimeridian'] = crosses_antimeridian(df) + >>> print(df['crosses_antimeridian']) """ - TEMPLATE = pd.Series([], dtype="bool") + # Ensure the CRS is in degrees (longitude, latitude) + if df.crs.to_epsg() != 4326: + df = df.to_crs(4326) - if df.crs.to_epsg() != 3857: - df = df.to_crs(3857) + # Extract coordinates from the geometry + coords = df.geometry.apply(lambda geom: np.array(geom.coords.xy).T) - if df.empty: - return TEMPLATE + # Vectorized check for antimeridian crossing + def crosses(coords: np.ndarray) -> bool: + # Calculate differences between consecutive longitudes + longitudes = coords[:, 0] + lon_diff = np.diff(longitudes) - coords = df.geometry.astype(object).apply(lambda x: (x.coords[0], x.coords[-1])) - return coords.apply(lambda x: x[0][0] * x[1][0] < 0) + # Check if the difference is greater than 180 degrees (indicating a crossing) + crosses = np.abs(lon_diff) > 180 + return np.any(crosses) + # Apply the vectorized check across all geometries + return coords.apply(crosses) -def buffer_in_utm( - geom: Polygon, - src_crs: str | int, - buffer_distance: float | int, - utm_crs: str | int | None = None, -) -> Polygon: + +def _buffer_geometry( + geom: base.BaseGeometry, src_crs: str | int, buffer_dist: float +) -> base.BaseGeometry: """ - Apply a buffer to a geometry in its appropriate UTM projection. + Buffers a single geometry in its appropriate UTM projection and reprojects it back to the original CRS. Args: - geom (shapely.geometry.Polygon): Input geometry. - src_crs (str): The coordinate reference system of the input geometry in PROJ string format or EPSG code. - buffer_distance (float | int): Buffer distance in metres. - utm_crs (str): The UTM zone of the input geometry in PROJ String or EPSG code. - + geom (shapely.geometry.base.BaseGeometry): The geometry to buffer. + src_crs (Union[str, int]): The original CRS of the geometry. + buffer_dist (float): The buffer distance in meters. Returns: - shapely.geometry.Polygon: Buffered geometry. + base.BaseGeometry: The buffered geometry in the original CRS. + """ + # Estimate the UTM CRS based on the geometry's location + utm_crs = gpd.GeoSeries([geom], crs=src_crs).estimate_utm_crs() - Example: - from shapely.geometry import Point - buffered_geom = buffer_in_utm(Point(12.4924, 41.8902), src_crs="EPSG:4326", buffer_distance=-100) + # Reproject the geometry to UTM, apply the buffer, and reproject back to the original CRS + geom_utm = gpd.GeoSeries([geom], crs=src_crs).to_crs(utm_crs).iloc[0] + buffered_utm = geom_utm.buffer(buffer_dist) + buffered_geom = gpd.GeoSeries([buffered_utm], crs=utm_crs).to_crs(src_crs).iloc[0] + + return buffered_geom + + +def buffer_geometries_in_utm( + geo_data: gpd.GeoSeries | gpd.GeoDataFrame, buffer_dist: float +) -> gpd.GeoSeries | gpd.GeoDataFrame: + """ + Buffer all geometries in a GeoSeries or GeoDataFrame in their appropriate UTM projections and return + the buffered geometries in the original CRS. + + Args: + geo_data (Union[gpd.GeoSeries, gpd.GeoDataFrame]): Input GeoSeries or GeoDataFrame containing geometries. + buffer_dist (float): Buffer distance in meters. + + Returns: + Union[gpd.GeoSeries, gpd.GeoDataFrame]: Buffered geometries in the original CRS. """ - if not utm_crs: - utm_crs = gpd.GeoSeries(geom, crs=src_crs).estimate_utm_crs() + # Determine if the input is a GeoDataFrame or a GeoSeries + is_geodataframe = isinstance(geo_data, gpd.GeoDataFrame) - # Set up the transformers for forward and reverse transformations - transformer_to_utm = Transformer.from_crs(src_crs, utm_crs, always_xy=True) - transformer_from_utm = Transformer.from_crs(utm_crs, src_crs, always_xy=True) + # Extract the geometry series from the GeoDataFrame, if necessary + geom_series = geo_data.geometry if is_geodataframe else geo_data - # Perform the transformations - geom_utm = transform(transformer_to_utm.transform, geom) # type: ignore - geom_buffered_utm = geom_utm.buffer(buffer_distance) - geom_buffered = transform(transformer_from_utm.transform, geom_buffered_utm) # type: ignore + # Ensure the input data has a defined CRS + if geom_series.crs is None: + msg = "Input GeoSeries or GeoDataFrame must have a defined CRS." + raise ValueError(msg) - return geom_buffered + # Buffer each geometry using the UTM projection and return to original CRS + buffered_geoms = geom_series.apply( + lambda geom: _buffer_geometry(geom, geom_series.crs, buffer_dist) + ) + + # Return the modified GeoDataFrame or GeoSeries with the buffered geometries + if is_geodataframe: + geo_data = geo_data.assign(geometry=buffered_geoms) + return geo_data + else: + return buffered_geoms def add_line_length( diff --git a/src/coastpy/geo/quadtiles_utils.py b/src/coastpy/geo/quadtiles_utils.py index bb2df89..dfd02cd 100644 --- a/src/coastpy/geo/quadtiles_utils.py +++ b/src/coastpy/geo/quadtiles_utils.py @@ -74,6 +74,8 @@ def quadkey_to_geojson(quadkey: str) -> dict: } +# NOTE: consider if it would be better to optionally run this function when the attributes +# are already present int he columns. def add_geo_columns( df: gpd.GeoDataFrame, geo_columns: list[Literal["bbox", "bounding_quadkey", "quadkey"]], @@ -138,11 +140,10 @@ def get_point_from_geometry(geom): # Add quadkey column if "quadkey" in geo_columns: if quadkey_zoom_level is None: - message = ( + msg = ( "quadkey_zoom_level must be provided when 'quadkey' is in geo_columns." ) - raise ValueError(message) - + raise ValueError(msg) if "lon" in df.columns and "lat" in df.columns: points = gpd.GeoSeries( [Point(xy) for xy in zip(df.lon, df.lat, strict=False)], crs="EPSG:4326" diff --git a/src/coastpy/geo/transect.py b/src/coastpy/geo/transect.py index 69d45d3..fe04d70 100644 --- a/src/coastpy/geo/transect.py +++ b/src/coastpy/geo/transect.py @@ -22,11 +22,11 @@ class Transect: """Dataclass for transects""" - tr_name: str + transect_id: str tr_origin: Point tr_length: int tr_angle: float - utm_crs: int | str + utm_epsg: int | str dst_crs: int | str _geometry: LineString = field(init=False, repr=False) @@ -37,7 +37,7 @@ class Transect: @property # @lru_cache(maxsize=1) # does not work when using dask distributed def geometry(self): - utm_crs_epsg = pyproj.CRS.from_user_input(self.utm_crs).to_epsg() + utm_crs_epsg = pyproj.CRS.from_user_input(self.utm_epsg).to_epsg() dst_crs_epsg = pyproj.CRS.from_user_input(self.dst_crs).to_epsg() pt1 = calculate_point(self.tr_origin, self.tr_angle + 90, 0.5 * self.tr_length) @@ -61,23 +61,23 @@ def geometry(self): def to_dict(self): return { - "tr_name": self.tr_name, + "transect_id": self.transect_id, "geometry": self.geometry, "lon": self._lon, "lat": self._lat, "tr_origin": wkt.dumps(self._tr_origin), "bearing": self._bearing, - "utm_crs": self.utm_crs, + "utm_epsg": self.utm_epsg, "src_crs": self.dst_crs, } def __hash__(self): return hash( ( - self.tr_name, + self.transect_id, self.tr_origin, self.tr_length, - self.utm_crs, + self.utm_epsg, self.dst_crs, ) ) @@ -174,9 +174,9 @@ def make_transect_origins( def make_transects( coastline: gpd.GeoSeries, - coastline_name: str, + osm_coastline_id: str, src_crs: str, - utm_crs: str, + utm_epsg: str, dst_crs: str, spacing: float, transect_length: int = 2000, @@ -187,9 +187,9 @@ def make_transects( Args: coastline: GeoSeries representing the coastline. - coastline_name: Name of the coastline. + osm_coastline_id: Name of the coastline. src_crs: Source CRS of the coastline. - utm_crs: UTM CRS for local coordinate transformation. + utm_epsg: UTM CRS for local coordinate transformation. dst_crs: Destination CRS for the transects. spacing: Distance between transects. idx: Index of the function call. @@ -200,7 +200,7 @@ def make_transects( GeoDataFrame containing transects with their attributes and geometry. """ # Convert coastline to local UTM CRS - tf = Transformer.from_crs(src_crs, utm_crs, always_xy=True) + tf = Transformer.from_crs(src_crs, utm_epsg, always_xy=True) coastline = transform(tf.transform, coastline) origins = make_transect_origins(coastline, spacing) @@ -213,14 +213,14 @@ def make_transects( pt2 = coastline.interpolate(origin + smooth_distance) angle = get_angle(pt1, pt2) - tr_id = f"cl{int(coastline_name)}tr{int(origin)}" + tr_id = f"cl{int(osm_coastline_id)}tr{int(origin)}" tr = Transect( - tr_name=tr_id, + transect_id=tr_id, tr_angle=angle, tr_origin=tr_origin, tr_length=transect_length, - utm_crs=utm_crs, + utm_epsg=utm_epsg, dst_crs=dst_crs, ) @@ -231,25 +231,25 @@ def make_transects( # when pyarrow is more stable, use pyarrow dtypes instead # column_datatypes = { - # "tr_name": "string[pyarrow]", + # "transect_id": "string[pyarrow]", # "lon": "float64[pyarrow]", # "lat": "float64[pyarrow]", # "tr_origin": "string[pyarrow]", # "bearing": "float64[pyarrow]", - # "utm_crs": "int32[pyarrow]", + # "utm_epsg": "int32[pyarrow]", # "src_crs": "int32[pyarrow]", - # "coastline_name": "int32[pyarrow]", + # "osm_coastline_id": "int32[pyarrow]", # } column_datatypes = { - "tr_name": "string", + "transect_id": "string", "lon": "float64", "lat": "float64", "tr_origin": "string", "bearing": "float64", - "utm_crs": "int32", + "utm_epsg": "int32", "src_crs": "int32", - "coastline_name": "int32", + "osm_coastline_id": "int32", } # TODO: instead of dropping transects that cross date line, create MultiLinestrings? @@ -257,7 +257,7 @@ def make_transects( gpd.GeoDataFrame(transects, geometry="geometry", crs=dst_crs) .reset_index(drop=True) # .pipe(drop_transects_crossing_antimeridian) - .assign(coastline_name=coastline_name) + .assign(osm_coastline_id=osm_coastline_id) .astype(column_datatypes) ) @@ -266,11 +266,11 @@ def generate_transects_from_coastline( coastline: LineString, transect_length: float, spacing: float | int, - coastline_name: int, - coastline_is_closed: bool, - coastline_length: int, + osm_coastline_id: int, + osm_coastline_is_closed: bool, + osm_coastline_length: int, src_crs: str | int, - utm_crs: str | int, + utm_epsg: str | int, dst_crs: str | int, smooth_distance: float = 1e-3, ) -> gpd.GeoDataFrame: @@ -280,11 +280,11 @@ def generate_transects_from_coastline( Args: coastline (LineString): The coastline geometry. transect_length (float): Length of the transects. - coastline_name (int): ID for the coastline. - coastline_is_closed (bool): If the source OSM coastline is closed. - coastline_length (int): length of the coastline. + osm_coastline_id (int): ID for the coastline. + osm_coastline_is_closed (bool): If the source OSM coastline is closed. + osm_coastline_length (int): length of the coastline. src_crs (str): Source CRS of the coastline geometry. - utm_crs (str): UTM CRS for local coordinate transformation. + utm_epsg (str): UTM CRS for local coordinate transformation. dst_crs (str): Target CRS for the transects. smooth_distance (float, optional): Smoothing distance. Defaults to 1e-3. @@ -294,15 +294,15 @@ def generate_transects_from_coastline( # Define a template empty geodataframe with specified datatypes META = gpd.GeoDataFrame( { - "tr_name": pd.Series([], dtype="string"), + "transect_id": pd.Series([], dtype="string"), "lon": pd.Series([], dtype="float32"), "lat": pd.Series([], dtype="float32"), "bearing": pd.Series([], dtype="float32"), - "utm_crs": pd.Series([], dtype="int32"), + "utm_epsg": pd.Series([], dtype="int32"), # NOTE: leave here because before we used to store the coastline name - # "coastline_name": pd.Series([], dtype="string"), - "coastline_is_closed": pd.Series([], dtype="bool"), - "coastline_length": pd.Series([], dtype="int32"), + # "osm_coastline_id": pd.Series([], dtype="string"), + "osm_coastline_is_closed": pd.Series([], dtype="bool"), + "osm_coastline_length": pd.Series([], dtype="int32"), "geometry": gpd.GeoSeries([], dtype="geometry"), }, crs=dst_crs, @@ -314,7 +314,7 @@ def generate_transects_from_coastline( dtypes = META.dtypes.to_dict() column_order = META.columns.to_list() - tf = Transformer.from_crs(src_crs, utm_crs, always_xy=True) + tf = Transformer.from_crs(src_crs, utm_epsg, always_xy=True) coastline = transform(tf.transform, coastline) origins = make_transect_origins(coastline, spacing) @@ -337,32 +337,32 @@ def generate_transects_from_coastline( get_planar_bearing(pt1, pt2) for pt1, pt2 in zip(pt1s, pt2s, strict=True) ] - tf_4326 = Transformer.from_crs(utm_crs, 4326, always_xy=True) + tf_4326 = Transformer.from_crs(utm_epsg, 4326, always_xy=True) tr_origins_4326 = [ transform(tf_4326.transform, tr_origin) for tr_origin in tr_origins ] lons, lats = zip(*[extract_coordinates(p) for p in tr_origins_4326], strict=True) - tr_names = [f"{coastline_name}tr{int(o)}" for o in origins] + transect_ids = [f"{osm_coastline_id}tr{int(o)}" for o in origins] return ( gpd.GeoDataFrame( { - "tr_name": tr_names, + "transect_id": transect_ids, "lon": lons, "lat": lats, "bearing": bearings, "geometry": transects, }, - crs=utm_crs, + crs=utm_epsg, ) .to_crs(dst_crs) # NOTE: leave here because before we used to store the coastline name - # .assign(coastline_name=coastline_name) - .assign(utm_crs=utm_crs) - .assign(coastline_is_closed=coastline_is_closed) - .assign(coastline_length=coastline_length) + # .assign(osm_coastline_id=osm_coastline_id) + .assign(utm_epsg=utm_epsg) + .assign(osm_coastline_is_closed=osm_coastline_is_closed) + .assign(osm_coastline_length=osm_coastline_length) .loc[:, column_order] .astype(dtypes) ) diff --git a/src/coastpy/geo/utils.py b/src/coastpy/geo/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/src/coastpy/io/partitioner.py b/src/coastpy/io/partitioner.py index b901221..268f238 100644 --- a/src/coastpy/io/partitioner.py +++ b/src/coastpy/io/partitioner.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal import dask_geopandas import fsspec @@ -7,9 +8,12 @@ from coastpy.geo.quadtiles_utils import add_geo_columns from coastpy.geo.size import estimate_memory_usage_per_row from coastpy.io.utils import name_data -from coastpy.utils.size_utils import size_to_bytes +from coastpy.utils.size import size_to_bytes +# NOTE: What about separating the partitioner and the equal size partitioner? - that +# way we can have a more generic partitioner tha does not have to compute the sizes of +# each row. class EqualSizePartitioner: def __init__( self, @@ -18,24 +22,31 @@ def __init__( max_size, sort_by="quadkey", quadkey_zoom_level=12, - geo_columns=None, + geo_columns: list[Literal["bbox", "bounding_quadkey", "quadkey"]] | None = None, column_order=None, dtypes=None, + storage_options=None, + naming_function_kwargs=None, ): if geo_columns is None: - geo_columns = ["bbox", "quadkey", "bounding_quadkey"] + geo_columns = ["bbox", "quadkey"] + + if storage_options is None: + storage_options = {} + + if naming_function_kwargs is None: + naming_function_kwargs = {} self.df = df - self.out_dir = Path(out_dir) + self.out_dir = out_dir self.max_size_bytes = size_to_bytes(max_size) self.sort_by = sort_by self.quadkey_zoom_level = quadkey_zoom_level self.geo_columns = geo_columns + self.storage_options = storage_options self.column_order = column_order self.dtypes = dtypes - - # Ensure output directory exists - self.out_dir.mkdir(parents=True, exist_ok=True) + self.naming_function_kwargs = naming_function_kwargs # Set the naming function for the output files self.naming_function = name_data @@ -85,11 +96,20 @@ def write_data(self, partition_df, column_order=None): if not partition_df.empty: partition_df = partition_df[column_order] if column_order else partition_df # Generate the output path using the naming function - outpath = self.naming_function(partition_df, prefix=str(self.out_dir)) + outpath = self.naming_function( + partition_df, prefix=str(self.out_dir), **self.naming_function_kwargs + ) + + # Initialize the filesystem object with storage options + + fs, _, [path] = fsspec.get_fs_token_paths( + outpath, storage_options=self.storage_options + ) - # Ensure the output directory exists - fs = fsspec.open(outpath, "wb").fs - fs.makedirs(fs._parent(outpath), exist_ok=True) + # Ensure the output directory exists (local filesystem specific) + if fs.protocol == ("file", "local"): + # For local filesystem, ensure the parent directory exists + fs.makedirs(fs._parent(path), exist_ok=True) if self.dtypes: partition_df = partition_df.astype(self.dtypes) @@ -98,7 +118,7 @@ def write_data(self, partition_df, column_order=None): partition_df = partition_df[self.column_order] # Use fsspec to write the DataFrame to parquet - with fs.open(outpath, "wb") as f: + with fs.open(path, "wb") as f: partition_df.to_parquet(f, index=False) @@ -111,9 +131,11 @@ def __init__( min_quadkey_zoom, sort_by, quadkey_zoom_level=12, - geo_columns=None, + geo_columns: list[Literal["bbox", "bounding_quadkey", "quadkey"]] | None = None, column_order=None, dtypes=None, + storage_options=None, + naming_function_kwargs=None, ): super().__init__( df, @@ -124,6 +146,8 @@ def __init__( geo_columns, column_order, dtypes, + storage_options, + naming_function_kwargs, ) self.min_quadkey_zoom = min_quadkey_zoom self.quadkey_grouper = f"quadkey_z{min_quadkey_zoom}" diff --git a/src/coastpy/io/utils.py b/src/coastpy/io/utils.py index e2fafa8..c154b33 100644 --- a/src/coastpy/io/utils.py +++ b/src/coastpy/io/utils.py @@ -1,4 +1,6 @@ +import copy import json +import logging import pathlib import uuid import warnings @@ -15,6 +17,8 @@ from shapely.geometry import box from shapely.ops import transform +logger = logging.getLogger(__name__) + def is_local_file_path(path: str | pathlib.Path) -> bool: """ @@ -255,22 +259,46 @@ def name_bounds(bounds, crs): def read_items_extent(collection, columns=None, storage_options=None): + """ + Reads the extent of items from a STAC collection and returns a GeoDataFrame with specified columns. + + Args: + collection: A STAC collection object that contains assets. + columns: List of columns to return. Default is ["geometry", "assets", "href"]. + storage_options: Storage options to pass to fsspec. Default is None. + + Returns: + GeoDataFrame containing the specified columns. + """ if storage_options is None: storage_options = {} + # Set default columns if columns is None: - columns = ["geometry", "assets"] + columns = ["geometry", "assets", "href"] - required_cols = ["geometry", "assets"] + columns_ = copy.deepcopy(columns) - for col in required_cols: - if col not in columns: - columns = [*columns, col] + # Ensure 'assets' is always in the columns + if "assets" not in columns: + columns.append("assets") + logger.debug("'assets' column added to the list of columns") + # Open the parquet file and read the specified columns href = collection.assets["geoparquet-stac-items"].href with fsspec.open(href, mode="rb", **storage_options) as f: - extents = gpd.read_parquet(f, columns=columns) - extents["href"] = extents.assets.map(lambda x: x["data"]["href"]) + extents = gpd.read_parquet(f, columns=[c for c in columns if c != "href"]) + + # If 'href' is requested, extract it from the 'assets' column + if "href" in columns: + extents["href"] = extents["assets"].apply(lambda x: x["data"]["href"]) + logger.debug("'href' column extracted from 'assets'") + + # Drop 'assets' if it was not originally requested + if "assets" not in columns_: + extents = extents.drop(columns=["assets"]) + logger.debug("'assets' column dropped from the GeoDataFrame") + return extents @@ -376,3 +404,81 @@ def read_log_entries( df = df.sort_values(by="time", ascending=True) return df + + +def rm_from_storage( + pattern: str, + storage_options: dict[str, str] | None = None, + confirm: bool = True, + verbose: bool = True, +) -> None: + """ + Deletes all blobs/files in the specified storage location that match the given prefix. + + Args: + pattern (str): The pattern or path pattern (including wildcards) for the blobs/files to delete. + storage_options (Dict[str, str], optional): A dictionary containing storage connection details. + confirm (bool): Whether to prompt for confirmation before deletion. + verbose (bool): Whether to display detailed log messages. + + Returns: + None + """ + if storage_options is None: + storage_options = {} + + # Create a local logger + logger = logging.getLogger(__name__) + if verbose: + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + if storage_options is None: + storage_options = {} + + # Get filesystem, token, and resolved paths + fs, _, paths = fsspec.get_fs_token_paths(pattern, storage_options=storage_options) + + if paths: + if verbose: + logger.info( + f"Warning: You are about to delete the following {len(paths)} blobs/files matching '{pattern}'." + ) + for path in paths: + logger.info(path) + + if confirm: + confirmation = input( + f"Type 'yes' to confirm deletion of {len(paths)} blobs/files matching '{pattern}': " + ) + else: + confirmation = "yes" + + if confirmation.lower() == "yes": + for path in paths: + try: + if verbose: + logger.info(f"Deleting blob/file: {path}") + fs.rm(path) + if verbose: + logger.info(f"Blob/file {path} deleted successfully.") + except Exception as e: + if verbose: + logger.error(f"Failed to delete blob/file: {e}") + if verbose: + logger.info("All specified blobs/files have been deleted.") + else: + if verbose: + logger.info("Blob/file deletion cancelled.") + else: + if verbose: + logger.info(f"No blobs/files found matching '{pattern}'.") + + # Remove the handler after use + if verbose: + logger.removeHandler(handler) + handler.close() diff --git a/src/coastpy/libs/stac_table.py b/src/coastpy/libs/stac_table.py index b8cac61..7beaf5b 100644 --- a/src/coastpy/libs/stac_table.py +++ b/src/coastpy/libs/stac_table.py @@ -297,6 +297,10 @@ def get_columns(schema: pa.Schema, prefix: str = "") -> list: # For nested fields, recurse into the structure nested_columns = get_columns(field.type, prefix=field.name + ".") columns.extend(nested_columns) + elif field.name == "geometry": + column = {"name": "geometry", "type": "WKB"} + columns.append(column) + else: # Handle non-nested fields column = {"name": prefix + field.name, "type": str(field.type).lower()} diff --git a/src/coastpy/utils/__init__.py b/src/coastpy/utils/__init__.py index 14a7e25..5122ef1 100644 --- a/src/coastpy/utils/__init__.py +++ b/src/coastpy/utils/__init__.py @@ -1,11 +1,11 @@ from .config import configure_instance, detect_instance_type -from .dask_utils import create_dask_client -from .size_utils import readable_bytes, size_to_bytes +from .dask import DaskClientManager +from .size import readable_bytes, size_to_bytes __all__ = [ "detect_instance_type", "configure_instance", - "create_dask_client", + "DaskClientManager", "size_to_bytes", "readable_bytes", ] diff --git a/src/coastpy/utils/dask.py b/src/coastpy/utils/dask.py new file mode 100644 index 0000000..e952343 --- /dev/null +++ b/src/coastpy/utils/dask.py @@ -0,0 +1,154 @@ +import logging +from typing import Any + +import dask +from distributed import Client + +from coastpy.utils.config import ComputeInstance + + +class DaskClientManager: + """Manager for creating Dask clients based on compute instance type. + + This class supports the creation of local and SLURM Dask clusters, + with optional configuration from external files. + + Attributes: + config_path (Optional[str]): Path to a Dask configuration file. + """ + + def __init__(self): + """Initialize the DaskClientManager, optionally loading a config file. + + Args: + config_path (Optional[str]): Path to the configuration file. + """ + dask.config.refresh() + + def create_client(self, instance_type: ComputeInstance, *args: Any, **kwargs: Any): + """Create a Dask client based on the instance type. + + Args: + instance_type (ComputeInstance): The type of the compute instance. + *args: Additional positional arguments for client creation. + **kwargs: Additional keyword arguments for client creation. + + Returns: + Client: The Dask client. + + Raises: + ValueError: If the instance type is not recognized. + """ + + if instance_type.name == "LOCAL": + return self._create_local_client(*args, **kwargs) + elif instance_type.name == "SLURM": + return self._create_slurm_client(*args, **kwargs) + else: + msg = "Unknown compute instance type." + raise ValueError(msg) + + def _create_local_client(self, *args: Any, **kwargs: Any) -> Client: + """Create a local Dask client with potential overrides. + + Args: + *args: Additional positional arguments for client creation. + **kwargs: Additional keyword arguments for client creation. + + Returns: + Client: The Dask local client. + """ + # Set default values + from distributed import Client + + configs = { + "threads_per_worker": 1, + "processes": True, + "n_workers": 5, + "local_directory": "/tmp", + } + + # Update defaults with any overrides provided in kwargs + configs.update(kwargs) + + # Create and return the Dask Client using the updated parameters + return Client(*args, **configs) + + def _create_slurm_client(self, *args: Any, **kwargs: Any) -> Client: + """Create a SLURM Dask client with potential overrides. + + Args: + *args: Additional positional arguments for client creation. + **kwargs: Additional keyword arguments for client creation. + + Returns: + Client: The Dask SLURM client. + """ + from dask_jobqueue import SLURMCluster + + slurm_configs = { + "cores": 10, # Cores per worker + "processes": 10, # Processes per worker + # "n_workers": 10, + "memory": "120GB", # Memory per worker + # "local_directory": "/scratch/frcalkoen/tmp", + "walltime": "3:00:00", + "log_directory": "/scratch/frcalkoen/tmp", + } + # Update default values with any overrides provided in kwargs + slurm_configs.update(kwargs) + + # Create the SLURM cluster + cluster = SLURMCluster(*args, **slurm_configs) + cluster.scale(jobs=2) + + logging.info(f"{cluster.job_script()}") + + # cluster.scale(jobs=5) + + # min_jobs = kwargs.pop( + # "minimum_jobs", dask.config.get("jobqueue.adaptive.minimum", 1) + # ) + # max_jobs = kwargs.pop( + # "maximum_jobs", dask.config.get("jobqueue.adaptive.maximum", 30) + # ) + + # cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs) + return Client(cluster) + + # def _create_slurm_client(self, *args: Any, **kwargs: Any): + # """Create a SLURM Dask client with potential overrides. + + # Args: + # *args: Additional positional arguments for client creation. + # **kwargs: Additional keyword arguments for client creation. + + # Returns: + # Client: The Dask SLURM client. + # """ + # from dask_jobqueue import SLURMCluster + + # min_jobs = kwargs.pop( + # "minimum_jobs", dask.config.get("jobqueue.adaptive.minimum", 1) + # ) + # max_jobs = kwargs.pop( + # "maximum_jobs", dask.config.get("jobqueue.adaptive.maximum", 30) + # ) + + # cluster = SLURMCluster(*args, **kwargs) + # cluster.adapt(minimum_jobs=min_jobs, maximum_jobs=max_jobs) + # return cluster.get_client() + + +def silence_shapely_warnings() -> None: + """Suppress specific warnings commonly encountered in Shapely geometry operations.""" + import warnings + + warnings_to_ignore: list[str] = [ + "invalid value encountered in buffer", + "invalid value encountered in intersection", + "invalid value encountered in unary_union", + ] + + for warning in warnings_to_ignore: + warnings.filterwarnings("ignore", message=warning) diff --git a/src/coastpy/utils/dask_utils.py b/src/coastpy/utils/dask_utils.py deleted file mode 100644 index 041e3dd..0000000 --- a/src/coastpy/utils/dask_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -from coastpy.utils.config import ComputeInstance - - -def create_dask_client(instance_type: ComputeInstance): - """Create a Dask client based on the instance type. - - Args: - instance_type (ComputeInstance): The type of the compute instance. - - Returns: - Client: The Dask client. - """ - - if instance_type.name == "LOCAL": - from distributed import Client - - return Client( - threads_per_worker=1, - processes=True, - local_directory="/tmp", - ) - elif instance_type.name == "SLURM": - from dask_jobqueue import SLURMCluster - - cluster = SLURMCluster(memory="16GB") - cluster.adapt(minimum_jobs=1, maximum_jobs=30) - return cluster.get_client() - else: - msg = "Unknown compute instance type." - raise ValueError(msg) - - -def silence_shapely_warnings() -> None: - """ - Suppress specific warnings commonly encountered in Shapely geometry operations. - - Warnings being suppressed: - - Invalid value encountered in buffer - - Invalid value encountered in intersection - - Invalid value encountered in unary_union - """ - - warning_messages = [ - "invalid value encountered in buffer", - "invalid value encountered in intersection", - "invalid value encountered in unary_union", - ] - import warnings - - for message in warning_messages: - warnings.filterwarnings("ignore", message=message) diff --git a/src/coastpy/utils/pandas.py b/src/coastpy/utils/pandas.py new file mode 100644 index 0000000..30e47eb --- /dev/null +++ b/src/coastpy/utils/pandas.py @@ -0,0 +1,156 @@ +import warnings + +import antimeridian +import geopandas as gpd +import pandas as pd +import shapely +from geopandas import GeoDataFrame +from shapely.geometry import LineString, box + + +def create_buffer_zone( + gdf: GeoDataFrame, + planar_crs: int | None = 3857, + buffer_factor: float = 1.5, + max_distance: float = 20000, + use_utm: bool = False, +) -> GeoDataFrame: + """ + Create a buffer zone around geometries in a GeoDataFrame. + + Args: + gdf (GeoDataFrame): Input GeoDataFrame with geometries. + planar_crs (Optional[int]): EPSG code for the planar projection system. Defaults to 3857. + buffer_factor (float): Factor to multiply the max distance by for the buffer. Defaults to 1.5. + max_distance (float): Maximum distance for the buffer in meters. Defaults to 20000. + use_utm (bool): Whether to compute the buffer in the UTM zone of the geometries. Defaults to False. + + Returns: + GeoDataFrame: GeoDataFrame with the buffer zone geometries. + """ + src_crs = gdf.crs + + if use_utm: + planar_crs = gdf.estimate_utm_crs() + + gdf = gdf.to_crs(planar_crs) + gdf["geometry"] = gdf.buffer(max_distance * buffer_factor) + gdf = gdf.to_crs(src_crs) + + # TODO: + # - Add antimeridian handling + + return gdf + + +def create_antimeridian_buffer( + max_distance: float, buffer_factor: float = 1.5 +) -> gpd.GeoDataFrame: + """ + Creates a buffered zone around the antimeridian to account for spatial + operations near the -180/180 longitude line. + + Args: + max_distance (float): Maximum distance in meters for the buffer around the antimeridian. + buffer_factor (float, optional): Factor to scale the buffer size. Defaults to 1.5. + + Returns: + gpd.GeoDataFrame: A GeoDataFrame containing the buffered zone around the antimeridian. + """ + # Create a GeoDataFrame representing the antimeridian line + antimeridian_line = gpd.GeoDataFrame( + geometry=[LineString([[-180, -85], [-180, 85]])], crs=4326 + ) + buffer_zone = create_buffer_zone( + antimeridian_line, + planar_crs=3857, + buffer_factor=buffer_factor, + max_distance=max_distance, + use_utm=False, + ) + + # Suppress FixWindingWarning from the antimeridian package + with warnings.catch_warnings(): + warnings.simplefilter("ignore", antimeridian.FixWindingWarning) + fixed_geometry = shapely.geometry.shape( + antimeridian.fix_geojson( + shapely.geometry.mapping(buffer_zone.iloc[0].geometry) + ) + ) + + # Return as GeoDataFrame + return gpd.GeoDataFrame(geometry=[fixed_geometry], crs=4326) + + +def add_attributes_from_gdf( + df: gpd.GeoDataFrame, + other_gdf: gpd.GeoDataFrame, + max_distance: float = 20000, + buffer_factor: float = 1.5, +) -> gpd.GeoDataFrame: + """ + Adds attributes from a source GeoDataFrame to a target GeoDataFrame based on nearest spatial join. + + Args: + df (gpd.GeoDataFrame): The target GeoDataFrame to which attributes will be added. + other_gdf (gpd.GeoDataFrame): The other GeoDataFrame from which attributes will be extracted. + max_distance (float): The maximum distance for nearest neighbor consideration, in meters. + buffer_factor (float): Factor to increase the buffer area. Defaults to 1.5. + + Returns: + gpd.GeoDataFrame: The target GeoDataFrame with added attributes from the source GeoDataFrame. + """ + # Ensure the transect GeoDataFrame has a point geometry to avoid double intersection. + transect_origins = gpd.GeoDataFrame( + df[["transect_id"]], + geometry=gpd.points_from_xy(df.lon, df.lat, crs=4326), + ) + + antimeridian_buffer = create_antimeridian_buffer( + max_distance, buffer_factor=buffer_factor + ) + + # Optimization: define the region of interest with a buffer that only works in areas far away from + # the antimeridian + if gpd.overlay(transect_origins, antimeridian_buffer).empty: + roi = gpd.GeoDataFrame(geometry=[box(*transect_origins.total_bounds)], crs=4326) + roi = gpd.GeoDataFrame( + geometry=roi.to_crs(3857).buffer(max_distance * 1.5).to_crs(4326) + ) + + # Filter source GeoDataFrame within the region of interest + other_gdf = gpd.sjoin(other_gdf, roi).drop(columns=["index_right"]) + + # Perform nearest neighbor spatial join + joined = gpd.sjoin_nearest( + transect_origins.to_crs(3857), + other_gdf.to_crs(3857), + max_distance=max_distance, + ).drop(columns=["index_right", "geometry"]) + + # Merge the attributes into the original target GeoDataFrame + result = pd.merge(df, joined, on="transect_id", how="left").drop_duplicates( + "transect_id" + ) + return result + + +def add_attributes_from_gdfs( + df: gpd.GeoDataFrame, + other_gdfs: list[gpd.GeoDataFrame], + max_distance: float = 20000, +) -> gpd.GeoDataFrame: + """ + Adds attributes from multiple other GeoDataFrames to a target GeoDataFrame. + + Args: + df (gpd.GeoDataFrame): The target GeoDataFrame to which attributes will be added. + other_gdfs (List[gpd.GeoDataFrame]): A list of other GeoDataFrames from which attributes will be extracted. + max_distance (float): The maximum distance for nearest neighbor consideration, in meters. + + Returns: + gpd.GeoDataFrame: The target GeoDataFrame with added attributes from all source GeoDataFrames. + """ + for source_gdf in other_gdfs: + df = add_attributes_from_gdf(df, source_gdf, max_distance) + return df diff --git a/src/coastpy/utils/size_utils.py b/src/coastpy/utils/size.py similarity index 100% rename from src/coastpy/utils/size_utils.py rename to src/coastpy/utils/size.py diff --git a/src/coastpy/utils/xr_utils.py b/src/coastpy/utils/xarray.py similarity index 85% rename from src/coastpy/utils/xr_utils.py rename to src/coastpy/utils/xarray.py index 5723160..ac5c628 100644 --- a/src/coastpy/utils/xr_utils.py +++ b/src/coastpy/utils/xarray.py @@ -1,8 +1,10 @@ -from typing import Literal +import warnings import numpy as np +import rasterio import xarray as xr from affine import Affine +from rasterio.enums import Resampling from shapely import Polygon @@ -115,35 +117,52 @@ def raster_center(ds: xr.Dataset) -> tuple[float, float]: def rotate_raster( - ds: xr.Dataset, rotation_angle: float, pivot: tuple[float, float] | None = None + ds: xr.Dataset, + rotation_angle: float, + resampling: Resampling, + pivot: tuple[float, float] | None = None, ) -> xr.Dataset: """ Rotate a raster dataset around a pivot point or its center. Args: - ds (xr.Dataset): The raster dataset to be rotated. + ds (xr.Dataset): Raster dataset to be rotated. rotation_angle (float): Angle to rotate the raster, in degrees. Positive values represent counterclockwise rotation. - pivot (Tuple[float, float], optional): The (x, y) coordinates of the pivot point. If not provided, raster's center is used. + resampling (Resampling): Resampling method to use during reprojection. + pivot (Optional[Tuple[float, float]]): (x, y) coordinates of the pivot point. If not provided, the raster's center is used. Returns: - xr.Dataset: The rotated raster dataset. + xr.Dataset: Rotated raster dataset. - Example: - >>> ds = xr.Dataset(...) - >>> rotated_ds = rotate_raster(ds, 45) + Raises: + UserWarning: If the absolute rotation angle is 45 degrees, which may result in a raster that is not of the expected shape, with a clipped view because the axis should also be swapped. """ + if abs(rotation_angle) > 45: + msg = "The absolute rotation angle larger than 45 degrees, which may result in a raster that clipped. Consider adjusting the rotation in the other direction." + warnings.warn( + msg, + UserWarning, + stacklevel=2, + ) + src_transform = ds.rio.transform() rotation = Affine.rotation(rotation_angle, pivot=pivot) + + # TODO: Compute the scaling factors for the new grid + # dst_transform = src_transform * Affine.scale(x_scale, y_scale) + dst_transform = src_transform * rotation # Rescale the y-axis to correct the inversion rescale_y = Affine(1, 0, 0, 0, -1, ds.rio.height) dst_transform = dst_transform * rescale_y - ds = ds.rio.reproject(dst_crs=ds.rio.crs, transform=dst_transform) + ds = ds.rio.reproject( + dst_crs=ds.rio.crs, transform=dst_transform, resampling=resampling + ) ds = ds.rio.write_transform(dst_transform) ds = ds.assign_coords( - {"y": ("y", range(ds.dims["y"])), "x": ("x", range(ds.dims["x"]))} + {"y": ("y", range(ds.sizes["y"])), "x": ("x", range(ds.sizes["x"]))} ) return ds @@ -152,20 +171,7 @@ def interpolate_raster( ds: xr.Dataset, y_shape: int, x_shape: int, - method: Literal[ - "linear", - "nearest", - "zero", - "slinear", - "quadratic", - "cubic", - "polynomial", - "barycentric", - "krog", - "pchip", - "spline", - "akima", - ], + resampling: rasterio.enums.Resampling = rasterio.enums.Resampling.nearest, ) -> xr.Dataset: """ Interpolates a given raster (xarray Dataset) to a specified resolution using the provided method. @@ -174,8 +180,7 @@ def interpolate_raster( ds (xr.Dataset): The input raster to interpolate. y_shape (int): Desired number of grid points along y dimension. x_shape (int): Desired number of grid points along x dimension. - method (str, optional): The interpolation method to use. Defaults to "linear". - Other methods like "nearest", "cubic" can also be used. + resampling: rasterio.enums.Resampling: The interpolation method to use. Returns: xr.Dataset: Interpolated raster without geospatial metadata. @@ -196,10 +201,10 @@ def interpolate_raster( new_y = np.linspace(ds.y.min(), ds.y.max(), y_shape) new_x = np.linspace(ds.x.min(), ds.x.max(), x_shape) - interpolated = ds.interp(y=new_y, x=new_x, method=method) + interpolated = ds.interp(y=new_y, x=new_x, method=resampling) # add new coords because the old ones are now two dimensional - interpolated = interpolated.assign_coords(y=new_y, x=new_x) + interpolated = interpolated.assign_coords(y=range(y_shape), x=range(x_shape)) # the transformation matrix can be computed by scaling the src one src_dims = ds.dims