Skip to content

Commit

Permalink
Update mmearth.py with plot class
Browse files Browse the repository at this point in the history
  • Loading branch information
preethatr07 authored Dec 17, 2024
1 parent befe1f2 commit bc1f441
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions torchgeo/datasets/mmearth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import torch
from torch import Tensor


import matplotlib.pyplot as plt
import numpy as np
import torch

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, lazy_import
Expand Down Expand Up @@ -618,3 +623,85 @@ def __len__(self) -> int:
length of the dataset
"""
return len(self.indices)



def plot_mmearth_sample(sample: dict, modality: str, band: str = None, color_map: str = 'viridis',
title: str = None, figsize: tuple = (10, 8), show_axes: bool = False):

"""
The method is used to plot a sample from the database while taking into consideration the particular way to handle each modality.
Parameters:
sample (dict): A sample of a part of mmearth dataset.
modality (str): The modality that we want to use in the plot(give for example ‘Sentinel2’, stomp in!, or dynamic world or cloud mask etc.).
band (str, optional): A particular band within the above mentioned modality which is intended for plotting purposes. None by default.
color_map (str, optional): The color palette for the plotting. Default value is set to be ‘viridis’.
title (str, optional): The feature which maps the titles to the plots. Default value set to None.
figsize (tuple, optional): It specifies the height and width of the figure in inches. Its default value is (10,8).
show_axes (bool, optional): Sets whether the plot axes are displayed or not. Default value is set to False.
"""
# Determine whether sample includes the modality
if modality not in sample:
raise ValueError("Modality '{modality}' has not been located in the sample!")

# fetch the information regarding the certain modality
modality_data = sample[modality]

if modality == 'sentinel2':
if band:
if band not in sample['avail_bands'][modality]:
raise ValueError(f"Band '{band}' has not been motioned for the including of the modality '{modality}'!")
band_index = sample['avail_bands'][modality].index(band)
modality_data = modality_data[band_index] # Provides extraction of band

else:
modality_data = modality_data[0] # If specific band is not asked, first band is used instead.

modality_data = modality_data/modality_data.max()

elif modality == 'sentinel1':
if band:
raise ValueError ("Sentinel-1 modality does not allow using a band. Use either 'asc' or 'desc' instead.")

if 'asc' in sample['avail_bands'][modality]:
modality_data = sample[modality]['asc']
elif 'desc' in sample['avail_bands'][modality]:
modality_data = sample[modality]['desc']
else:
raise ValueError("Sentinel-1 data missing 'asc' or 'desc' information.")

modality_data = modality_data.div(modality_data.max());

elif modality == 'dynamic_world':
modality_data = modality_data.astype(int);
color_map = 'tab20'

elif modality == 'cloud_mask':
modality_data = modality_data.astype(int);
color_map = 'binary'

elif modality == 'temperature':
color_map = 'coolwarm'

elif modality == 'precipitation':
color_map = 'Blues'

else:
if isinstance(modality_data, torch.Tensor):
modality_data = modality_data.numpy()
modality_data = modality_data.div(modality_data.max());

#Plotting
plt.figure(figsize=figsize)
plt.imshow(modality_data, cmap=color_map)
plt.title(title if title else f'{modality} - {band if band else "Band 1"}', fontsize=16)
plt.colorbar(label='Intensity')
if not show_axes:
plt.axis('off') # Hide axes for better visualization
plt.show()

# Example usage:
# plot_mmearth_sample(sample, modality='sentinel2', band='B1', color_map='viridis', title='Sentinel-2 Band 1')
# plot_mmearth_sample(sample, modality='dynamic_world', color_map='tab20', title='Land Cover')

0 comments on commit bc1f441

Please sign in to comment.