Skip to content

Commit

Permalink
Fix for classification & regression tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 2, 2024
1 parent 705597e commit fad350c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 5 deletions.
4 changes: 3 additions & 1 deletion torchgeo/datasets/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
)

sample: dict[str, Any] = {'image': self._load_image(directory)}
sample.update(self._load_features(directory))
features = self._load_features(directory)
sample['label'] = features['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update({x: features[x] for x in features if x != 'label'})

return sample

Expand Down
5 changes: 4 additions & 1 deletion torchgeo/datasets/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def plot(
show_predictions = 'prediction' in sample

if show_mask:
mask = sample['mask'].numpy()
mask = sample['mask']
if mask.ndim == 3 and mask.shape[0] == 1:
mask = mask.squeeze(0)
mask = mask.numpy()
ncols += 1

if show_predictions:
Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datasets/quakeset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
label = torch.tensor(self.data[index]['label'])
magnitude = torch.tensor(self.data[index]['magnitude'])

sample = {'image': image, 'label': label, 'magnitude': magnitude}
sample = {'image': image, 'label': label}

if self.transforms is not None:
sample = self.transforms(sample)
sample['magnitude'] = magnitude

return sample

Expand Down
4 changes: 3 additions & 1 deletion torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def __getitem__(self, index: int) -> dict[str, str | Tensor]:
data and label at that index
"""
sample: dict[str, str | Tensor] = {'image': self._load_image(index)}
sample.update(self._load_features(index))
features = self._load_features(index)
sample['label'] = features['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update({x: features[x] for x in features if x != 'label'})

return sample

Expand Down
9 changes: 8 additions & 1 deletion torchgeo/datasets/sustainbench_crop_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,17 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
data and label at that index
"""
sample: dict[str, Tensor] = {'image': self.images[index]}
sample.update(self.features[index])
sample['label'] = self.features[index]['label']

if self.transforms is not None:
sample = self.transforms(sample)
sample.update(
{
x: self.features[index][x]
for x in self.features[index]
if x != 'label'
}
)

return sample

Expand Down

0 comments on commit fad350c

Please sign in to comment.