From 6d83e43013ea09068abd85b930dd87f3e391306d Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 18:55:21 +0100 Subject: [PATCH 1/5] Update with plot class for eddmaps.py --- torchgeo/datasets/eddmaps.py | 107 ++++++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/eddmaps.py b/torchgeo/datasets/eddmaps.py index d3a046993a..c651a06b18 100644 --- a/torchgeo/datasets/eddmaps.py +++ b/torchgeo/datasets/eddmaps.py @@ -11,6 +11,10 @@ import pandas as pd from rasterio.crs import CRS +import matplotlib.pyplot as plt +import matplotlib.patches as patches +from typing import Tuple, Optional + from .errors import DatasetNotFoundError from .geo import GeoDataset from .utils import BoundingBox, Path, disambiguate_timestamp @@ -79,7 +83,7 @@ def __init__(self, root: Path = 'data') -> None: coords = (x, x, y, y, mint, maxt) self.index.insert(i, coords) i += 1 - + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: """Retrieve metadata indexed by query. @@ -103,3 +107,104 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {'crs': self.crs, 'bounds': bboxes} return sample + + +def plot( + self, + query: Optional[BoundingBox] = None, + title: str = "EDDMapS Dataset", + point_size: int = 20, + point_color: str = 'blue', + query_color: str = 'red', + annotate: bool = False, + figsize: Tuple[int, int] = (12, 8) +) -> None: + + """ Plot the dataset points with optional query bounding box + Args: + + query: The query to look for points, in the form of a bounding box: (minx,maxx,miny,maxy,mint,maxt) + title: Title of the plot + point_size: The size of the points plotted + point_color: The default color of the points, in case no such map is provided + query_color: color for the points which fall into the query + annotate: Controls if the points with timestamps are annotated + figsize: Size of drawn figure in the shape: (width, height) + + Raises: + + ValueError: When no points could be plotted because none were valid. + + """ + + # Filtering valid lat and long rows + valid_data = self.data.dropna(subset = [ 'Latitude' , 'Longitude']) + if valid_data.empty: + raise ValueError("No valid lat/long data to plot.") + + fig, ax = plt.subplots(figsize=figsize) + + # Plot-at-all points + + ax.scatter( + + valid_data['Longitude'], + + valid_data['Latitude'], + + s = point_size, + + c = point_color, + + edgecolor = 'k', + + alpha = 0.6, + + label = 'All Observations' + + ) + + #highlighting queried points (If) provided bounding box query + + if query: + minx, maxx, miny, maxy, mint, maxt = query + hits = self.index.intersection((minx,maxx,miny,maxy,mint, maxt)) + + # Get coordinates of hits to highlight + query_points = valid_data.iloc[[list(hits)]] + ax.scatter( + query_points['Longitude'], + query_points['Latitude'], + s = point_size * 1.5, + c = query_color, + edgecolor = 'white', + alpha = 0.8, + label = 'Query Results' + ) + + # Draw a bounding box + bbox_patch = patches.rectangle( + (minx, miny), maxx - minx, maxy - miny, + linewidth = 2, edgecolor = 'red', facecolor='none', linestyle = '--', label = "Query Bounding Box" + ) + ax.add_patch(bbox_patch) + + # Optional annotations + if annotate: + for _, row in valid_data.iterrows(): + ax.annotate( + row['ObsDate'], (row['Longitude'], row['Latitude']), + fontsize=8, alpha=0.7, textcoords="offset points", xytext=(0, 5), ha='center' + ) + + # Plot styling + ax.set_title(title, fontsize=14) + ax.set_xlabel("Longitude", fontsize=12) + ax.set_ylabel("Latitude", fontsize=12) + ax.grid(True, linestyle='--', alpha=0.5) + ax.legend() + + plt.show() + + + From 81315b355a0f2d45a59af8757493954fece89e77 Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 18:59:00 +0100 Subject: [PATCH 2/5] Update gbif.py with plot class --- torchgeo/datasets/gbif.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchgeo/datasets/gbif.py b/torchgeo/datasets/gbif.py index 3e8cfb6c88..d34123a9af 100644 --- a/torchgeo/datasets/gbif.py +++ b/torchgeo/datasets/gbif.py @@ -13,6 +13,9 @@ import pandas as pd from rasterio.crs import CRS +import matplotlib.pyplot as plt + + from .errors import DatasetNotFoundError from .geo import GeoDataset from .utils import BoundingBox, Path @@ -140,3 +143,28 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {'crs': self.crs, 'bounds': bboxes} return sample + + +def plot(self) -> None: + """Represent in graphic any mentions as a map on a globe.""" + # Extract latitude and longitude from the dataset + lat = self.data['decimalLatitude'] + long = self.data['decimalLongitude'] + + # Remove all other rows except those that have latitude and longitude data available + valid = self.data.dropna(subset=['decimalLatitude', 'decimalLongitude']) + + # Create a new figure + plt.figure(figsize=(10, 6)) + plt.scatter( + valid['decimalLongitude'], + valid['decimalLatitude'], + c='b', # Color choice can be made here + s=10, # Change the size of the markers + alpha=0.5, # Change the transparency of the points + ) + plt.title('Spatial Occurrence Distribution of GBIF Records') + plt.xlabel('Longitude') + plt.ylabel('Latitude') + plt.grid(True) + plt.show() From befe1f24707dcead50f213da31eed33cc1ada701 Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 19:08:17 +0100 Subject: [PATCH 3/5] Update inaturalist.py with plot class --- torchgeo/datasets/inaturalist.py | 87 ++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/torchgeo/datasets/inaturalist.py b/torchgeo/datasets/inaturalist.py index bb5cfe3c8d..2b213987e1 100644 --- a/torchgeo/datasets/inaturalist.py +++ b/torchgeo/datasets/inaturalist.py @@ -11,6 +11,11 @@ import pandas as pd from rasterio.crs import CRS +import matplotlib.pyplot as plt +import geopandas as gpd +from matplotlib import cm +import numpy as np + from .errors import DatasetNotFoundError from .geo import GeoDataset from .utils import BoundingBox, Path, disambiguate_timestamp @@ -110,3 +115,85 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]: sample = {'crs': self.crs, 'bounds': bboxes} return sample + + + + def plot(self, query: BoundingBox = None, time_range: tuple = None, cmap: str = 'viridis') -> None: + """ Plot the observations in a map when given a geographical bounding box as well as time frame which is optional. + + Args: + query: Optional case for bounding box which is meant to cut the observations (minx,maxx,miny,maxy,mint,maxt). + time_range: Optional filtering in terms of exact date and location for (start_time, end_time). + cmap: A color map for the time sequence. + """ + # Step 1: If a query and/or time_range is provided, filter the dataset based on this criteria. + data = self._filter_data(query, time_range) + + # Step 2: Prepare a GeoDataFrame for geospatial visualization by geographic reference. + gdf = gpd.GeoDataFrame( + data, + geometry=gpd.points_from_xy(data['longitude'], data['latitude']), + crs=self._crs + ) + + + # Step 3: Illustrate the observations that were made. + fig, ax = plt.subplots(figsize=(10, 8)) + gdf.plot(ax=ax, color=self._get_color_by_time(data, cmap=cmap), markersize=10, alpha=0.7) + + + # Step 4: Incorporate cartographic refinements (for example, coastlines, gridlines) to the map. + ax.set_title('iNaturalist Observations', fontsize=16) + ax.set_xlabel('Longitude') + ax.set_ylabel('Latitude') + ax.grid(True) + + # Insert a colorbar to show the progression of time + sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=data['observed_on'].min(), vmax=data['observed_on'].max())) + sm.set_array([]) + plt.colorbar(sm, ax=ax, label='Time of Observation') + plt.show() + + def _filter_data(self, query: BoundingBox, time_range: tuple) -> pd.DataFrame: + """This is a Helper function that helps to filter the dataset based on the bounding box provided by the query and a time range.""" + # First stage filter on bounding box + if query: + data = self._filter_by_bbox(query) + else: + data = self._load_data() + + # Now we filter on time range + if time_range: + data = data[(data['observed_on'] >= time_range[0]) & (data['observed_on'] <= time_range[1])] + return data + + + def _get_color_by_time(self, data: pd.DataFrame, cmap: str) -> np.ndarray: + """Creates a mapping of time with color.""" + norm = plt.Normalize(vmin=data['observed_on'].min(), vmax=data['observed_on'].max()) + colormap = cm.get_cmap(cmap) + return colormap(norm(data['observed_on'])) + + + def _filter_by_bbox(self, query: BoundingBox) -> pd.DataFrame: + """Helper function that broadens filters with bounding box type parameters to filter the data frame.""" + minx, maxx, miny, maxy, _, _ = query + data = self._load_data() + return data[(data['longitude'] >= minx) & (data['longitude'] <= maxx) & + (data['latitude'] >= miny) & (data['latitude'] <= maxy)] + + def _load_data(self) -> pd.DataFrame: + """Tries to get dataset from the CSV file.""" + files = glob.glob(os.path.join(self.root, '**.csv')) + if not files: + raise DatasetNotFoundError(self) + + data = pd.read_csv( + files[0], + engine='c', + usecols=['observed_on', 'latitude', 'longitude'], + ) + return data + + + From bc1f44165b2a41a5294d162b4757d21b7e396edb Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 19:10:19 +0100 Subject: [PATCH 4/5] Update mmearth.py with plot class --- torchgeo/datasets/mmearth.py | 87 ++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/torchgeo/datasets/mmearth.py b/torchgeo/datasets/mmearth.py index f363276c40..a023606ee3 100644 --- a/torchgeo/datasets/mmearth.py +++ b/torchgeo/datasets/mmearth.py @@ -13,6 +13,11 @@ import torch from torch import Tensor + +import matplotlib.pyplot as plt +import numpy as np +import torch + from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import Path, lazy_import @@ -618,3 +623,85 @@ def __len__(self) -> int: length of the dataset """ return len(self.indices) + + + +def plot_mmearth_sample(sample: dict, modality: str, band: str = None, color_map: str = 'viridis', + title: str = None, figsize: tuple = (10, 8), show_axes: bool = False): + + """ + The method is used to plot a sample from the database while taking into consideration the particular way to handle each modality. + + Parameters: + + sample (dict): A sample of a part of mmearth dataset. + modality (str): The modality that we want to use in the plot(give for example ‘Sentinel2’, stomp in!, or dynamic world or cloud mask etc.). + band (str, optional): A particular band within the above mentioned modality which is intended for plotting purposes. None by default. + color_map (str, optional): The color palette for the plotting. Default value is set to be ‘viridis’. + title (str, optional): The feature which maps the titles to the plots. Default value set to None. + figsize (tuple, optional): It specifies the height and width of the figure in inches. Its default value is (10,8). + show_axes (bool, optional): Sets whether the plot axes are displayed or not. Default value is set to False. + """ + # Determine whether sample includes the modality + if modality not in sample: + raise ValueError("Modality '{modality}' has not been located in the sample!") + + # fetch the information regarding the certain modality + modality_data = sample[modality] + + if modality == 'sentinel2': + if band: + if band not in sample['avail_bands'][modality]: + raise ValueError(f"Band '{band}' has not been motioned for the including of the modality '{modality}'!") + band_index = sample['avail_bands'][modality].index(band) + modality_data = modality_data[band_index] # Provides extraction of band + + else: + modality_data = modality_data[0] # If specific band is not asked, first band is used instead. + + modality_data = modality_data/modality_data.max() + + elif modality == 'sentinel1': + if band: + raise ValueError ("Sentinel-1 modality does not allow using a band. Use either 'asc' or 'desc' instead.") + + if 'asc' in sample['avail_bands'][modality]: + modality_data = sample[modality]['asc'] + elif 'desc' in sample['avail_bands'][modality]: + modality_data = sample[modality]['desc'] + else: + raise ValueError("Sentinel-1 data missing 'asc' or 'desc' information.") + + modality_data = modality_data.div(modality_data.max()); + + elif modality == 'dynamic_world': + modality_data = modality_data.astype(int); + color_map = 'tab20' + + elif modality == 'cloud_mask': + modality_data = modality_data.astype(int); + color_map = 'binary' + + elif modality == 'temperature': + color_map = 'coolwarm' + + elif modality == 'precipitation': + color_map = 'Blues' + + else: + if isinstance(modality_data, torch.Tensor): + modality_data = modality_data.numpy() + modality_data = modality_data.div(modality_data.max()); + + #Plotting + plt.figure(figsize=figsize) + plt.imshow(modality_data, cmap=color_map) + plt.title(title if title else f'{modality} - {band if band else "Band 1"}', fontsize=16) + plt.colorbar(label='Intensity') + if not show_axes: + plt.axis('off') # Hide axes for better visualization + plt.show() + +# Example usage: +# plot_mmearth_sample(sample, modality='sentinel2', band='B1', color_map='viridis', title='Sentinel-2 Band 1') +# plot_mmearth_sample(sample, modality='dynamic_world', color_map='tab20', title='Land Cover') From 56e42640485663cdf11074f925866e1174beda48 Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 19:11:40 +0100 Subject: [PATCH 5/5] Update western_usa_live_fuel_moisture.py with plot class --- .../western_usa_live_fuel_moisture.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index fe51f6ade8..6052221644 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -12,6 +12,9 @@ import pandas as pd import torch +import matplotlib.pyplot as plt +import seaborn as sns + from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import Path, which @@ -297,3 +300,77 @@ def _download(self) -> None: os.makedirs(self.root, exist_ok=True) azcopy = which('azcopy') azcopy('sync', self.url, self.root, '--recursive=true') + + + def plot( + self, + x_feature: str = None, + y_feature: str = None, + kind: str = "scatter", + title: str = None, + save_path: str = None, + ) -> None: + """Plot features or relationships within the dataset. + + Args: + x_feature: Name of the feature to plot on the x-axis. + y_feature: Name of the feature to plot on the y-axis. + Defaults to the label if not specified. + kind: Type of plot ('scatter', 'hist', 'box', or 'geo'). + title: Title of the plot. + save_path: If provided, save the plot to the given path. + """ + if x_feature not in self.input_features: + raise ValueError(f"'{x_feature}' is not a valid input feature.") + if y_feature is None: + y_feature = self.label_name + if y_feature not in self.input_features and y_feature != self.label_name: + raise ValueError(f"'{y_feature}' is not a valid feature or label.") + + plt.figure(figsize=(10, 6)) + + if kind == "scatter": + # Scatter plot for feature relationships + sns.scatterplot( + x=self.dataframe[x_feature], + y=self.dataframe[y_feature], + alpha=0.7, + ) + plt.xlabel(x_feature) + plt.ylabel(y_feature) + plt.title(title or f"Scatter plot: {x_feature} vs {y_feature}") + + elif kind == "hist": + # Histogram for a single feature + sns.histplot(self.dataframe[x_feature], kde=True, bins=30, color="blue") + plt.xlabel(x_feature) + plt.title(title or f"Distribution of {x_feature}") + + elif kind == "box": + # Boxplot for feature distributions + sns.boxplot(y=self.dataframe[x_feature]) + plt.title(title or f"Boxplot of {x_feature}") + + elif kind == "geo": + # Spatial scatter plot using latitude and longitude + if "lat" not in self.input_features or "lon" not in self.input_features: + raise ValueError("Latitude ('lat') and longitude ('lon') must be input features for geo plots.") + sns.scatterplot( + x=self.dataframe["lon"], + y=self.dataframe["lat"], + hue=self.dataframe[self.label_name], + palette="viridis", + alpha=0.7, + ) + plt.xlabel("Longitude") + plt.ylabel("Latitude") + plt.title(title or "Geographic Distribution of Fuel Moisture") + + else: + raise ValueError(f"Plot kind '{kind}' is not supported.") + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + else: + plt.show()