From fad350c44993c2773296a9659972043ae5c10f07 Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Mon, 1 Jul 2024 17:17:16 +0400 Subject: [PATCH] Fix for classification & regression tests --- torchgeo/datasets/cyclone.py | 4 +++- torchgeo/datasets/inria.py | 5 ++++- torchgeo/datasets/quakeset.py | 3 ++- torchgeo/datasets/skippd.py | 4 +++- torchgeo/datasets/sustainbench_crop_yield.py | 9 ++++++++- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index eccca9d7314..e8d4ed2aa97 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -125,10 +125,12 @@ def __getitem__(self, index: int) -> dict[str, Any]: ) sample: dict[str, Any] = {'image': self._load_image(directory)} - sample.update(self._load_features(directory)) + features = self._load_features(directory) + sample['label'] = features['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index 5b3db228499..6b6320b9ba0 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -217,7 +217,10 @@ def plot( show_predictions = 'prediction' in sample if show_mask: - mask = sample['mask'].numpy() + mask = sample['mask'] + if mask.ndim == 3 and mask.shape[0] == 1: + mask = mask.squeeze(0) + mask = mask.numpy() ncols += 1 if show_predictions: diff --git a/torchgeo/datasets/quakeset.py b/torchgeo/datasets/quakeset.py index ce5d9a3bd2c..eb3a1b4ddec 100644 --- a/torchgeo/datasets/quakeset.py +++ b/torchgeo/datasets/quakeset.py @@ -113,10 +113,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: label = torch.tensor(self.data[index]['label']) magnitude = torch.tensor(self.data[index]['magnitude']) - sample = {'image': image, 'label': label, 'magnitude': magnitude} + sample = {'image': image, 'label': label} if self.transforms is not None: sample = self.transforms(sample) + sample['magnitude'] = magnitude return sample diff --git a/torchgeo/datasets/skippd.py b/torchgeo/datasets/skippd.py index 0d111ae15b9..03bbf870edd 100644 --- a/torchgeo/datasets/skippd.py +++ b/torchgeo/datasets/skippd.py @@ -144,10 +144,12 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]: data and label at that index """ sample: dict[str, str | Tensor] = {'image': self._load_image(index)} - sample.update(self._load_features(index)) + features = self._load_features(index) + sample['label'] = features['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update({x: features[x] for x in features if x != 'label'}) return sample diff --git a/torchgeo/datasets/sustainbench_crop_yield.py b/torchgeo/datasets/sustainbench_crop_yield.py index 8eb410297e9..4f9b2362b4e 100644 --- a/torchgeo/datasets/sustainbench_crop_yield.py +++ b/torchgeo/datasets/sustainbench_crop_yield.py @@ -149,10 +149,17 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ sample: dict[str, Tensor] = {'image': self.images[index]} - sample.update(self.features[index]) + sample['label'] = self.features[index]['label'] if self.transforms is not None: sample = self.transforms(sample) + sample.update( + { + x: self.features[index][x] + for x in self.features[index] + if x != 'label' + } + ) return sample