Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update with plot class for eddmaps,gbif,mm_earth,inaturalist,western_usa_live_fuel_mositure #2475

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion torchgeo/datasets/eddmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@

"""Dataset for EDDMapS."""

import os
import sys
from typing import Any

import numpy as np
import pandas as pd
from rasterio.crs import CRS

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import Tuple, Optional

Check failure on line 16 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

torchgeo/datasets/eddmaps.py:16:1: UP035 `typing.Tuple` is deprecated, use `tuple` instead

from .errors import DatasetNotFoundError
from .geo import GeoDataset
from .utils import BoundingBox, Path, disambiguate_timestamp
Expand Down Expand Up @@ -79,7 +83,7 @@
coords = (x, x, y, y, mint, maxt)
self.index.insert(i, coords)
i += 1

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve metadata indexed by query.

Expand All @@ -103,3 +107,104 @@
sample = {'crs': self.crs, 'bounds': bboxes}

return sample


def plot(
self,

Check failure on line 113 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

torchgeo/datasets/eddmaps.py:113:5: ANN001 Missing type annotation for function argument `self`
query: Optional[BoundingBox] = None,

Check failure on line 114 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP007)

torchgeo/datasets/eddmaps.py:114:12: UP007 Use `X | Y` for type annotations
title: str = "EDDMapS Dataset",
point_size: int = 20,
point_color: str = 'blue',
query_color: str = 'red',
annotate: bool = False,
figsize: Tuple[int, int] = (12, 8)

Check failure on line 120 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

torchgeo/datasets/eddmaps.py:120:14: UP006 Use `tuple` instead of `Tuple` for type annotation
) -> None:

""" Plot the dataset points with optional query bounding box
Args:

query: The query to look for points, in the form of a bounding box: (minx,maxx,miny,maxy,mint,maxt)
title: Title of the plot
point_size: The size of the points plotted
point_color: The default color of the points, in case no such map is provided
query_color: color for the points which fall into the query
annotate: Controls if the points with timestamps are annotated
figsize: Size of drawn figure in the shape: (width, height)

Raises:

ValueError: When no points could be plotted because none were valid.

"""

Check failure on line 138 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D201)

torchgeo/datasets/eddmaps.py:123:5: D201 No blank lines allowed before function docstring (found 1)

Check failure on line 138 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D202)

torchgeo/datasets/eddmaps.py:123:5: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 138 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D205)

torchgeo/datasets/eddmaps.py:123:5: D205 1 blank line required between summary line and description

Check failure on line 138 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D210)

torchgeo/datasets/eddmaps.py:123:5: D210 No whitespaces allowed surrounding docstring text

Check failure on line 138 in torchgeo/datasets/eddmaps.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D415)

torchgeo/datasets/eddmaps.py:123:5: D415 First line should end with a period, question mark, or exclamation point

# Filtering valid lat and long rows
valid_data = self.data.dropna(subset = [ 'Latitude' , 'Longitude'])
if valid_data.empty:
raise ValueError("No valid lat/long data to plot.")

fig, ax = plt.subplots(figsize=figsize)

# Plot-at-all points

ax.scatter(

valid_data['Longitude'],

valid_data['Latitude'],

s = point_size,

c = point_color,

edgecolor = 'k',

alpha = 0.6,

label = 'All Observations'

)

#highlighting queried points (If) provided bounding box query

if query:
minx, maxx, miny, maxy, mint, maxt = query
hits = self.index.intersection((minx,maxx,miny,maxy,mint, maxt))

# Get coordinates of hits to highlight
query_points = valid_data.iloc[[list(hits)]]
ax.scatter(
query_points['Longitude'],
query_points['Latitude'],
s = point_size * 1.5,
c = query_color,
edgecolor = 'white',
alpha = 0.8,
label = 'Query Results'
)

# Draw a bounding box
bbox_patch = patches.rectangle(
(minx, miny), maxx - minx, maxy - miny,
linewidth = 2, edgecolor = 'red', facecolor='none', linestyle = '--', label = "Query Bounding Box"
)
ax.add_patch(bbox_patch)

# Optional annotations
if annotate:
for _, row in valid_data.iterrows():
ax.annotate(
row['ObsDate'], (row['Longitude'], row['Latitude']),
fontsize=8, alpha=0.7, textcoords="offset points", xytext=(0, 5), ha='center'
)

# Plot styling
ax.set_title(title, fontsize=14)
ax.set_xlabel("Longitude", fontsize=12)
ax.set_ylabel("Latitude", fontsize=12)
ax.grid(True, linestyle='--', alpha=0.5)
ax.legend()

plt.show()



28 changes: 28 additions & 0 deletions torchgeo/datasets/gbif.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import pandas as pd
from rasterio.crs import CRS

import matplotlib.pyplot as plt


from .errors import DatasetNotFoundError
from .geo import GeoDataset
from .utils import BoundingBox, Path
Expand Down Expand Up @@ -140,3 +143,28 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
sample = {'crs': self.crs, 'bounds': bboxes}

return sample


def plot(self) -> None:
"""Represent in graphic any mentions as a map on a globe."""
# Extract latitude and longitude from the dataset
lat = self.data['decimalLatitude']
long = self.data['decimalLongitude']

# Remove all other rows except those that have latitude and longitude data available
valid = self.data.dropna(subset=['decimalLatitude', 'decimalLongitude'])

# Create a new figure
plt.figure(figsize=(10, 6))
plt.scatter(
valid['decimalLongitude'],
valid['decimalLatitude'],
c='b', # Color choice can be made here
s=10, # Change the size of the markers
alpha=0.5, # Change the transparency of the points
)
plt.title('Spatial Occurrence Distribution of GBIF Records')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.grid(True)
plt.show()
87 changes: 87 additions & 0 deletions torchgeo/datasets/inaturalist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import pandas as pd
from rasterio.crs import CRS

import matplotlib.pyplot as plt
import geopandas as gpd
from matplotlib import cm
import numpy as np

from .errors import DatasetNotFoundError
from .geo import GeoDataset
from .utils import BoundingBox, Path, disambiguate_timestamp
Expand Down Expand Up @@ -110,3 +115,85 @@ def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
sample = {'crs': self.crs, 'bounds': bboxes}

return sample



def plot(self, query: BoundingBox = None, time_range: tuple = None, cmap: str = 'viridis') -> None:
""" Plot the observations in a map when given a geographical bounding box as well as time frame which is optional.

Args:
query: Optional case for bounding box which is meant to cut the observations (minx,maxx,miny,maxy,mint,maxt).
time_range: Optional filtering in terms of exact date and location for (start_time, end_time).
cmap: A color map for the time sequence.
"""
# Step 1: If a query and/or time_range is provided, filter the dataset based on this criteria.
data = self._filter_data(query, time_range)

# Step 2: Prepare a GeoDataFrame for geospatial visualization by geographic reference.
gdf = gpd.GeoDataFrame(
data,
geometry=gpd.points_from_xy(data['longitude'], data['latitude']),
crs=self._crs
)


# Step 3: Illustrate the observations that were made.
fig, ax = plt.subplots(figsize=(10, 8))
gdf.plot(ax=ax, color=self._get_color_by_time(data, cmap=cmap), markersize=10, alpha=0.7)


# Step 4: Incorporate cartographic refinements (for example, coastlines, gridlines) to the map.
ax.set_title('iNaturalist Observations', fontsize=16)
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')
ax.grid(True)

# Insert a colorbar to show the progression of time
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=data['observed_on'].min(), vmax=data['observed_on'].max()))
sm.set_array([])
plt.colorbar(sm, ax=ax, label='Time of Observation')
plt.show()

def _filter_data(self, query: BoundingBox, time_range: tuple) -> pd.DataFrame:
"""This is a Helper function that helps to filter the dataset based on the bounding box provided by the query and a time range."""
# First stage filter on bounding box
if query:
data = self._filter_by_bbox(query)
else:
data = self._load_data()

# Now we filter on time range
if time_range:
data = data[(data['observed_on'] >= time_range[0]) & (data['observed_on'] <= time_range[1])]
return data


def _get_color_by_time(self, data: pd.DataFrame, cmap: str) -> np.ndarray:
"""Creates a mapping of time with color."""
norm = plt.Normalize(vmin=data['observed_on'].min(), vmax=data['observed_on'].max())
colormap = cm.get_cmap(cmap)
return colormap(norm(data['observed_on']))


def _filter_by_bbox(self, query: BoundingBox) -> pd.DataFrame:
"""Helper function that broadens filters with bounding box type parameters to filter the data frame."""
minx, maxx, miny, maxy, _, _ = query
data = self._load_data()
return data[(data['longitude'] >= minx) & (data['longitude'] <= maxx) &
(data['latitude'] >= miny) & (data['latitude'] <= maxy)]

def _load_data(self) -> pd.DataFrame:
"""Tries to get dataset from the CSV file."""
files = glob.glob(os.path.join(self.root, '**.csv'))
if not files:
raise DatasetNotFoundError(self)

data = pd.read_csv(
files[0],
engine='c',
usecols=['observed_on', 'latitude', 'longitude'],
)
return data



87 changes: 87 additions & 0 deletions torchgeo/datasets/mmearth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import torch
from torch import Tensor


import matplotlib.pyplot as plt
import numpy as np
import torch

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, lazy_import
Expand Down Expand Up @@ -618,3 +623,85 @@ def __len__(self) -> int:
length of the dataset
"""
return len(self.indices)



def plot_mmearth_sample(sample: dict, modality: str, band: str = None, color_map: str = 'viridis',
title: str = None, figsize: tuple = (10, 8), show_axes: bool = False):

"""
The method is used to plot a sample from the database while taking into consideration the particular way to handle each modality.

Parameters:

sample (dict): A sample of a part of mmearth dataset.
modality (str): The modality that we want to use in the plot(give for example ‘Sentinel2’, stomp in!, or dynamic world or cloud mask etc.).
band (str, optional): A particular band within the above mentioned modality which is intended for plotting purposes. None by default.
color_map (str, optional): The color palette for the plotting. Default value is set to be ‘viridis’.
title (str, optional): The feature which maps the titles to the plots. Default value set to None.
figsize (tuple, optional): It specifies the height and width of the figure in inches. Its default value is (10,8).
show_axes (bool, optional): Sets whether the plot axes are displayed or not. Default value is set to False.
"""
# Determine whether sample includes the modality
if modality not in sample:
raise ValueError("Modality '{modality}' has not been located in the sample!")

# fetch the information regarding the certain modality
modality_data = sample[modality]

if modality == 'sentinel2':
if band:
if band not in sample['avail_bands'][modality]:
raise ValueError(f"Band '{band}' has not been motioned for the including of the modality '{modality}'!")
band_index = sample['avail_bands'][modality].index(band)
modality_data = modality_data[band_index] # Provides extraction of band

else:
modality_data = modality_data[0] # If specific band is not asked, first band is used instead.

modality_data = modality_data/modality_data.max()

elif modality == 'sentinel1':
if band:
raise ValueError ("Sentinel-1 modality does not allow using a band. Use either 'asc' or 'desc' instead.")

if 'asc' in sample['avail_bands'][modality]:
modality_data = sample[modality]['asc']
elif 'desc' in sample['avail_bands'][modality]:
modality_data = sample[modality]['desc']
else:
raise ValueError("Sentinel-1 data missing 'asc' or 'desc' information.")

modality_data = modality_data.div(modality_data.max());

elif modality == 'dynamic_world':
modality_data = modality_data.astype(int);
color_map = 'tab20'

elif modality == 'cloud_mask':
modality_data = modality_data.astype(int);
color_map = 'binary'

elif modality == 'temperature':
color_map = 'coolwarm'

elif modality == 'precipitation':
color_map = 'Blues'

else:
if isinstance(modality_data, torch.Tensor):
modality_data = modality_data.numpy()
modality_data = modality_data.div(modality_data.max());

#Plotting
plt.figure(figsize=figsize)
plt.imshow(modality_data, cmap=color_map)
plt.title(title if title else f'{modality} - {band if band else "Band 1"}', fontsize=16)
plt.colorbar(label='Intensity')
if not show_axes:
plt.axis('off') # Hide axes for better visualization
plt.show()

# Example usage:
# plot_mmearth_sample(sample, modality='sentinel2', band='B1', color_map='viridis', title='Sentinel-2 Band 1')
# plot_mmearth_sample(sample, modality='dynamic_world', color_map='tab20', title='Land Cover')
Loading
Loading