From 56e42640485663cdf11074f925866e1174beda48 Mon Sep 17 00:00:00 2001 From: preethatr07 Date: Tue, 17 Dec 2024 19:11:40 +0100 Subject: [PATCH] Update western_usa_live_fuel_moisture.py with plot class --- .../western_usa_live_fuel_moisture.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/torchgeo/datasets/western_usa_live_fuel_moisture.py b/torchgeo/datasets/western_usa_live_fuel_moisture.py index fe51f6ade8f..6052221644c 100644 --- a/torchgeo/datasets/western_usa_live_fuel_moisture.py +++ b/torchgeo/datasets/western_usa_live_fuel_moisture.py @@ -12,6 +12,9 @@ import pandas as pd import torch +import matplotlib.pyplot as plt +import seaborn as sns + from .errors import DatasetNotFoundError from .geo import NonGeoDataset from .utils import Path, which @@ -297,3 +300,77 @@ def _download(self) -> None: os.makedirs(self.root, exist_ok=True) azcopy = which('azcopy') azcopy('sync', self.url, self.root, '--recursive=true') + + + def plot( + self, + x_feature: str = None, + y_feature: str = None, + kind: str = "scatter", + title: str = None, + save_path: str = None, + ) -> None: + """Plot features or relationships within the dataset. + + Args: + x_feature: Name of the feature to plot on the x-axis. + y_feature: Name of the feature to plot on the y-axis. + Defaults to the label if not specified. + kind: Type of plot ('scatter', 'hist', 'box', or 'geo'). + title: Title of the plot. + save_path: If provided, save the plot to the given path. + """ + if x_feature not in self.input_features: + raise ValueError(f"'{x_feature}' is not a valid input feature.") + if y_feature is None: + y_feature = self.label_name + if y_feature not in self.input_features and y_feature != self.label_name: + raise ValueError(f"'{y_feature}' is not a valid feature or label.") + + plt.figure(figsize=(10, 6)) + + if kind == "scatter": + # Scatter plot for feature relationships + sns.scatterplot( + x=self.dataframe[x_feature], + y=self.dataframe[y_feature], + alpha=0.7, + ) + plt.xlabel(x_feature) + plt.ylabel(y_feature) + plt.title(title or f"Scatter plot: {x_feature} vs {y_feature}") + + elif kind == "hist": + # Histogram for a single feature + sns.histplot(self.dataframe[x_feature], kde=True, bins=30, color="blue") + plt.xlabel(x_feature) + plt.title(title or f"Distribution of {x_feature}") + + elif kind == "box": + # Boxplot for feature distributions + sns.boxplot(y=self.dataframe[x_feature]) + plt.title(title or f"Boxplot of {x_feature}") + + elif kind == "geo": + # Spatial scatter plot using latitude and longitude + if "lat" not in self.input_features or "lon" not in self.input_features: + raise ValueError("Latitude ('lat') and longitude ('lon') must be input features for geo plots.") + sns.scatterplot( + x=self.dataframe["lon"], + y=self.dataframe["lat"], + hue=self.dataframe[self.label_name], + palette="viridis", + alpha=0.7, + ) + plt.xlabel("Longitude") + plt.ylabel("Latitude") + plt.title(title or "Geographic Distribution of Fuel Moisture") + + else: + raise ValueError(f"Plot kind '{kind}' is not supported.") + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + else: + plt.show()