Skip to content

Commit

Permalink
Fix for segmentation tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jul 2, 2024
1 parent fad350c commit a0376b8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
return {'image': image, 'crs': CRS.from_epsg(4326), 'bbox': query}

def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.length = length

def __getitem__(self, index: int) -> dict[str, Tensor]:
return {"image": torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}
return {'image': torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)}

def __len__(self) -> int:
return self.length
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datamodules/spacenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def __init__(
K.RandomVerticalFlip(p=0.5),
K.RandomSharpness(p=0.5),
K.ColorJitter(p=0.5, brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
data_keys=['image', 'mask'],
data_keys=None,
)
self.aug = K.AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.PadTo((448, 448)),
data_keys=['image', 'mask'],
data_keys=None,
)

def setup(self, stage: str) -> None:
Expand Down
8 changes: 8 additions & 0 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.nn as nn
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
Expand Down Expand Up @@ -225,6 +226,9 @@ def training_step(
Returns:
The loss tensor.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')

x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand All @@ -245,6 +249,8 @@ def validation_step(
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand Down Expand Up @@ -289,6 +295,8 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
Expand Down

0 comments on commit a0376b8

Please sign in to comment.