diff --git a/README.md b/README.md index baaedf7..da9e978 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,13 @@ Python tools for cloud-native coastal analytics. ## Installation -If you have all GDAL dependencies installed correctly you can install with pip: +You can install coastpy with pip in a Python environment where [GDAL](https://pypi.org/project/GDAL/) and [pyproj](https://pypi.org/project/pyproj/) are already installed. ```bash pip install coastpy ``` -although it's probably easier to install with conda: +From scratch it's probably easier to install with [conda](https://github.com/conda-forge/miniforge): ```bash conda env create -f environment.yaml diff --git a/src/coastpy/eo/collection.py b/src/coastpy/eo/collection.py new file mode 100644 index 0000000..45bfdb4 --- /dev/null +++ b/src/coastpy/eo/collection.py @@ -0,0 +1,609 @@ +import abc +import logging +from collections.abc import Callable +from typing import Any + +import geopandas as gpd +import numpy as np +import odc.stac +import pyproj +import pystac +import pystac_client +import rioxarray # noqa +import stac_geoparquet +import xarray as xr + +from coastpy.eo.indices import calculate_indices +from coastpy.stac.utils import read_snapshot +from coastpy.utils.xarray import get_nodata, set_nodata + + +class ImageCollection: + """ + A generic class to manage image collections from a STAC-based catalog. + """ + + def __init__( + self, + catalog_url: str, + collection: str, + stac_cfg: dict | None = None, + ): + self.catalog_url = catalog_url + self.collection = collection + self.catalog = pystac_client.Client.open(self.catalog_url) + + # Configuration + self.search_params = {} + self.bands = [] + self.spectral_indices = [] + self.percentile = None + self.dst_crs = None + self.load_params = {} + self.stac_cfg = stac_cfg or {} + + # Internal state + self.items = None + self.dataset = None + + def search( + self, + roi: gpd.GeoDataFrame, + datetime_range: str, + query: dict | None = None, + filter_function: Callable[[list[pystac.Item]], list[pystac.Item]] | None = None, + ) -> "ImageCollection": + """ + Search the catalog for items and optionally apply a filter function. + + Args: + roi (gpd.GeoDataFrame): Region of interest. + datetime_range (str): Temporal range in 'YYYY-MM-DD/YYYY-MM-DD'. + query (dict, optional): Additional query parameters for search. + filter_function (Callable, optional): A function to filter/sort items. + Accepts and returns a list of pystac.Items. + + Returns: + ImageCollection: Updated instance with items populated. + """ + self.search_params = { + "collections": self.collection, + "intersects": roi.to_crs(4326).geometry.item(), + "datetime": datetime_range, + "query": query, + } + + # Perform the actual search + logging.info(f"Executing search with params: {self.search_params}") + search = self.catalog.search(**self.search_params) + self.items = list(search.items()) + + # Check if items were found + if not self.items: + msg = "No items found for the given search parameters." + raise ValueError(msg) + + # Apply the filter function if provided + if filter_function: + try: + logging.info("Applying custom filter function.") + self.items = filter_function(self.items) + except Exception as e: + msg = f"Error in filter_function: {e}" + raise RuntimeError(msg) # noqa: B904 + + return self + + def load( + self, + bands: list[str], + percentile: int | None = None, + spectral_indices: list[str] | None = None, + chunks: dict[str, int | str] | None = None, + groupby: str = "solar_day", + resampling: str | dict[str, str] | None = None, + dtype: np.dtype | str | None = None, + crs: str | int | None = None, + resolution: float | int | None = None, + pool: int | None = None, + preserve_original_order: bool = False, + progress: bool | None = None, + fail_on_error: bool = True, + geobox: dict | None = None, + like: xr.Dataset | None = None, + patch_url: str | None = None, + dst_crs: Any | None = None, + ) -> "ImageCollection": + """ + Configure loading parameters. + + Args: + bands (List[str]): Bands to load. + percentile (int | None): Percentile value for compositing (e.g., 50 for median). + spectral_indices (List[str]): Spectral indices to calculate. + Additional args: Parameters for odc.stac.load. + + Returns: + ImageCollection: Updated instance. + + """ + if percentile is not None and not (0 <= percentile <= 100): + msg = "Composite percentile must be between 0 and 100." + raise ValueError(msg) + + self.bands = bands + self.spectral_indices = spectral_indices + self.percentile = percentile + self.dst_crs = dst_crs + + # ODC StaC load parameters + self.load_params = { + "chunks": chunks or {}, + "groupby": groupby, + "resampling": resampling, + "stac_cfg": self.stac_cfg, + "dtype": dtype, + "crs": crs, + "resolution": resolution, + "pool": pool, + "preserve_original_order": preserve_original_order, + "progress": progress, + "fail_on_error": fail_on_error, + "geobox": geobox, + "like": like, + "patch_url": patch_url, + } + return self + + def _load(self) -> xr.Dataset: + """ + Internal method to load data using odc.stac. + """ + if not self.items: + msg = "No items found. Perform a search first." + raise ValueError(msg) + + bbox = tuple(self.search_params["intersects"].bounds) + + ds = odc.stac.load( + self.items, + bands=self.bands, + bbox=bbox, + **self.load_params, + ) + + for band in self.bands: + nodata_value = get_nodata(ds[band]) + ds[band] = ds[band].where(ds[band] != nodata_value) + ds[band] = set_nodata(ds[band], np.nan) + + if self.dst_crs and ( + pyproj.CRS.from_user_input(self.dst_crs).to_epsg() != ds.rio.crs.to_epsg() + ): + ds = ds.rio.reproject(self.dst_crs) + ds = ds.odc.reproject(self.dst_crs, resampling="bilinear", nodata=np.nan) + + for band in self.bands: + ds[band] = set_nodata(ds[band], np.nan) + + return ds + + def add_spectral_indices( + self, indices: list[str], nodata: float | int | None = None + ) -> xr.Dataset: + """ + Add spectral indices to the current dataset. + + Args: + indices (List[str]): Spectral indices to calculate. + + Returns: + xr.Dataset: Updated dataset with spectral indices. + """ + if self.dataset is None: + msg = "No dataset loaded. Perform `execute` first." + raise ValueError(msg) + + self.dataset = calculate_indices(self.dataset, indices) + + # Set nodata value for the new spectral indices + if nodata is not None: + for index in indices: + self.dataset[index] = set_nodata(self.dataset[index], nodata) + + return self.dataset + + def composite( + self, percentile: int = 50, nodata: float | int | None = None + ) -> xr.Dataset: + """ + Apply a composite operation to the dataset based on the given percentile. + + Args: + percentile (int): Percentile to calculate (e.g., 50 for median). + Values range between 0 and 100. + nodata (float | int | None): Value to assign for nodata pixels in the resulting composite. + + Returns: + xr.Dataset: Composited dataset. + """ + if self.dataset is None: + msg = "No dataset loaded. Perform `execute` first." + raise ValueError(msg) + + if not (0 <= percentile <= 100): + msg = "Percentile must be between 0 and 100." + raise ValueError(msg) + + logging.info(f"Applying {percentile}th percentile composite.") + + # Use median() if percentile is 50, otherwise quantile() + if percentile == 50: + composite = self.dataset.median(dim="time", skipna=True, keep_attrs=True) + logging.info("Using median() for composite.") + else: + composite = self.dataset.quantile( + percentile / 100, dim="time", skipna=True, keep_attrs=True + ) + logging.info("Using quantile() for composite.") + + # Set nodata values for each band if provided + if nodata is not None: + + def apply_nodata(da): + return set_nodata(da, nodata) + + composite = composite.map(apply_nodata) + + self.dataset = composite + return self.dataset + + def execute(self) -> xr.Dataset: + """ + Trigger the search and load process and return the dataset. + """ + # Perform search if not already done + if self.items is None: + logging.info(f"Executing search with params: {self.search_params}") + search = self.catalog.search(**self.search_params) + self.items = list(search.items()) + + # Perform load if not already done + if self.dataset is None: + logging.info("Loading dataset...") + self.dataset = self._load() + + if self.percentile: + logging.info("Compositing dataset...") + self.dataset = self.composite(percentile=self.percentile, nodata=np.nan) + + if self.spectral_indices: + logging.info(f"Calculating spectral indices: {self.spectral_indices}") + self.dataset = self.add_spectral_indices( + self.spectral_indices, nodata=np.nan + ) + + return self.dataset + + +class S2Collection(ImageCollection): + """ + A class to manage Sentinel-2 collections from the Planetary Computer catalog. + """ + + def __init__( + self, + catalog_url: str = "https://planetarycomputer.microsoft.com/api/stac/v1", + collection: str = "sentinel-2-l2a", + ): + stac_cfg = { + "sentinel-2-l2a": { + "assets": { + "*": {"data_type": None, "nodata": np.nan}, + "SCL": {"data_type": None, "nodata": np.nan}, + "visual": {"data_type": None, "nodata": np.nan}, + }, + }, + "*": {"warnings": "ignore"}, + } + + super().__init__(catalog_url, collection, stac_cfg) + + +class TileCollection: + """ + A generic class to manage tile collections from a STAC-based catalog. + """ + + def __init__( + self, + catalog_url: str, + collection: str, + stac_cfg: dict | None = None, + ): + self.catalog_url = catalog_url + self.collection = collection + self.catalog = pystac_client.Client.open(self.catalog_url) + + # Configuration + self.search_params = {} + self.bands = [] + self.dst_crs = None + self.load_params = {} + self.stac_cfg = stac_cfg or {} + + # Internal state + self.items = None + self.dataset = None + + @abc.abstractmethod + def search(self, roi: gpd.GeoDataFrame) -> "TileCollection": + """ + Search for DeltaDTM items based on a region of interest. + """ + + def load( + self, + chunks: dict[str, int | str] | None = None, + resampling: str | dict[str, str] | None = None, + dtype: np.dtype | str | None = None, + crs: str | int | None = None, + resolution: float | int | None = None, + pool: int | None = None, + preserve_original_order: bool = False, + progress: bool | None = None, + fail_on_error: bool = True, + geobox: dict | None = None, + like: xr.Dataset | None = None, + patch_url: str | None = None, + dst_crs: Any | None = None, + ) -> "TileCollection": + """ + Configure loading parameters. + + Args: + Additional args: Parameters for odc.stac.load. + + Returns: + DeltaDTMCollection: Updated instance. + """ + self.dst_crs = dst_crs + + self.load_params = { + "chunks": chunks or {}, + "resampling": resampling, + "dtype": dtype, + "crs": crs, + "resolution": resolution, + "pool": pool, + "preserve_original_order": preserve_original_order, + "progress": progress, + "fail_on_error": fail_on_error, + "geobox": geobox, + "like": like, + "patch_url": patch_url, + } + return self + + def _load(self) -> xr.Dataset: + """ + Internal method to load data using odc.stac. + """ + if not self.items: + msg = "No items found. Perform a search first." + raise ValueError(msg) + + bbox = tuple(self.search_params["intersects"].bounds) + + ds = odc.stac.load( + self.items, + bbox=bbox, + **self.load_params, + ).squeeze() + + if self.dst_crs and ( + pyproj.CRS.from_user_input(self.dst_crs).to_epsg() != ds.rio.crs.to_epsg() + ): + ds = ds.rio.reproject(self.dst_crs) + ds = ds.odc.reproject(self.dst_crs, resampling="cubic", nodata=np.nan) + + return ds + + def _post_process(self, ds: xr.Dataset) -> xr.Dataset: + """Post-process the dataset.""" + return ds + + def execute(self) -> xr.Dataset: + """ + Trigger the search and load process and return the dataset. + """ + # Perform search if not already done + if self.items is None: + msg = "No items found. Perform a search first." + raise ValueError(msg) + + # Perform load if not already done + if self.dataset is None: + logging.info("Loading dataset...") + self.dataset = self._load() + self.dataset = self._post_process(self.dataset) + + return self.dataset + + +class DeltaDTMCollection(TileCollection): + """ + A class to manage Delta DTM collections from the CoCliCo catalog. + """ + + def __init__( + self, + catalog_url: str = "https://coclico.blob.core.windows.net/stac/v1/catalog.json", + collection: str = "deltares-delta-dtm", + ): + super().__init__(catalog_url, collection) + + def search(self, roi: gpd.GeoDataFrame) -> "DeltaDTMCollection": + """ + Search for DeltaDTM items based on a region of interest. + """ + + self.search_params = { + "collections": self.collection, + "intersects": roi.to_crs(4326).geometry.item(), + } + + col = self.catalog.get_collection(self.collection) + storage_options = col.extra_fields["item_assets"]["data"][ + "xarray:storage_options" + ] + ddtm_extents = read_snapshot( + col, + columns=None, + storage_options=storage_options, + ) + r = gpd.sjoin(ddtm_extents, roi.to_crs(ddtm_extents.crs)).drop( + columns="index_right" + ) + self.items = list(stac_geoparquet.to_item_collection(r)) + + # Check if items were found + if not self.items: + msg = "No items found for the given search parameters." + raise ValueError(msg) + + return self + + def _post_process(self, ds: xr.Dataset) -> xr.Dataset: + """Post-process the dataset.""" + ds["data"] = ds["data"].where(ds["data"] != ds["data"].attrs["nodata"], 0) + # NOTE: Idk if this is good practice + ds["data"].attrs["nodata"] = np.nan + return ds + + +class CopernicusDEMCollection(TileCollection): + """ + A class to manage Copernicus DEM collections from the Planetary Computer catalog. + """ + + def __init__( + self, + catalog_url: str = "https://planetarycomputer.microsoft.com/api/stac/v1", + collection: str = "cop-dem-glo-30", + ): + stac_cfg = { + "cop-dem-glo-30": { + "assets": { + "*": {"data_type": "int16", "nodata": -32768}, + }, + "*": {"warnings": "ignore"}, + } + } + + super().__init__(catalog_url, collection, stac_cfg) + + def search(self, roi: gpd.GeoDataFrame) -> "CopernicusDEMCollection": + """ + Search for Copernicus DEM items based on a region of interest. + """ + self.search_params = { + "collections": self.collection, + "intersects": roi.to_crs(4326).geometry.item(), + } + + # Perform the search + logging.info(f"Executing search with params: {self.search_params}") + search = self.catalog.search(**self.search_params) + self.items = list(search.items()) + + # Check if items were found + if not self.items: + msg = "No items found for the given search parameters." + raise ValueError(msg) + + return self + + +if __name__ == "__main__": + + def filter_and_sort_stac_items( + items: list[pystac.Item], + max_items: int, + group_by: str, + sort_by: str, + ) -> list[pystac.Item]: + """ + Filter and sort STAC items by grouping and ranking within each group. + + Args: + items (list[pystac.Item]): List of STAC items to process. + max_items (int): Maximum number of items to return per group. + group_by (str): Property to group by (e.g., 's2:mgrs_tile'). + sort_by (str): Property to sort by within each group (e.g., 'eo:cloud_cover'). + + Returns: + list[pystac.Item]: Filtered and sorted list of STAC items. + """ + try: + # Convert STAC items to a DataFrame + df = ( + stac_geoparquet.arrow.parse_stac_items_to_arrow(items) + .read_all() + .to_pandas() + ) + + # Group by the specified property and sort within groups + df = ( + df.groupby(group_by, group_keys=False) + .apply(lambda group: group.sort_values(sort_by).head(max_items)) + .reset_index(drop=True) + ) + + # Reconstruct the filtered list of items from indices + return [items[idx] for idx in df.index] + + except Exception as err: + logging.error(f"Error filtering and sorting items: {err}") + return [] + + import planetary_computer as pc + import shapely + + west, south, east, north = ( + -1.4987754821777346, + 46.328320550966765, + -1.446976661682129, + 46.352022707044455, + ) + roi = gpd.GeoDataFrame( + geometry=[shapely.geometry.box(west, south, east, north)], crs=4326 + ) + + def filter_function(items): + return filter_and_sort_stac_items( + items, max_items=10, group_by="s2:mgrs_tile", sort_by="eo:cloud_cover" + ) + + s2 = ( + S2Collection() + .search( + roi, + datetime_range="2023-01-01/2023-12-31", + query={"eo:cloud_cover": {"lt": 20}}, + filter_function=filter_function, + ) + .load( + bands=["blue", "green", "red", "nir", "swir16"], + percentile=50, + spectral_indices=["NDWI", "NDVI"], + chunks={"x": 256, "y": 256}, + patch_url=pc.sign, + ) + .execute() + ) + + s2 = s2.compute() + deltadtm = DeltaDTMCollection().search(roi).load().execute() + cop_dem = CopernicusDEMCollection().search(roi).load().execute() + print("done") diff --git a/src/coastpy/eo/indices.py b/src/coastpy/eo/indices.py new file mode 100644 index 0000000..7c61292 --- /dev/null +++ b/src/coastpy/eo/indices.py @@ -0,0 +1,283 @@ +import inspect +import logging +import re +from collections.abc import Callable + +import numpy as np +import xarray as xr + +from coastpy.utils.xarray import set_nodata + +# Define logging +logger = logging.getLogger(__name__) + +INDEX_DICT = { + "NDVI": { + "formula": lambda ds: (ds.nir - ds.red) / (ds.nir + ds.red), + "description": "Normalized Difference Vegetation Index, Rouse 1973", + }, + "kNDVI": { + "formula": lambda ds: np.tanh(((ds.nir - ds.red) / (ds.nir + ds.red)) ** 2), + "description": "Non-linear Normalized Difference Vegetation Index, Camps-Valls et al. 2021", + }, + "EVI": { + "formula": lambda ds: ( + (2.5 * (ds.nir - ds.red)) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1) + ), + "description": "Enhanced Vegetation Index, Huete 2002", + }, + "LAI": { + "formula": lambda ds: ( + 3.618 + * ((2.5 * (ds.nir - ds.red)) / (ds.nir + 6 * ds.red - 7.5 * ds.blue + 1)) + - 0.118 + ), + "description": "Leaf Area Index, Boegh 2002", + }, + "SAVI": { + "formula": lambda ds: ((1.5 * (ds.nir - ds.red)) / (ds.nir + ds.red + 0.5)), + "description": "Soil Adjusted Vegetation Index, Huete 1988", + }, + "MSAVI": { + "formula": lambda ds: ( + (2 * ds.nir + 1 - ((2 * ds.nir + 1) ** 2 - 8 * (ds.nir - ds.red)) ** 0.5) + / 2 + ), + "description": "Modified Soil Adjusted Vegetation Index, Qi et al. 1994", + }, + "NDMI": { + "formula": lambda ds: (ds.nir - ds.swir1) / (ds.nir + ds.swir1), + "description": "Normalized Difference Moisture Index, Gao 1996", + }, + "NBR": { + "formula": lambda ds: (ds.nir - ds.swir2) / (ds.nir + ds.swir2), + "description": "Normalized Burn Ratio, Lopez Garcia 1991", + }, + "BAI": { + "formula": lambda ds: (1.0 / ((0.10 - ds.red) ** 2 + (0.06 - ds.nir) ** 2)), + "description": "Burn Area Index, Martin 1998", + }, + "NDCI": { + "formula": lambda ds: (ds.red_edge_1 - ds.red) / (ds.red_edge_1 + ds.red), + "description": "Normalized Difference Chlorophyll Index, Mishra & Mishra, 2012", + }, + "NDSI": { + "formula": lambda ds: (ds.green - ds.swir1) / (ds.green + ds.swir1), + "description": "Normalized Difference Snow Index, Hall 1995", + }, + "NDTI": { + "formula": lambda ds: (ds.swir1 - ds.swir2) / (ds.swir1 + ds.swir2), + "description": "Normalized Difference Tillage Index, Van Deventer et al. 1997", + }, + "NDWI": { + "formula": lambda ds: (ds.green - ds.nir) / (ds.green + ds.nir), + "description": "Normalized Difference Water Index, McFeeters 1996", + }, + "MNDWI": { + "formula": lambda ds: (ds.green - ds.swir1) / (ds.green + ds.swir1), + "description": "Modified Normalized Difference Water Index, Xu 2006", + }, + "NDBI": { + "formula": lambda ds: (ds.swir1 - ds.nir) / (ds.swir1 + ds.nir), + "description": "Normalized Difference Built-Up Index, Zha 2003", + }, + "BUI": { + "formula": lambda ds: ((ds.swir1 - ds.nir) / (ds.swir1 + ds.nir)) + - ((ds.nir - ds.red) / (ds.nir + ds.red)), + "description": "Built-Up Index, He et al. 2010", + }, + "BAEI": { + "formula": lambda ds: (ds.red + 0.3) / (ds.green + ds.swir1), + "description": "Built-Up Area Extraction Index, Bouzekri et al. 2015", + }, + "NBI": { + "formula": lambda ds: (ds.swir1 + ds.red) / ds.nir, + "description": "New Built-Up Index, Jieli et al. 2010", + }, + "BSI": { + "formula": lambda ds: ((ds.swir1 + ds.red) - (ds.nir + ds.blue)) + / ((ds.swir1 + ds.red) + (ds.nir + ds.blue)), + "description": "Bare Soil Index, Rikimaru et al. 2002", + }, + "AWEI_ns": { + "formula": lambda ds: ( + 4 * (ds.green - ds.swir1) - (0.25 * ds.nir + 2.75 * ds.swir2) + ), + "description": "Automated Water Extraction Index (no shadows), Feyisa 2014", + }, + "AWEI_sh": { + "formula": lambda ds: ( + ds.blue + 2.5 * ds.green - 1.5 * (ds.nir + ds.swir1) - 0.25 * ds.swir2 + ), + "description": "Automated Water Extraction Index (shadows), Feyisa 2014", + }, + "WI": { + "formula": lambda ds: ( + 1.7204 + + 171 * ds.green + + 3 * ds.red + - 70 * ds.nir + - 45 * ds.swir1 + - 71 * ds.swir2 + ), + "description": "Water Index, Fisher 2016", + }, + "TCW": { + "formula": lambda ds: ( + 0.0315 * ds.blue + + 0.2021 * ds.green + + 0.3102 * ds.red + + 0.1594 * ds.nir + + -0.6806 * ds.swir1 + + -0.6109 * ds.swir2 + ), + "description": "Tasseled Cap Wetness, Crist 1985", + }, + "TCG": { + "formula": lambda ds: ( + -0.1603 * ds.blue + + -0.2819 * ds.green + + -0.4934 * ds.red + + 0.7940 * ds.nir + + -0.0002 * ds.swir1 + + -0.1446 * ds.swir2 + ), + "description": "Tasseled Cap Greeness, Crist 1985", + }, + "TCB": { + "formula": lambda ds: ( + 0.2043 * ds.blue + + 0.4158 * ds.green + + 0.5524 * ds.red + + 0.5741 * ds.nir + + 0.3124 * ds.swir1 + + -0.2303 * ds.swir2 + ), + "description": "Tasseled Cap Brightness, Crist 1985", + }, + "TCW_GSO": { + "formula": lambda ds: ( + 0.0649 * ds.blue + + 0.2802 * ds.green + + 0.3072 * ds.red + + -0.0807 * ds.nir + + -0.4064 * ds.swir1 + + -0.5602 * ds.swir2 + ), + "description": "Tasseled Cap Wetness, Nedkov 2017", + }, + "TCG_GSO": { + "formula": lambda ds: ( + -0.0635 * ds.blue + + -0.168 * ds.green + + -0.348 * ds.red + + 0.3895 * ds.nir + + -0.4587 * ds.swir1 + + -0.4064 * ds.swir2 + ), + "description": "Tasseled Cap Greeness, Nedkov 2017", + }, + "TCB_GSO": { + "formula": lambda ds: ( + 0.0822 * ds.blue + + 0.136 * ds.green + + 0.2611 * ds.red + + 0.5741 * ds.nir + + 0.3882 * ds.swir1 + + 0.1366 * ds.swir2 + ), + "description": "Tasseled Cap Brightness, Nedkov 2017", + }, + "CMR": { + "formula": lambda ds: (ds.swir1 / ds.swir2), + "description": "Clay Minerals Ratio, Drury 1987", + }, + "FMR": { + "formula": lambda ds: (ds.swir1 / ds.nir), + "description": "Ferrous Minerals Ratio, Segal 1982", + }, + "IOR": { + "formula": lambda ds: (ds.red / ds.blue), + "description": "Iron Oxide Ratio, Segal 1982", + }, + "BR": { + "formula": lambda ds: (ds.blue - ds.red) / (ds.blue + ds.red), + "description": "Blue-Red Index, CoastSat Classifier", + }, +} + + +def _get_fargs(func: Callable) -> set[str]: + """Returns a set of variables used in a provided function by inspecting its source code.""" + source_code = inspect.getsource(func) + vars_used = re.findall(r"ds\.([a-zA-Z0-9]+)", source_code) + return set(vars_used) + + +def calculate_indices( + ds: xr.Dataset, + index: str | list[str], + normalize: bool = True, + drop: bool = False, + nodata: float | None = None, +) -> xr.Dataset: + """ + Calculate spectral indices for an xarray dataset. + + Parameters + ---------- + ds : xarray.Dataset + The dataset containing the spectral bands required for index calculation. + index : str or list of str + The name(s) of the index or indices to calculate. + normalize : bool, optional + If True, normalize data by dividing by 10000. Defaults to True. + drop : bool, optional + If True, drop the original bands from the dataset. Defaults to False. + nodata: float or None, optional + If provided, replace nodata values with this value. Defaults to np.nan. + + Returns + ------- + xr.Dataset + A dataset with the calculated indices added as new variables. + """ + # Ensure index is a list for consistent processing + indices = [index] if isinstance(index, str) else index + + # Validate indices + invalid_indices = [idx for idx in indices if idx not in INDEX_DICT] + if invalid_indices: + msg = ( + f"Invalid index/indices: {invalid_indices}. " + f"Valid options are: {list(INDEX_DICT.keys())}." + ) + raise ValueError(msg) + + # Normalize dataset if requested + if normalize: + ds = ds / 10000.0 + + # Compute indices + for idx in indices: + index_info = INDEX_DICT[idx] + formula = index_info["formula"] + required_bands = _get_fargs(formula) + + # Check required bands are present in the dataset + missing_bands = required_bands - set(ds.data_vars) + if missing_bands: + msg = f"Dataset is missing required bands for '{idx}': {missing_bands}." + raise ValueError(msg) + + # Calculate index and add to dataset + ds[idx] = formula(ds) + + if nodata is not None: + ds[idx] = set_nodata(ds[idx], nodata) + + # Drop original bands if requested + if drop: + ds = ds.drop_vars([var for var in ds.data_vars if var not in indices]) + + return ds diff --git a/src/coastpy/utils/xarray.py b/src/coastpy/utils/xarray.py index c0499e0..6d72d4e 100644 --- a/src/coastpy/utils/xarray.py +++ b/src/coastpy/utils/xarray.py @@ -1,12 +1,124 @@ import warnings +from typing import Literal import numpy as np +import rioxarray # noqa import xarray as xr from affine import Affine from rasterio.enums import Resampling from shapely import Polygon +def get_nodata( + da: xr.DataArray | xr.Dataset, band: str | None = None +) -> float | int | None: + """ + Find the nodata value in an Xarray DataArray or Dataset. + + This function checks for the nodata value in the following order: + 1. `nodata` attribute. + 2. `_FillValue` attribute. + 3. `rio.nodata` (if using rioxarray). + + Args: + da (xr.DataArray | xr.Dataset): Input Xarray object. + band (str, optional): Band name to check if input is a Dataset. + + Returns: + float | None: The nodata value if found, otherwise None. + """ + # Get the target DataArray + if isinstance(da, xr.Dataset): + if band is None: + msg = "For Dataset input, 'band' must be specified." + raise ValueError(msg) + if band not in da: + msg = f"Band '{band}' not found in the Dataset." + raise ValueError(msg) + da = da[band] + + # Check for 'nodata' in attrs + nodata_value = da.attrs.get("nodata", None) + if nodata_value is not None: + return nodata_value + + # Check for '_FillValue' in attrs + nodata_value = da.attrs.get("_FillValue", None) + if nodata_value is not None: + return nodata_value + + # Check for rioxarray nodata + try: + nodata_value = da.rio.nodata + if nodata_value is not None: + return nodata_value + except AttributeError: + # rioxarray is not available + pass + + # Raise a warning if no nodata value is found + warnings.warn( # noqa: B028 + "No nodata value found in 'nodata', '_FillValue', or 'rio.nodata'.", + UserWarning, + ) + return None + + +def set_nodata( + da: xr.DataArray | xr.Dataset, + nodata_value: float | int | None, + band: str | None = None, + target: Literal["nodata", "_FillValue"] = "nodata", +) -> xr.DataArray | xr.Dataset: + """ + Set the nodata value for an Xarray DataArray or Dataset. + + This function sets the nodata value in the specified attribute + (`nodata` or `_FillValue`) and attempts to set `rio.nodata` for compatibility + with rioxarray. + + Args: + da (xr.DataArray | xr.Dataset): Input Xarray object to modify. + nodata_value (float | int | None): The nodata value to set. Use `None` to clear. + band (str, optional): Band name to modify if input is a Dataset. + target (str, optional): Target attribute to set, either 'nodata' (default) or '_FillValue'. + + Returns: + xr.DataArray | xr.Dataset: The modified DataArray or Dataset. + """ + if target not in ["nodata", "_FillValue"]: + msg = "The target parameter must be either 'nodata' or '_FillValue'." + raise ValueError(msg) + + # Handle Dataset case + if isinstance(da, xr.Dataset): + if band is None: + msg = "For Dataset input, 'band' must be specified." + raise ValueError(msg) + if band not in da: + msg = f"Band '{band}' not found in the Dataset." + raise ValueError(msg) + da = da[band] + + # Set the specified attribute + if nodata_value is not None: + da.attrs[target] = nodata_value + else: + da.attrs.pop(target, None) # Remove the attribute if nodata_value is None + + # Always attempt to set rioxarray nodata + try: + if nodata_value is not None: + da.rio.write_nodata(nodata_value, inplace=True) + else: + da.rio.update_attrs({"nodata": None}, inplace=True) + except AttributeError: + # rioxarray is not available or not in use + pass + + return da + + def make_template(data: xr.DataArray) -> xr.DataArray: """ Create a template DataArray with the same structure as `data` but filled with object data type. diff --git a/tutorials/Untitled.ipynb b/tutorials/Untitled.ipynb new file mode 100644 index 0000000..bf60671 --- /dev/null +++ b/tutorials/Untitled.ipynb @@ -0,0 +1,1894 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import fsspec\n", + "import geopandas as gpd\n", + "import hvplot.xarray\n", + "import pystac\n", + "import rioxarray\n", + "import shapely\n", + "import xarray as xr\n", + "from ipyleaflet import Map, basemaps\n", + "\n", + "from coastpy.stac.utils import read_snapshot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [], + "source": [ + "m = Map(basemap=basemaps.Esri.WorldImagery, scroll_wheel_zoom=True)\n", + "m.center = 46.34, -1.47\n", + "m.zoom = 15\n", + "m.layout.height = \"800px\"\n", + "m" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "with fsspec.open(\n", + " \"https://coclico.blob.core.windows.net/tiles/S2A_OPER_GIP_TILPAR_MPC.parquet\", \"rb\"\n", + ") as f:\n", + " s2grid = gpd.read_parquet(f).to_crs(4326)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import dotenv\n", + "\n", + "dotenv.load_dotenv()\n", + "sas_token = os.getenv(\"AZURE_STORAGE_SAS_TOKEN\")\n", + "storage_options = {\"account_name\": \"coclico\", \"sas_token\": sas_token}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "import dask_geopandas as dgpd\n", + "import fsspec\n", + "import geopandas as gpd\n", + "\n", + "from coastpy.geo.quadtiles import make_mercantiles\n", + "\n", + "grid = make_mercantiles(5)\n", + "\n", + "# Load coastline buffer data (2km buffer around the coastline) and convert to WGS84\n", + "buffer = dask_geopandas.read_parquet(\n", + " \"az://coastline-buffer/osm-coastlines-buffer-2000m.parquet\",\n", + " storage_options=storage_options,\n", + ").compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "Script to generate a processing grid for scalable geospatial analysis.\n", + "This script:\n", + "1. Generates a Mercator grid based on a specified zoom level.\n", + "2. Reads and processes coastline buffer data.\n", + "3. Clips the buffer to the European region.\n", + "4. Filters tiles based on intersection with the coastline.\n", + "5. # TODO: Clips tiles based on a coastline buffer. \n", + "\n", + "Dependencies: dask-geopandas, geopandas, fsspec\n", + "\"\"\"\n", + "\n", + "import dask_geopandas as dgpd\n", + "import fsspec\n", + "import geopandas as gpd\n", + "\n", + "from coastpy.geo.quadtiles import make_mercantiles\n", + "\n", + "grid = make_mercantiles(5)\n", + "\n", + "# Load coastline buffer data (2km buffer around the coastline) and convert to WGS84\n", + "buffer = dask_geopandas.read_parquet(\n", + " \"az://coastline-buffer/osm-coastlines-buffer-2000m.parquet\",\n", + " storage_options=storage_options,\n", + ").compute()\n", + "\n", + "\n", + "# Load coastline data\n", + "coastline_urlpath = \"az://coastlines-osm/release/2023-02-09/coast_3857_gen9.parquet\"\n", + "with fsspec.open(coastline_urlpath, mode=\"rb\", **storage_options) as f:\n", + " coastline = gpd.read_parquet(f).to_crs(4326)\n", + "\n", + "\n", + "# Load countries dataset and filter for Europe\n", + "with fsspec.open(\n", + " \"https://coclico.blob.core.windows.net/public/countries.parquet\", \"rb\"\n", + ") as f:\n", + " countries = gpd.read_parquet(f)\n", + "\n", + "europe = countries[\n", + " (countries[\"continent\"] == \"EU\")\n", + " & (~countries[\"common_country_name\"].isin([\"Svalbard\", \"Russia\"]))\n", + "]\n", + "\n", + "# Clip the coastline buffer to the European region\n", + "european_coastline = gpd.clip(coastline, europe)\n", + "\n", + "# Filter the Mercator grid tiles by spatial join with the European buffer\n", + "filtered_tiles = grid.sjoin(european_coastline, how=\"inner\")\n", + "\n", + "# Drop unnecessary columns to clean the output\n", + "filtered_tiles = filtered_tiles.drop(columns=[\"index_right\"])\n", + "\n", + "# TODO: I want the tiles to be clipped by the buffer. So that within each tile\n", + "# I only have the buffer area (coastal zone)\n", + "... # < your code comes here.\n", + "clipped_tiles = gpd.overlay(filtered_tiles, european_buffer, how=\"intersection\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "def buffer_antimeridian(buffer_size):\n", + " \"\"\"\n", + " Create a buffer around the antimeridian to handle crossings.\n", + "\n", + " Args:\n", + " buffer_size (float): Buffer distance in degrees.\n", + "\n", + " Returns:\n", + " gpd.GeoDataFrame: Buffered antimeridian area.\n", + " \"\"\"\n", + " antimeridian = shapely.geometry.LineString([(180, -90), (180, 90)])\n", + " buffer_geom = shapely.geometry.Polygon(antimeridian.buffer(buffer_size))\n", + " return gpd.GeoDataFrame(geometry=[buffer_geom], crs=\"EPSG:4326\")\n", + "\n", + "\n", + "buffer_antimeridian()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "geom = gpd.GeoDataFrame(\n", + " geometry=[shapely.geometry.LineString([(180, -90), (180, 90)])], crs=4326\n", + ")\n", + "\n", + "\n", + "utm_northeast = 32660\n", + "utm_northwest = 32601\n", + "utm_southeast = ...\n", + "utm_southwest = ...\n", + "\n", + "print(list(geom.to_crs(utm_crs).geometry.item().coords))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "import geopandas as gpd\n", + "import shapely.geometry\n", + "from pyproj import CRS, Transformer\n", + "from shapely.ops import transform\n", + "\n", + "\n", + "def create_antimeridian_buffer(buffer_size_meters):\n", + " \"\"\"\n", + " Create a valid buffer area around the antimeridian.\n", + "\n", + " Args:\n", + " buffer_size_meters (float): Buffer distance in meters.\n", + "\n", + " Returns:\n", + " gpd.GeoDataFrame: Buffered area around the antimeridian in EPSG:4326.\n", + " \"\"\"\n", + " # Step 1: Define the antimeridian line\n", + " antimeridian_line = gpd.GeoDataFrame(\n", + " geometry=[shapely.geometry.LineString([(180, -90), (180, 90)])], crs=\"EPSG:4326\"\n", + " )\n", + "\n", + " # Step 2: Define UTM zones for each quadrant\n", + " utm_zones = {\n", + " \"northeast\": 32660, # UTM zone 60N\n", + " \"northwest\": 32601, # UTM zone 1N\n", + " \"southeast\": 32760, # UTM zone 60S\n", + " \"southwest\": 32701, # UTM zone 1S\n", + " }\n", + "\n", + " buffered_quadrants = []\n", + "\n", + " # Step 3 & 4: Transform to UTM, compute buffer, and convert back\n", + " for quadrant, utm_crs in utm_zones.items():\n", + " transformer_to_utm = Transformer.from_crs(\n", + " \"EPSG:4326\", CRS.from_epsg(utm_crs), always_xy=True\n", + " )\n", + " transformer_to_4326 = Transformer.from_crs(\n", + " CRS.from_epsg(utm_crs), \"EPSG:4326\", always_xy=True\n", + " )\n", + "\n", + " # Transform the antimeridian line to UTM CRS\n", + " utm_line = transform(\n", + " transformer_to_utm.transform, antimeridian_line.geometry.item()\n", + " )\n", + "\n", + " # Buffer in UTM CRS\n", + " utm_buffer = shapely.geometry.Polygon(\n", + " [\n", + " (utm_line.coords[0][0] - buffer_size_meters, utm_line.coords[0][1]),\n", + " (utm_line.coords[1][0] - buffer_size_meters, utm_line.coords[1][1]),\n", + " (utm_line.coords[1][0] + buffer_size_meters, utm_line.coords[1][1]),\n", + " (utm_line.coords[0][0] + buffer_size_meters, utm_line.coords[0][1]),\n", + " ]\n", + " )\n", + "\n", + " # Convert buffered geometry back to EPSG:4326\n", + " epsg4326_buffer = transform(transformer_to_4326.transform, utm_buffer)\n", + "\n", + " # Append to the list of buffered quadrants\n", + " buffered_quadrants.append(epsg4326_buffer)\n", + "\n", + " # Step 5: Combine buffered quadrants into one GeoDataFrame\n", + " combined_buffers = gpd.GeoDataFrame(geometry=buffered_quadrants, crs=\"EPSG:4326\")\n", + "\n", + " # Step 6: Validate geometries\n", + " combined_buffers[\"is_valid\"] = combined_buffers.geometry.is_valid\n", + " if not combined_buffers[\"is_valid\"].all():\n", + " print(\"Some geometries are invalid. Attempting to fix...\")\n", + " combined_buffers[\"geometry\"] = combined_buffers.geometry.apply(\n", + " lambda geom: geom.buffer(0) if not geom.is_valid else geom\n", + " )\n", + "\n", + " return combined_buffers\n", + "\n", + "\n", + "# Example Usage\n", + "buffer_size = 300000 # 300 km buffer\n", + "antimeridian_buffer = create_antimeridian_buffer(buffer_size)\n", + "print(antimeridian_buffer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [], + "source": [ + "from geopy.distance import geodesic\n", + "\n", + "\n", + "def compute_geodesic_point(lat, lon, distance_km, bearing):\n", + " \"\"\"\n", + " Compute a geodesic destination point.\n", + "\n", + " Args:\n", + " lat (float): Latitude of the starting point.\n", + " lon (float): Longitude of the starting point.\n", + " distance_km (float): Distance to travel in kilometers.\n", + " bearing (float): Direction of travel in degrees (e.g., 90 for east).\n", + "\n", + " Returns:\n", + " tuple: (latitude, longitude) of the destination point.\n", + " \"\"\"\n", + " start_point = (lat, lon)\n", + " destination = geodesic(kilometers=distance_km).destination(start_point, bearing)\n", + " return destination.latitude, destination.longitude\n", + "\n", + "\n", + "# Example Usage:\n", + "start_lat = 0\n", + "start_lon = -180\n", + "distance = 5 # in kilometers\n", + "bearing = 90 # eastward\n", + "\n", + "destination_point = compute_geodesic_point(start_lat, start_lon, distance, bearing)\n", + "print(f\"Destination Point: {destination_point}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "start_lat, start_lon = 0.1, -180 # Antimeridian equator point\n", + "shifted_lat, shifted_lon = compute_geodesic_point(start_lat, start_lon, 5, 90)\n", + "\n", + "p1 = shapely.Point(start_lon, start_lat)\n", + "p2 = shapely.Point(shifted_lon, shifted_lat)\n", + "\n", + "gdf1 = gpd.GeoDataFrame(geometry=[p1], crs=4326)\n", + "gdf2 = gpd.GeoDataFrame(geometry=[p2], crs=4326)\n", + "\n", + "m = gdf1.explore()\n", + "gdf2.explore(m=m, color=\"red\")\n", + "\n", + "utm_crs = gdf2.estimate_utm_crs()\n", + "\n", + "line = gpd.GeoDataFrame(\n", + " geometry=[shapely.LineString([shapely.Point(p2.x, 0), shapely.Point(p2.x, 84)])],\n", + " crs=4326,\n", + ")\n", + "line.assign(geometry=line.to_crs(utm_crs).buffer(2500)).to_file(\n", + " \"/Users/calkoen/tmp/test.gpkg\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "import antimeridian\n", + "from shapely.geometry import shape\n", + "from shapely.validation import make_valid\n", + "\n", + "\n", + "def correct_antimeridian_cross(row):\n", + " \"\"\"\n", + " Correct geometries crossing the antimeridian using the antimeridian library.\n", + "\n", + " Args:\n", + " row (pd.Series): Row containing the geometry to correct.\n", + "\n", + " Returns:\n", + " shapely.geometry.base.BaseGeometry: Corrected geometry.\n", + " \"\"\"\n", + " geom = row.geometry\n", + "\n", + " try:\n", + " # Convert GeoJSON-like geometries to Shapely if necessary\n", + " if isinstance(geom, dict):\n", + " geom = shape(geom)\n", + "\n", + " # Ensure geometry is valid\n", + " if not geom.is_valid:\n", + " geom = make_valid(geom)\n", + "\n", + " # Fix geometry using antimeridian library\n", + " return antimeridian.fix_polygon(geom, fix_winding=False)\n", + " except Exception as e:\n", + " # Log and return the original geometry if correction fails\n", + " print(e)\n", + " return geom\n", + "\n", + "\n", + "def correct_antimeridian_crosses_in_df(df):\n", + " \"\"\"\n", + " Correct geometries that cross the antimeridian.\n", + "\n", + " Args:\n", + " df (gpd.GeoDataFrame): Input GeoDataFrame with `crosses_antimeridian` column.\n", + " utm_grid (gpd.GeoDataFrame): UTM grid for overlay.\n", + "\n", + " Returns:\n", + " gpd.GeoDataFrame: Updated GeoDataFrame with corrected geometries.\n", + " \"\"\"\n", + " df = df.copy()\n", + "\n", + " # Create a boolean mask for rows to correct\n", + " rows_to_correct = df[\"crosses_antimeridian\"]\n", + "\n", + " # Apply the correction only to rows where `crosses_antimeridian` is True\n", + " df.loc[rows_to_correct, \"geometry\"] = df.loc[rows_to_correct].apply(\n", + " lambda row: correct_antimeridian_cross(row), axis=1\n", + " )\n", + " return df\n", + "\n", + "\n", + "gdf = gpd.read_parquet(\"/Users/calkoen/tmp/data.parquet\")\n", + "\n", + "gdf2 = correct_antimeridian_crosses_in_df(gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "buf = gpd.read_file(\"/Users/calkoen/data/prc/test_buffer_15000_coast_3857_gen9.gpkg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "df = buf.copy()\n", + "df[(df.geom_type == \"Polygon\") | (df.geom_type == \"MultiPolygon\")]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [], + "source": [ + "buf.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "buf[buf.geom_type == \"LineString\"].explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "def crosses_antimeridian(geometry):\n", + " \"\"\"\n", + " Check if a geometry crosses the antimeridian.\n", + "\n", + " Args:\n", + " geometry (shapely.geometry.base.BaseGeometry): The input geometry.\n", + "\n", + " Returns:\n", + " bool: True if the geometry crosses the antimeridian, False otherwise.\n", + " \"\"\"\n", + " minx, miny, maxx, maxy = geometry.bounds\n", + " return maxx - minx > 180\n", + "\n", + "\n", + "def map_crosses_antimeridian(df):\n", + " src_crs = df.crs\n", + " df = df.to_crs(4326)\n", + " df[\"crosses_antimeridian\"] = df[\"geometry\"].apply(crosses_antimeridian)\n", + " df = df.to_crs(src_crs)\n", + " return df\n", + "\n", + "\n", + "r = map_crosses_antimeridian(buf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "import antimeridian\n", + "from shapely.validation import make_valid\n", + "\n", + "def correct_antimeridian_cross(row):\n", + " \"\"\"\n", + " Correct geometries crossing the antimeridian using the antimeridian library.\n", + "\n", + " Args:\n", + " row (pd.Series): Row containing the geometry to correct.\n", + "\n", + " Returns:\n", + " shapely.geometry.base.BaseGeometry: Corrected geometry.\n", + " \"\"\"\n", + " geom = row.geometry\n", + "\n", + " try:\n", + " # Fix geometry using antimeridian library\n", + " import antimeridian\n", + "\n", + " if geom.geom_type == \"Polygon\":\n", + " fixed = antimeridian.fix_polygon(geom, fix_winding=True)\n", + " fixed = make_valid(fixed)\n", + " return fixed\n", + " elif geom.geom_type == \"MultiPolygon\":\n", + " fixed = antimeridian.fix_multi_polygon(geom, fix_winding=True)\n", + " fixed = make_valid(fixed)\n", + " return fixed\n", + "\n", + " except Exception as e:\n", + " print(e)\n", + " return None\n", + "\n", + "\n", + "def correct_antimeridian_crosses_in_df(df):\n", + " \"\"\"\n", + " Correct geometries that cross the antimeridian.\n", + "\n", + " Args:\n", + " df (gpd.GeoDataFrame): Input GeoDataFrame with `crosses_antimeridian` column.\n", + " utm_grid (gpd.GeoDataFrame): UTM grid for overlay.\n", + "\n", + " Returns:\n", + " gpd.GeoDataFrame: Updated GeoDataFrame with corrected geometries.\n", + " \"\"\"\n", + " df = df.copy()\n", + " crs = df.crs\n", + " df = df.to_crs(4326)\n", + "\n", + " # Create a boolean mask for rows to correct\n", + " rows_to_correct = df[\"crosses_antimeridian\"]\n", + "\n", + " # Apply the correction only to rows where `crosses_antimeridian` is True\n", + " df.loc[rows_to_correct, \"geometry\"] = df.loc[rows_to_correct].apply(\n", + " lambda row: correct_antimeridian_cross(row), axis=1\n", + " )\n", + " df = df.to_crs(crs)\n", + " return df\n", + "\n", + "r2 = correct_antimeridian_crosses_in_df(r)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "r3 = r2[r2[\"crosses_antimeridian\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [], + "source": [ + "r3.explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "buf.iloc[[299]].explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [], + "source": [ + "m = gdf.explore()\n", + "gdf2.explore(m=m, color=\"red\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + ")\n", + "print(gdf.geom_type.unique())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "m = gdf2.explore()\n", + "gdf.explore(m=m, color=\"red\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "gdf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "import geopandas as gpd\n", + "from geopy.distance import geodesic\n", + "from shapely.geometry import LineString\n", + "\n", + "\n", + "def compute_geodesic_point(lat, lon, distance_km, bearing):\n", + " start_point = (lat, lon)\n", + " destination = geodesic(kilometers=distance_km).destination(start_point, bearing)\n", + " return destination.latitude, destination.longitude\n", + "\n", + "\n", + "def create_antimeridian_buffers(buffer_size_km=5):\n", + " # Define UTM zones\n", + " utm_zones = {\n", + " \"northeast\": 32660, # UTM zone 60N\n", + " \"northwest\": 32601, # UTM zone 1N\n", + " \"southeast\": 32760, # UTM zone 60S\n", + " \"southwest\": 32701, # UTM zone 1S\n", + " }\n", + "\n", + " # Initialize storage for polygons\n", + " buffered_geometries = []\n", + "\n", + " # Process each quadrant\n", + " for quadrant, utm_zone in utm_zones.items():\n", + " # Compute east/west starting point\n", + " direction = 90 if \"east\" in quadrant else 270 # East for NE/SE, West for NW/SW\n", + " start_lat, start_lon = 0, -180 # Antimeridian equator point\n", + " shifted_lat, shifted_lon = compute_geodesic_point(\n", + " start_lat, start_lon, 5, direction\n", + " )\n", + "\n", + " # Create north-south linestring\n", + " if \"north\" in quadrant:\n", + " line_coords = [(shifted_lon, 0), (shifted_lon, 90)] # Northward\n", + " else:\n", + " line_coords = [(shifted_lon, 0), (shifted_lon, -90)] # Southward\n", + "\n", + " line = LineString(line_coords)\n", + "\n", + " # Convert to UTM, buffer, and back to EPSG:4326\n", + " gdf = gpd.GeoDataFrame(geometry=[line], crs=\"EPSG:4326\").to_crs(utm_zone)\n", + " gdf[\"geometry\"] = gdf.buffer(buffer_size_km * 1000) # Buffer in meters\n", + " gdf = gdf.to_crs(\"EPSG:4326\")\n", + "\n", + " # Store result with metadata\n", + " gdf[\"quadrant\"] = quadrant\n", + " buffered_geometries.append(gdf)\n", + "\n", + " # Combine all buffered geometries into a single GeoDataFrame\n", + " result_gdf = gpd.GeoDataFrame(pd.concat(buffered_geometries, ignore_index=True))\n", + "\n", + " # Validate and fix geometries\n", + " result_gdf = result_gdf.set_geometry(\"geometry\")\n", + " if not result_gdf.is_valid.all():\n", + " result_gdf[\"geometry\"] = result_gdf[\"geometry\"].apply(\n", + " lambda geom: geom.buffer(0)\n", + " )\n", + "\n", + " return result_gdf\n", + "\n", + "\n", + "# Run the function\n", + "buffered_antimeridian = create_antimeridian_buffers(buffer_size_km=5)\n", + "\n", + "# Visualize the result\n", + "buffered_antimeridian.explore(tooltip=\"quadrant\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "import geopandas as gpd\n", + "import pandas as pd\n", + "from geopy.distance import geodesic\n", + "from shapely.geometry import Point\n", + "\n", + "\n", + "def compute_geodesic_point(lat, lon, distance_km, bearing):\n", + " \"\"\"\n", + " Compute a geodesic destination point.\n", + "\n", + " Args:\n", + " lat (float): Latitude of the starting point.\n", + " lon (float): Longitude of the starting point.\n", + " distance_km (float): Distance to travel in kilometers.\n", + " bearing (float): Direction of travel in degrees (e.g., 90 for east).\n", + "\n", + " Returns:\n", + " tuple: (latitude, longitude) of the destination point.\n", + " \"\"\"\n", + " start_point = (lat, lon)\n", + " destination = geodesic(kilometers=distance_km).destination(start_point, bearing)\n", + " return destination.latitude, destination.longitude\n", + "\n", + "\n", + "# Compute multiple geodesic points\n", + "start_lat = 0\n", + "start_lon = -180\n", + "distances_km = [5, 10, 15, 20, 25] # Distances in km\n", + "bearing = 90 # Eastward\n", + "\n", + "# Generate destination points\n", + "points = [\n", + " compute_geodesic_point(start_lat, start_lon, dist, bearing) for dist in distances_km\n", + "]\n", + "\n", + "# Create a DataFrame and GeoDataFrame\n", + "df = pd.DataFrame(points, columns=[\"latitude\", \"longitude\"])\n", + "df[\"geometry\"] = [Point(lon, lat) for lat, lon in zip(df[\"latitude\"], df[\"longitude\"])]\n", + "\n", + "gdf = gpd.GeoDataFrame(df, geometry=\"geometry\", crs=\"EPSG:4326\")\n", + "\n", + "# Visualize on a map\n", + "gdf.explore(\n", + " tooltip=[\"latitude\", \"longitude\"],\n", + " popup=True,\n", + " color=\"blue\",\n", + " marker_kwds={\"radius\": 5},\n", + " title=\"Geodesic Points East of Antimeridian\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "from geopy.distance import geodesic\n", + "\n", + "\n", + "def compute_geodesic_point(start_point, distance_km, bearing_degrees):\n", + " \"\"\"\n", + " Compute a geodesic point from a starting coordinate traveling a specified\n", + " distance in a specified direction.\n", + "\n", + " Args:\n", + " start_point (tuple): The starting coordinate (latitude, longitude) as a tuple.\n", + " distance_km (float): The distance to travel in kilometers.\n", + " bearing_degrees (float): The bearing in degrees (e.g., 90 for east, 270 for west).\n", + "\n", + " Returns:\n", + " tuple: The destination coordinate (latitude, longitude).\n", + " \"\"\"\n", + " # Use geopy to calculate the destination point\n", + " destination = geodesic(kilometers=distance_km).destination(\n", + " start_point, bearing_degrees\n", + " )\n", + " return destination.latitude, destination.longitude\n", + "\n", + "\n", + "# Example usage:\n", + "start_coord = (180, 0) # Antimeridian, equator\n", + "distance = 5 # Kilometers\n", + "bearing = 90 # East\n", + "\n", + "new_coord = compute_geodesic_point(start_coord, distance, bearing)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [], + "source": [ + "p1 = shapely.Point(start_coord)\n", + "p2 = shapely.Point(new_coord)\n", + "gpd.GeoDataFrame(geometry=[p1, p2], crs=4326).explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [], + "source": [ + "geom.explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "import geopandas as gpd\n", + "import shapely.geometry\n", + "from pyproj import CRS, Transformer\n", + "\n", + "\n", + "def buffer_antimeridian(buffer_size_meters):\n", + " \"\"\"\n", + " Create a valid buffer around the antimeridian to handle crossings.\n", + "\n", + " Args:\n", + " buffer_size_meters (float): Buffer size in meters.\n", + "\n", + " Returns:\n", + " gpd.GeoDataFrame: Buffered antimeridian area as valid polygons.\n", + " \"\"\"\n", + " # Define UTM zones for east and west of the antimeridian\n", + " utm_zone_east = 60 # UTM zone just east of the antimeridian\n", + " utm_zone_west = 1 # UTM zone just west of the antimeridian\n", + "\n", + " # Define UTM CRS for each zone\n", + " utm_crs_east = CRS.from_epsg(32660) # Northern hemisphere, UTM zone 60\n", + " utm_crs_west = CRS.from_epsg(32601) # Northern hemisphere, UTM zone 1\n", + "\n", + " # Define the antimeridian in latitude bounds\n", + " lat_min, lat_max = -90, 90\n", + "\n", + " # Create transformers to convert EPSG:4326 to UTM and back\n", + " to_utm_east = Transformer.from_crs(\"EPSG:4326\", utm_crs_east, always_xy=True)\n", + " to_utm_west = Transformer.from_crs(\"EPSG:4326\", utm_crs_west, always_xy=True)\n", + " to_epsg4326 = Transformer.from_crs(utm_crs_east, \"EPSG:4326\", always_xy=True)\n", + "\n", + " # Define east buffer polygon in UTM coordinates\n", + " east_poly_coords = [\n", + " to_utm_east.transform(180, lat_min), # Start at (180°, lat_min)\n", + " to_utm_east.transform(180 + buffer_size_meters / 111320, lat_min), # East edge\n", + " to_utm_east.transform(180 + buffer_size_meters / 111320, lat_max), # East edge\n", + " to_utm_east.transform(180, lat_max), # Back at (180°, lat_max)\n", + " to_utm_east.transform(180, lat_min), # Close the polygon\n", + " ]\n", + " east_polygon = shapely.geometry.Polygon(east_poly_coords)\n", + "\n", + " # Define west buffer polygon in UTM coordinates\n", + " west_poly_coords = [\n", + " to_utm_west.transform(-180, lat_min), # Start at (-180°, lat_min)\n", + " to_utm_west.transform(-180 - buffer_size_meters / 111320, lat_min), # West edge\n", + " to_utm_west.transform(-180 - buffer_size_meters / 111320, lat_max), # West edge\n", + " to_utm_west.transform(-180, lat_max), # Back at (-180°, lat_max)\n", + " to_utm_west.transform(-180, lat_min), # Close the polygon\n", + " ]\n", + " west_polygon = shapely.geometry.Polygon(west_poly_coords)\n", + "\n", + " # Convert polygons back to EPSG:4326\n", + " east_polygon_epsg4326 = shapely.ops.transform(to_epsg4326.transform, east_polygon)\n", + " west_polygon_epsg4326 = shapely.ops.transform(to_epsg4326.transform, west_polygon)\n", + "\n", + " # Combine both polygons into a GeoDataFrame\n", + " combined_polygons = gpd.GeoDataFrame(\n", + " geometry=[east_polygon_epsg4326, west_polygon_epsg4326], crs=\"EPSG:4326\"\n", + " )\n", + "\n", + " return combined_polygons\n", + "\n", + "\n", + "# Example usage\n", + "buffer_size = 300000 # Buffer size in meters\n", + "antimeridian_buffer = buffer_antimeridian(buffer_size)\n", + "print(antimeridian_buffer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [], + "source": [ + "antimeridian_buffer.iloc[[1]].explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_tiles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "r.explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_tiles.sample(1).explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [], + "source": [ + "import antimeridian\n", + "import geopandas as gpd\n", + "from shapely.geometry import mapping, shape\n", + "\n", + "\n", + "def fix_geom(geom):\n", + " \"\"\"\n", + " Fix geometries using the antimeridian library.\n", + "\n", + " Args:\n", + " geom (shapely.geometry.base.BaseGeometry): Input geometry.\n", + "\n", + " Returns:\n", + " shapely.geometry.base.BaseGeometry: Fixed geometry or the original geometry if it can't be fixed.\n", + " \"\"\"\n", + " try:\n", + " # Fix geometry using antimeridian library\n", + " if geom.is_empty:\n", + " return geom # Skip empty geometries\n", + " elif geom.is_valid:\n", + " return geom # Return if already valid\n", + "\n", + " # Try fixing specific geometry types\n", + " if geom.geom_type == \"Polygon\":\n", + " return antimeridian.fix_polygon(geom)\n", + " elif geom.geom_type == \"MultiPolygon\":\n", + " return antimeridian.fix_multipolygon(geom)\n", + " else:\n", + " # Attempt general fix for other shapes\n", + " return antimeridian.fix_shape(geom)\n", + " except Exception as e:\n", + " # Log an issue with fixing and return the original geometry\n", + " print(f\"Could not fix geometry: {e}\")\n", + " return geom\n", + "\n", + "\n", + "def fix_invalid_geometries(gdf):\n", + " \"\"\"\n", + " Fix invalid geometries in a GeoDataFrame.\n", + "\n", + " Args:\n", + " gdf (geopandas.GeoDataFrame): Input GeoDataFrame.\n", + "\n", + " Returns:\n", + " geopandas.GeoDataFrame: GeoDataFrame with fixed geometries.\n", + " \"\"\"\n", + " # Identify invalid geometries\n", + " invalid_mask = ~gdf.is_valid\n", + " invalid_geometries = gdf.loc[invalid_mask]\n", + "\n", + " print(f\"Found {len(invalid_geometries)} invalid geometries.\")\n", + "\n", + " # Apply the fix_geom function to invalid geometries\n", + " gdf.loc[invalid_mask, \"geometry\"] = invalid_geometries.geometry.map(fix_geom)\n", + "\n", + " # Re-check if geometries are now valid\n", + " still_invalid = gdf[~gdf.is_valid]\n", + " if len(still_invalid) > 0:\n", + " print(f\"Still invalid geometries: {len(still_invalid)}\")\n", + " else:\n", + " print(\"All geometries are now valid.\")\n", + "\n", + " return gdf\n", + "\n", + "\n", + "buffer = fix_invalid_geometries(coastline_buffer)\n", + "buffer = buffer[buffer.is_valid]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [], + "source": [ + "from coastpy.io.utils import name_data\n", + "\n", + "name_data(\n", + " buffer, include_random_hex=False, filename_prefix=\"osm_coastline_buffer_2000m\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [], + "source": [ + "tolerance = 0.01 # Adjust this value based on your acceptable resolution loss\n", + "simplified_buffer = buffer.copy()\n", + "simplified_buffer[\"geometry\"] = buffer.geometry.simplify(\n", + " tolerance, preserve_topology=True\n", + ")\n", + "\n", + "# Save the simplified buffer\n", + "simplified_buffer.to_parquet(\"/Users/calkoen/tmp/buffer_simplified.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52", + "metadata": {}, + "outputs": [], + "source": [ + "# simplified_buffer.to_file(\"/Users/calkoen/tmp/buffer_simplified.gpkg\")\n", + "buffer.to_file(\"/Users/calkoen/tmp/buffer.gpkg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [], + "source": [ + "invals = buffer[~buffer.is_valid]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54", + "metadata": {}, + "outputs": [], + "source": [ + "buffer.to_parquet(\"/Users/calkoen/tmp/buffer.parquet\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "invals.explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [], + "source": [ + "import antimeridian\n", + "import shapely\n", + "\n", + "\n", + "def fix_geom(geom):\n", + " try:\n", + " if geom.dtype == \n", + " geom = antimeridian.fix_shape(geom)\n", + " except:\n", + " print(\"not fixedj\")\n", + " return geom\n", + " \n", + "\n", + "invalids.geometry.map(fix_geom)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58", + "metadata": {}, + "outputs": [], + "source": [ + "invalids.explore()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [], + "source": [ + "import geodatasets\n", + "\n", + "geodatasets.data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [], + "source": [ + "gpd.read_file(gpd.datasets.get_path(\"natearth_lowres\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: LOAD DATA\n", + "s2_tiles = retrieve_s2_tiles().to_crs(4326)\n", + "rois = retrieve_rois().to_crs(4326)\n", + "# TODO: write STAC catalog for the coastal buffer\n", + "buffer = dask_geopandas.read_parquet(\n", + " \"az://coastline-buffer/osm-coastlines-buffer-2000m.parquet\",\n", + " storage_options=storage_options,\n", + ").compute()\n", + "quadtiles = make_mercantiles(zoom_level=MERCANTILES_ZOOM_LEVEL).to_crs(4326)\n", + "\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " classifier = retrieve_coastsat_classifier()\n", + "\n", + "region_of_interest = infer_region_of_interest(ROI)\n", + "\n", + "# NOTE: make an overlay with a coastline buffer to avoid querying data we do not need\n", + "buffer_aoi = gpd.overlay(buffer, region_of_interest[[\"geometry\"]].to_crs(buffer.crs))\n", + "\n", + "# TODO: add heuristic to decide which s2 tiles to use\n", + "s2_tilenames_to_process = gpd.sjoin(\n", + " s2_tiles, buffer_aoi.to_crs(s2_tiles.crs)\n", + ").Name.unique()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [], + "source": [ + "west, south, east, north = m.west, m.south, m.east, m.north\n", + "# Note: small little hack to ensure the notebook also works when running all cells at once\n", + "if not west:\n", + " west, south, east, north = (\n", + " -1.4987754821777346,\n", + " 46.328320550966765,\n", + " -1.446976661682129,\n", + " 46.352022707044455,\n", + " )\n", + "roi = gpd.GeoDataFrame(\n", + " geometry=[shapely.geometry.box(west, south, east, north)], crs=4326\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": {}, + "outputs": [], + "source": [ + "import planetary_computer as pc\n", + "\n", + "from coastpy.eo.collection import (\n", + " CopernicusDEMCollection,\n", + " DeltaDTMCollection,\n", + " S2Collection,\n", + ")\n", + "\n", + "s2_ = (\n", + " S2Collection()\n", + " .search(\n", + " roi,\n", + " datetime_range=\"2023-05-01/2023-07-31\",\n", + " query={\"eo:cloud_cover\": {\"lt\": 10}},\n", + " )\n", + " .load(\n", + " bands=[\"blue\", \"green\", \"red\", \"nir\", \"swir16\"],\n", + " # composite=50,\n", + " spectral_indices=[\"NDWI\", \"NDVI\"],\n", + " chunks={\"x\": 256, \"y\": 256},\n", + " patch_url=pc.sign,\n", + " )\n", + " .execute()\n", + ")\n", + "\n", + "deltadtm = DeltaDTMCollection().search(roi).load().execute()\n", + "cop_dem = CopernicusDEMCollection().search(roi).load(patch_url=pc.sign).execute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65", + "metadata": {}, + "outputs": [], + "source": [ + "s2 = s2_.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66", + "metadata": {}, + "outputs": [], + "source": [ + "def median_composite(ds):\n", + " \"\"\"\n", + " Compute the median composite along the 'time' dimension.\n", + "\n", + " Args:\n", + " ds (xr.Dataset): Input dataset with dimensions (time, y, x).\n", + "\n", + " Returns:\n", + " xr.Dataset: Median composite with dimensions (y, x).\n", + " \"\"\"\n", + " composite = xr.apply_ufunc(\n", + " np.nanmedian,\n", + " ds, # Input dataset\n", + " input_core_dims=[[\"time\"]], # Operate along \"time\" dimension\n", + " output_core_dims=[[]], # Collapsed \"time\" dimension\n", + " vectorize=True, # Vectorize to apply along (y, x)\n", + " dask=\"parallelized\", # Support for Dask arrays\n", + " output_dtypes=[np.float32], # Match input dtype\n", + " )\n", + " return composite\n", + "\n", + "\n", + "composite = median_composite(s2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67", + "metadata": {}, + "outputs": [], + "source": [ + "c2 = s2.median(\"time\", skipna=True, keep_attrs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68", + "metadata": {}, + "outputs": [], + "source": [ + "s2.quantile(0.15, dim=\"time\", skipna=True, keep_attrs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import xarray as xr\n", + "\n", + "\n", + "def benchmark(func, *args, **kwargs):\n", + " \"\"\"\n", + " Helper function to benchmark a given function.\n", + " \"\"\"\n", + " start_time = time.time()\n", + " result = func(*args, **kwargs)\n", + " elapsed_time = time.time() - start_time\n", + " return result, elapsed_time\n", + "\n", + "\n", + "def median_composite_builtin(ds):\n", + " \"\"\"\n", + " Compute the median composite along the 'time' dimension using Xarray's built-in median.\n", + " \"\"\"\n", + " return ds.median(\"time\", skipna=True, keep_attrs=True)\n", + "\n", + "\n", + "def median_composite_ufunc(ds):\n", + " \"\"\"\n", + " Compute the median composite along the 'time' dimension using NumPy's nanmedian and Xarray ufunc.\n", + " \"\"\"\n", + " composite = xr.apply_ufunc(\n", + " np.nanmedian,\n", + " ds,\n", + " input_core_dims=[[\"time\"]],\n", + " output_core_dims=[[]],\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " output_dtypes=[np.float32],\n", + " )\n", + " return composite\n", + "\n", + "\n", + "def quantile_composite_builtin(ds, quantile):\n", + " \"\"\"\n", + " Compute the quantile composite along the 'time' dimension using Xarray's built-in quantile.\n", + " \"\"\"\n", + " return ds.quantile(quantile, dim=\"time\", skipna=True, keep_attrs=True)\n", + "\n", + "\n", + "def quantile_composite_ufunc(ds, quantile):\n", + " \"\"\"\n", + " Compute the quantile composite along the 'time' dimension using NumPy's nanpercentile and Xarray ufunc.\n", + " \"\"\"\n", + " composite = xr.apply_ufunc(\n", + " np.nanpercentile,\n", + " ds,\n", + " kwargs={\"q\": quantile * 100},\n", + " input_core_dims=[[\"time\"]],\n", + " output_core_dims=[[]],\n", + " vectorize=True,\n", + " dask=\"parallelized\",\n", + " output_dtypes=[np.float32],\n", + " )\n", + " return composite\n", + "\n", + "\n", + "def compute_percentage_difference(array1, array2):\n", + " \"\"\"\n", + " Compute the percentage difference between two arrays.\n", + " \"\"\"\n", + " diff = np.abs(array1 - array2)\n", + " return np.mean(diff / (np.abs(array1) + np.abs(array2) + 1e-10)) * 100\n", + "\n", + "\n", + "def main():\n", + " # Simulated dataset for testing\n", + " time = pd.date_range(\"2023-01-01\", \"2023-01-10\")\n", + " y = np.linspace(-10, 10, 50)\n", + " x = np.linspace(-10, 10, 50)\n", + " data = np.random.rand(len(time), len(y), len(x))\n", + "\n", + " ds = xr.DataArray(\n", + " data,\n", + " coords={\"time\": time, \"y\": y, \"x\": x},\n", + " dims=(\"time\", \"y\", \"x\"),\n", + " name=\"test_data\",\n", + " )\n", + "\n", + " # Quantile to use for quantile composite tests\n", + " quantile = 0.15\n", + "\n", + " # Benchmark each method\n", + " methods = [\n", + " (\"Median (Xarray Built-in)\", median_composite_builtin, {\"ds\": ds}),\n", + " (\"Median (Xarray Ufunc)\", median_composite_ufunc, {\"ds\": ds}),\n", + " (\n", + " \"Quantile (Xarray Built-in)\",\n", + " quantile_composite_builtin,\n", + " {\"ds\": ds, \"quantile\": quantile},\n", + " ),\n", + " (\n", + " \"Quantile (Xarray Ufunc)\",\n", + " quantile_composite_ufunc,\n", + " {\"ds\": ds, \"quantile\": quantile},\n", + " ),\n", + " ]\n", + "\n", + " results = []\n", + " outputs = {}\n", + "\n", + " for method_name, method, kwargs in methods:\n", + " result, elapsed_time = benchmark(method, **kwargs)\n", + " outputs[method_name] = result\n", + " results.append({\"Method\": method_name, \"Time (s)\": elapsed_time})\n", + "\n", + " # Compare results and calculate percentage differences\n", + " comparisons = [\n", + " (\"Median (Xarray Built-in)\", \"Median (Xarray Ufunc)\"),\n", + " (\"Quantile (Xarray Built-in)\", \"Quantile (Xarray Ufunc)\"),\n", + " ]\n", + "\n", + " for method1, method2 in comparisons:\n", + " diff = compute_percentage_difference(\n", + " outputs[method1].values, outputs[method2].values\n", + " )\n", + " results.append(\n", + " {\n", + " \"Method\": f\"Comparison {method1} vs {method2}\",\n", + " \"Time (s)\": f\"Diff: {diff:.6f}%\",\n", + " }\n", + " )\n", + "\n", + " # Display results\n", + " results_df = pd.DataFrame(results)\n", + " print(results_df)\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70", + "metadata": {}, + "outputs": [], + "source": [ + "composite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71", + "metadata": {}, + "outputs": [], + "source": [ + "xx.median(\"time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "73", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "\n", + "def percentile_composite(ds, q, nodata=np.nan):\n", + " \"\"\"\n", + " Create a percentile composite from a datacube.\n", + "\n", + " Args:\n", + " ds (xr.Dataset): Input Dataset with dimensions (time, y, x).\n", + " q (float): Percentile to compute (e.g., 15 for 15th percentile).\n", + " nodata (float, optional): Value representing nodata. Default is NaN.\n", + "\n", + " Returns:\n", + " xr.Dataset: Composite with dimensions (y, x).\n", + " \"\"\"\n", + " # Mask nodata values\n", + " ds_masked = ds.where(ds != nodata)\n", + "\n", + " # Apply percentile calculation\n", + " def nanpercentile(data, axis, q):\n", + " return np.nanpercentile(data, q=q, axis=axis)\n", + "\n", + " composite = xr.apply_ufunc(\n", + " nanpercentile,\n", + " ds_masked,\n", + " input_core_dims=[[\"time\"]], # Operate along the \"time\" dimension\n", + " output_core_dims=[[]], # Output has no \"time\" dimension\n", + " kwargs={\"q\": q}, # Percentile value\n", + " dask=\"parallelized\", # Enable Dask support\n", + " output_dtypes=[float], # Output type\n", + " keep_attrs=True, # Retain attributes\n", + " )\n", + "\n", + " # Preserve nodata value in output\n", + " for var in composite.data_vars:\n", + " composite[var].attrs[\"nodata\"] = nodata\n", + "\n", + " return composite\n", + "\n", + "\n", + "percentile_composite(xx, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "# Sample dataset for testing\n", + "xx = xr.Dataset(\n", + " {\n", + " \"nir\": ([\"time\", \"y\", \"x\"], np.random.random((10, 100, 100))),\n", + " \"green\": ([\"time\", \"y\", \"x\"], np.random.random((10, 100, 100))),\n", + " },\n", + " coords={\n", + " \"time\": np.arange(10),\n", + " \"y\": np.linspace(0, 100, 100),\n", + " \"x\": np.linspace(0, 100, 100),\n", + " },\n", + ")\n", + "\n", + "\n", + "def median_composite(ds):\n", + " \"\"\"\n", + " Compute a median composite over the 'time' dimension.\n", + "\n", + " Args:\n", + " ds (xr.Dataset): Input Dataset with dimensions (time, y, x).\n", + "\n", + " Returns:\n", + " xr.Dataset: Composite with dimensions (y, x).\n", + " \"\"\"\n", + " # Apply median calculation using xr.apply_ufunc\n", + " composite = xr.apply_ufunc(\n", + " np.nanmedian, # Use numpy's nanmedian\n", + " ds, # Input dataset\n", + " input_core_dims=[[\"time\"]], # Operate along the \"time\" dimension\n", + " output_core_dims=[[]], # Result has no \"time\" dimension\n", + " dask=\"parallelized\", # Support for Dask arrays\n", + " output_dtypes=[ds[\"nir\"].dtype], # Output dtype inferred from input\n", + " keep_attrs=True, # Retain attributes from the input dataset\n", + " )\n", + " return composite\n", + "\n", + "\n", + "# Test the function\n", + "composite = median_composite(xx)\n", + "print(composite)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76", + "metadata": {}, + "outputs": [], + "source": [ + "xr.apply_ufunc(\n", + " np.nanpercentile, # Handle NaNs during computation\n", + " xx, # Dataset to process\n", + " input_core_dims=[[\"time\"]], # Specify \"time\" as the axis for computation\n", + " output_core_dims=[[\"y\", \"x\"]], # Output should retain \"y\" and \"x\" dimensions\n", + " kwargs={\"q\": 15}, # Pass the percentile to compute\n", + " # dask=\"parallelized\", # Enable Dask support\n", + " # output_dtypes=[float], # Define the output data type\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "\n", + "\n", + "def composite(ds: xr.Dataset, percentile: float, nodata: float = np.nan) -> xr.Dataset:\n", + " \"\"\"\n", + " Create a percentile composite from a datacube.\n", + "\n", + " Args:\n", + " ds (xr.Dataset): The input datacube with dimensions (\"time\", \"y\", \"x\").\n", + " percentile (float): The percentile to compute (e.g., 15 for 15th percentile).\n", + " nodata (float): The value representing nodata. Default is NaN.\n", + "\n", + " Returns:\n", + " xr.Dataset: The composite with the specified percentile.\n", + " \"\"\"\n", + " # 1. Mask nodata values\n", + " ds_masked = ds.where(ds != nodata)\n", + "\n", + " # 2. Apply percentile calculation\n", + " composite = xr.apply_ufunc(\n", + " np.nanpercentile, # Handle NaNs during computation\n", + " ds_masked, # Dataset to process\n", + " input_core_dims=[[\"time\"]], # Specify \"time\" as the axis for computation\n", + " output_core_dims=[[\"y\", \"x\"]], # Output should retain \"y\" and \"x\" dimensions\n", + " kwargs={\"q\": percentile}, # Pass the percentile to compute\n", + " dask=\"parallelized\", # Enable Dask support\n", + " output_dtypes=[float], # Define the output data type\n", + " )\n", + "\n", + " # 3. Preserve nodata values\n", + " for var in ds.data_vars:\n", + " composite[var] = composite[var].where(~composite[var].isnull(), nodata)\n", + " composite[var].attrs[\"nodata\"] = nodata\n", + " composite[var] = composite[var].rio.write_nodata(nodata)\n", + "\n", + " # 4. Return the composite\n", + " return composite\n", + "\n", + "\n", + "composite(s2, 15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "agg = s2.reduce(np.percentile, dim=\"time\", q=0.15)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79", + "metadata": {}, + "outputs": [], + "source": [ + "r.nir.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80", + "metadata": {}, + "outputs": [], + "source": [ + "s2.isel(time=10).blue.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81", + "metadata": {}, + "outputs": [], + "source": [ + "s2.nir.hvplot(x=\"x\", y=\"y\", geo=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82", + "metadata": {}, + "outputs": [], + "source": [ + "s2.nirhvplot(x=\"x\", y=\"y\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83", + "metadata": {}, + "outputs": [], + "source": [ + "cop_dem[\"data\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84", + "metadata": {}, + "outputs": [], + "source": [ + "cop_dem = cop_dem.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85", + "metadata": {}, + "outputs": [], + "source": [ + "deltadtm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:coastal-full] *", + "language": "python", + "name": "conda-env-coastal-full-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/deltadtm.ipynb b/tutorials/deltadtm.ipynb index 55f3509..ab02a32 100644 --- a/tutorials/deltadtm.ipynb +++ b/tutorials/deltadtm.ipynb @@ -110,9 +110,19 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "7", "metadata": {}, + "outputs": [], + "source": [ + "roi.explore()" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, "source": [ "## Find the tiles for your region of interest" ] @@ -120,7 +130,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": {}, "outputs": [], "source": [ @@ -131,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": {}, "source": [ "## Read data" @@ -140,7 +150,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -150,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -162,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": {}, "outputs": [], "source": []