From b6b48cf5f366fbf7866f7b3d89a3e99140f590f7 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 12:21:12 -0600 Subject: [PATCH 001/108] DataModules: run all data augmentation on the GPU --- torchgeo/datamodules/deepglobelandcover.py | 7 -- torchgeo/datamodules/eurosat.py | 119 ++++++++++----------- torchgeo/datamodules/gid15.py | 9 -- torchgeo/datamodules/potsdam.py | 7 -- torchgeo/datamodules/vaihingen.py | 7 -- torchgeo/datasets/eurosat.py | 2 +- torchgeo/transforms/transforms.py | 35 +++--- 7 files changed, 81 insertions(+), 105 deletions(-) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 08871f38ad7..ee2da98a5d9 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from einops import rearrange from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -137,18 +136,12 @@ def on_after_batch_transfer( Returns: A batch of data """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - if self.trainer: if self.trainer.training: batch = self.train_transform(batch) elif self.trainer.validating or self.trainer.testing: batch = self.test_transform(batch) - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index cb672011fb6..b4b84180439 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -8,10 +8,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl import torch +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import EuroSAT +from ..transforms import AugmentationSequential class EuroSATDataModule(pl.LightningDataModule): @@ -22,46 +24,42 @@ class EuroSATDataModule(pl.LightningDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( - [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] - ) - - band_stds = torch.tensor( - [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] - ) + band_means = [ + 1354.40546513, + 1118.24399958, + 1042.92983953, + 947.62620298, + 1199.47283961, + 1999.79090914, + 2369.22292565, + 2296.82608323, + 732.08340178, + 12.11327804, + 1819.01027855, + 1118.92391149, + 2594.14080798, + ] + + band_stds = [ + 245.71762908, + 333.00778264, + 395.09249139, + 593.75055589, + 566.4170017, + 861.18399006, + 1086.63139075, + 1117.98170791, + 404.91978886, + 4.77584468, + 1002.58768311, + 761.30323499, + 1231.58581042, + ] def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for EuroSAT based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders @@ -74,20 +72,9 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - return sample + self.transform = AugmentationSequential( + Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -105,15 +92,11 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) + self.train_dataset = EuroSAT(split="train", **self.kwargs) + self.val_dataset = EuroSAT(split="val", **self.kwargs) + self.test_dataset = EuroSAT(split="test", **self.kwargs) - self.train_dataset = EuroSAT( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = EuroSAT(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = EuroSAT(split="test", transforms=transforms, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -126,7 +109,7 @@ def train_dataloader(self) -> DataLoader[Any]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -139,7 +122,7 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -152,6 +135,20 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + return self.transform(batch) + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 045509809e2..7f4c1dc962e 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from einops import rearrange from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -154,10 +153,6 @@ def on_after_batch_transfer( Returns: A batch of data """ - # Kornia requires masks to have a channel dimension - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - if self.trainer: if self.trainer.training: batch = self.train_transform(batch) @@ -166,10 +161,6 @@ def on_after_batch_transfer( elif self.trainer.predicting: batch = self.predict_transform(batch) - # Torchmetrics does not support masks with a channel dimension - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index d5024f197e9..84108ef8603 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from einops import rearrange from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -139,18 +138,12 @@ def on_after_batch_transfer( Returns: A batch of data """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - if self.trainer: if self.trainer.training: batch = self.train_transform(batch) elif self.trainer.validating or self.trainer.testing: batch = self.test_transform(batch) - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 2df86acdfd4..69ea79d243b 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from einops import rearrange from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -139,18 +138,12 @@ def on_after_batch_transfer( Returns: A batch of data """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - if self.trainer: if self.trainer.training: batch = self.train_transform(batch) elif self.trainer.validating or self.trainer.testing: batch = self.test_transform(batch) - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 85133f178c6..fea0ae73f8f 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -170,7 +170,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ image, label = self._load_image(index) - image = torch.index_select(image, dim=0, index=self.band_indices) + image = torch.index_select(image, dim=0, index=self.band_indices).float() sample = {"image": image, "label": label} if self.transforms is not None: diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 3990b9e82dc..8116301b195 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -7,6 +7,7 @@ import kornia import torch +from einops import rearrange from kornia.augmentation import GeometricAugmentationBase2D from kornia.augmentation.random_generator import CropGenerator from kornia.contrib import compute_padding, extract_tensor_patches @@ -17,7 +18,11 @@ # TODO: contribute these to Kornia and delete this file class AugmentationSequential(Module): - """Wrapper around kornia AugmentationSequential to handle input dicts.""" + """Wrapper around kornia AugmentationSequential to handle input dicts. + + .. deprecated:: 0.4 + Use :class:`kornia.augmentation.AugmentationSequential` instead. + """ def __init__(self, *args: Module, data_keys: List[str]) -> None: """Initialize a new augmentation sequential instance. @@ -49,13 +54,15 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: Returns: the augmented input """ - # Kornia augmentations require masks & boxes to be float - if "mask" in self.data_keys: - mask_dtype = sample["mask"].dtype - sample["mask"] = sample["mask"].to(torch.float) - if "boxes" in self.data_keys: - boxes_dtype = sample["boxes"].dtype - sample["boxes"] = sample["boxes"].to(torch.float) + # Kornia augmentations require all inputs to be float + dtypes = {} + for key in self.data_keys: + dtypes[key] = sample[key].dtype + sample[key] = sample[key].float() + + # Kornia requires masks to have a channel dimension + if "mask" in sample: + sample["mask"] = rearrange(sample["mask"], "b h w -> b () h w") inputs = [sample[k] for k in self.data_keys] outputs_list: Union[Tensor, List[Tensor]] = self.augs(*inputs) @@ -67,11 +74,13 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: } sample.update(outputs) - # Convert masks & boxes to previous dtype - if "mask" in self.data_keys: - sample["mask"] = sample["mask"].to(mask_dtype) - if "boxes" in self.data_keys: - sample["boxes"] = sample["boxes"].to(boxes_dtype) + # Convert all inputs back to their previous dtype + for key in self.data_keys: + sample[key] = sample[key].to(dtypes[key]) + + # Torchmetrics does not support masks with a channel dimension + if "mask" in sample: + sample["mask"] = rearrange(sample["mask"], "b () h w -> b h w") return sample From d5d3f589c18887229452cc892b93dd091d82f649 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 12:38:27 -0600 Subject: [PATCH 002/108] Passing tests --- tests/transforms/test_transforms.py | 16 ++++++++-------- torchgeo/datamodules/eurosat.py | 4 ++-- torchgeo/transforms/transforms.py | 8 ++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 68d9632d1b9..4b0c2630f99 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -24,7 +24,7 @@ def batch_gray() -> Dict[str, Tensor]: return { "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -43,7 +43,7 @@ def batch_rgb() -> Dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -64,7 +64,7 @@ def batch_multispectral() -> Dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -80,7 +80,7 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None: expected = { "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -103,7 +103,7 @@ def test_augmentation_sequential_rgb(batch_rgb: Dict[str, Tensor]) -> None: ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -130,7 +130,7 @@ def test_augmentation_sequential_multispectral( ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -157,7 +157,7 @@ def test_augmentation_sequential_image_only( ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -189,7 +189,7 @@ def test_sequential_transforms_augmentations( ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index b4b84180439..8ace880c321 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -7,7 +7,6 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -147,7 +146,8 @@ def on_after_batch_transfer( Returns: A batch of data """ - return self.transform(batch) + batch = self.transform(batch) + return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 8116301b195..f874e266ff4 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -21,7 +21,7 @@ class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts. .. deprecated:: 0.4 - Use :class:`kornia.augmentation.AugmentationSequential` instead. + Use :class:`kornia.augmentation.container.AugmentationSequential` instead. """ def __init__(self, *args: Module, data_keys: List[str]) -> None: @@ -55,9 +55,9 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: the augmented input """ # Kornia augmentations require all inputs to be float - dtypes = {} + dtype = {} for key in self.data_keys: - dtypes[key] = sample[key].dtype + dtype[key] = sample[key].dtype sample[key] = sample[key].float() # Kornia requires masks to have a channel dimension @@ -76,7 +76,7 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: # Convert all inputs back to their previous dtype for key in self.data_keys: - sample[key] = sample[key].to(dtypes[key]) + sample[key] = sample[key].to(dtype[key]) # Torchmetrics does not support masks with a channel dimension if "mask" in sample: From 0d557b9eb5c7b6d56744a701068bf90c0de387c5 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 14:05:39 -0600 Subject: [PATCH 003/108] Update BigEarthNet --- torchgeo/datamodules/bigearthnet.py | 177 +++++++++++++++++----------- 1 file changed, 110 insertions(+), 67 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 348f7773fb0..d398515f9ab 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -7,11 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import BigEarthNet +from ..transforms import AugmentationSequential class BigEarthNetDataModule(pl.LightningDataModule): @@ -22,51 +23,73 @@ class BigEarthNetDataModule(pl.LightningDataModule): # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) # min/max band statistics computed on 100k random samples - band_mins_raw = torch.tensor( - [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] - ) - band_maxs_raw = torch.tensor( - [ - 31.0, - 35.0, - 18556.0, - 20528.0, - 18976.0, - 17874.0, - 16611.0, - 16512.0, - 16394.0, - 16672.0, - 16141.0, - 16097.0, - 15336.0, - 15203.0, - ] - ) + band_mins_raw = [ + -70.0, + -72.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + ] + band_maxs_raw = [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] # min/max band statistics computed by percentile clipping the # above to samples to [2, 98] - band_mins = torch.tensor( - [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - ) - band_maxs = torch.tensor( - [ - 6.0, - 16.0, - 9859.0, - 12872.0, - 13163.0, - 14445.0, - 12477.0, - 12563.0, - 12289.0, - 15596.0, - 12183.0, - 9458.0, - 5897.0, - 5544.0, - ] - ) + band_mins = [ + -48.0, + -42.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ] + band_maxs = [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any @@ -95,12 +118,9 @@ def __init__( self.mins = self.band_mins[2:, None, None] self.maxs = self.band_maxs[2:, None, None] - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) - sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) - return sample + self.transform = AugmentationSequential( + Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -111,23 +131,23 @@ def prepare_data(self) -> None: BigEarthNet(split="train", **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. + + Args: + stage: stage to set up """ - transforms = Compose([self.preprocess]) - self.train_dataset = BigEarthNet( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = BigEarthNet( - split="val", transforms=transforms, **self.kwargs - ) - self.test_dataset = BigEarthNet( - split="test", transforms=transforms, **self.kwargs - ) + self.train_dataset = BigEarthNet(split="train", **self.kwargs) + self.val_dataset = BigEarthNet(split="val", **self.kwargs) + self.test_dataset = BigEarthNet(split="test", **self.kwargs) + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for training. - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" + Returns: + training data loader + """ return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -135,8 +155,12 @@ def train_dataloader(self) -> DataLoader[Any]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -144,8 +168,12 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, @@ -153,6 +181,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.BigEarthNet.plot`. From 73f567e1dccdc2fbf668f9ae87d813d9522bd1bc Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 14:25:46 -0600 Subject: [PATCH 004/108] Break ChesapeakeCVPR --- torchgeo/datamodules/chesapeake.py | 231 +++++++---------------------- 1 file changed, 50 insertions(+), 181 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 3163e2683c7..d618fe9a264 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,21 +3,21 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional -import torch -import torch.nn.functional as F -from pytorch_lightning.core.datamodule import LightningDataModule +import matplotlib.pyplot as plt +import pytorch_lightning as pl +from kornia.augmentation import CenterCrop, Normalize from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import ChesapeakeCVPR, stack_samples from ..samplers.batch import RandomBatchGeoSampler from ..samplers.single import GridGeoSampler +from ..transforms import AugmentationSequential -class ChesapeakeCVPRDataModule(LightningDataModule): +class ChesapeakeCVPRDataModule(pl.LightningDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. Uses the random splits defined per state to partition tiles into train, val, @@ -29,25 +29,27 @@ def __init__( train_splits: List[str], val_splits: List[str], test_splits: List[str], - patches_per_tile: int = 200, + num_tiles_per_batch: int = 64, + num_patches_per_tile: int = 200, patch_size: int = 256, - batch_size: int = 64, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, prior_smoothing_constant: float = 1e-4, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. + """Initialize a new LightningDataModule instance. Args: train_splits: The splits used to train the model, e.g. ["ny-train"] val_splits: The splits used to validate the model, e.g. ["ny-val"] test_splits: The splits used to test the model, e.g. ["ny-test"] - patches_per_tile: The number of patches per tile to sample - patch_size: The size of each patch in pixels (test patches will be 1.5 times - this size) - batch_size: The batch size to use in all created DataLoaders + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures num_workers: The number of workers to use in all created DataLoaders class_set: The high-resolution land cover class set to use - 5 or 7 use_prior_labels: Flag for using a prior over high-resolution classes @@ -72,12 +74,12 @@ def __init__( self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits - self.patches_per_tile = patches_per_tile + self.num_tiles_per_batch = num_tiles_per_batch + self.num_patches_per_tile = num_patches_per_tile self.patch_size = patch_size # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 2 - self.batch_size = batch_size self.num_workers = num_workers self.class_set = class_set self.use_prior_labels = use_prior_labels @@ -92,130 +94,9 @@ def __init__( else: self.layers = ["naip-new", "lc"] - def pad_to( - self, size: int = 512, image_value: int = 0, mask_value: int = 0 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a padding transform on a single sample. - - Args: - size: output image size - image_value: value to pad image with - mask_value: value to pad mask with - - Returns: - function to perform padding - """ - - def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - assert height <= size and width <= size - - height_pad = size - height - width_pad = size - width - - # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - # for a description of the format of the padding tuple - sample["image"] = F.pad( - sample["image"], - (0, width_pad, 0, height_pad), - mode="constant", - value=image_value, - ) - sample["mask"] = F.pad( - sample["mask"], - (0, width_pad, 0, height_pad), - mode="constant", - value=mask_value, - ) - return sample - - return pad_inner - - def center_crop( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a center crop transform on a single sample. - - Args: - size: output image size - - Returns: - function to perform center crop - """ - - def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - - y1 = round((height - size) / 2) - x1 = round((width - size) / 2) - sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] - sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] - - return sample - - return center_crop_inner - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Preprocesses a single sample. - - Args: - sample: sample dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - if "mask" in sample: - sample["mask"] = sample["mask"].squeeze() - if self.use_prior_labels: - sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) - sample["mask"] = F.normalize( - sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 - ) - else: - if self.class_set == 5: - sample["mask"][sample["mask"] == 5] = 4 - sample["mask"][sample["mask"] == 6] = 4 - sample["mask"] = sample["mask"].long() - - return sample - - def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Removes the bounding box property from a sample. - - Args: - sample: dictionary with geographic metadata - - Returns - sample without the bbox property - """ - del sample["bbox"] - return sample - - def nodata_check( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to check for nodata or mis-sized input. - - Args: - size: output image size - - Returns: - function to check for nodata values - """ - - def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - num_channels, height, width = sample["image"].shape - - if height < size or width < size: - sample["image"] = torch.zeros((num_channels, size, size)) - sample["mask"] = torch.zeros((size, size)) - - return sample - - return nodata_check_inner + self.transform = AugmentationSequential( + CenterCrop(patch_size), Normalize(mean=0, std=255) + ) def prepare_data(self) -> None: """Confirms that the dataset is downloaded on the local node. @@ -226,55 +107,21 @@ def prepare_data(self) -> None: ChesapeakeCVPR(splits=self.train_splits, layers=self.layers, **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. + """Initialize the main Dataset objects. - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + This method is called once per GPU per run. Args: stage: stage to set up """ - train_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - self.remove_bbox, - ] - ) - val_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - self.remove_bbox, - ] - ) - test_transforms = Compose( - [ - self.pad_to(self.original_patch_size, image_value=0, mask_value=0), - self.preprocess, - self.remove_bbox, - ] - ) - self.train_dataset = ChesapeakeCVPR( - splits=self.train_splits, - layers=self.layers, - transforms=train_transforms, - **self.kwargs, + splits=self.train_splits, layers=self.layers, **self.kwargs ) self.val_dataset = ChesapeakeCVPR( - splits=self.val_splits, - layers=self.layers, - transforms=val_transforms, - **self.kwargs, + splits=self.val_splits, layers=self.layers, **self.kwargs ) self.test_dataset = ChesapeakeCVPR( - splits=self.test_splits, - layers=self.layers, - transforms=test_transforms, - **self.kwargs, + splits=self.test_splits, layers=self.layers, **self.kwargs ) def train_dataloader(self) -> DataLoader[Any]: @@ -286,7 +133,7 @@ def train_dataloader(self) -> DataLoader[Any]: sampler = RandomBatchGeoSampler( self.train_dataset, size=self.original_patch_size, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, length=self.patches_per_tile * len(self.train_dataset), ) return DataLoader( @@ -309,7 +156,7 @@ def val_dataloader(self) -> DataLoader[Any]: ) return DataLoader( self.val_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, @@ -328,8 +175,30 @@ def test_dataloader(self) -> DataLoader[Any]: ) return DataLoader( self.test_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, ) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.ChesapeakeCVPR.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) From 1d3c4b333c6a34861c127f7b8606d209722f2c4b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 14:33:55 -0600 Subject: [PATCH 005/108] Update COWC --- conf/cowc_counting.yaml | 1 - tests/conf/cowc_counting.yaml | 1 - torchgeo/datamodules/cowc.py | 64 +++++++++++++++++------------------ 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 91f0d9921b6..eb6363d1e99 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -13,6 +13,5 @@ experiment: learning_rate_schedule_patience: 2 datamodule: root: "data/cowc_counting" - seed: 0 batch_size: 64 num_workers: 4 diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index a4f25698100..fc3218e8fef 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -10,6 +10,5 @@ experiment: datamodule: root: "tests/data/cowc_counting" download: true - seed: 0 batch_size: 1 num_workers: 0 diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index f26db807f94..2fa0870eec6 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -7,50 +7,39 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from torch import Generator +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader, random_split from ..datasets import COWCCounting +from ..transforms import AugmentationSequential class COWCCountingDataModule(pl.LightningDataModule): """LightningDataModule implementation for the COWC Counting dataset.""" def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for COWC Counting based DataLoaders. + """Initialize a new LightningDataModule instance. Args: - seed: The seed value to use when doing the dataset random_split batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.COWCCounting` """ super().__init__() - self.seed = seed self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 # scale to [0, 1] - if "label" in sample: - sample["label"] = sample["label"].float() - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. + """Initialize the main Dataset objects for use in :func:`setup`. This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. @@ -59,27 +48,21 @@ def prepare_data(self) -> None: COWCCounting(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. + """Initialize the main Dataset objects. - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + This method is called once per GPU per run. Args: stage: stage to set up """ - train_val_dataset = COWCCounting( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = COWCCounting( - split="test", transforms=self.preprocess, **self.kwargs - ) + train_val_dataset = COWCCounting(split="train", **self.kwargs) + self.test_dataset = COWCCounting(split="test", **self.kwargs) self.train_dataset, self.val_dataset = random_split( train_val_dataset, [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], - generator=Generator().manual_seed(self.seed), ) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -92,7 +75,7 @@ def train_dataloader(self) -> DataLoader[Any]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -105,7 +88,7 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -118,6 +101,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.COWC.plot`. From 22f11b994f2344d45aba034678943a2337d11436 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 14:42:55 -0600 Subject: [PATCH 006/108] Update Cyclone --- conf/cyclone.yaml | 1 - tests/conf/cyclone.yaml | 1 - torchgeo/datamodules/cyclone.py | 76 ++++++++++++++++----------------- 3 files changed, 36 insertions(+), 42 deletions(-) diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index c0e038b3834..68a67f9a306 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -13,6 +13,5 @@ experiment: learning_rate_schedule_patience: 2 datamodule: root: "data/cyclone" - seed: 0 batch_size: 32 num_workers: 4 diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index fd0fd42b412..b3323d28999 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -10,6 +10,5 @@ experiment: datamodule: root: "tests/data/cyclone" download: true - seed: 0 batch_size: 1 num_workers: 0 diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 0a7f1a4eab4..cd3d6bb78c3 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,12 +5,15 @@ from typing import Any, Dict, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from kornia.augmentation import Normalize from sklearn.model_selection import GroupShuffleSplit +from torch import Tensor from torch.utils.data import DataLoader, Subset from ..datasets import TropicalCyclone +from ..transforms import AugmentationSequential class TropicalCycloneDataModule(pl.LightningDataModule): @@ -19,50 +22,33 @@ class TropicalCycloneDataModule(pl.LightningDataModule): Implements 80/20 train/val splits based on hurricane storm ids. See :func:`setup` for more details. - .. versionchanged:: 0.4.0 + .. versionchanged:: 0.4 Class name changed from CycloneDataModule to TropicalCycloneDataModule to be consistent with TropicalCyclone dataset. """ def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. Args: - seed: The seed value to use when doing the sklearn based GroupShuffleSplit batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.TropicalCyclone` """ super().__init__() - self.seed = seed self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = ( - sample["image"].unsqueeze(0).repeat(3, 1, 1) - ) # convert from grayscale to 3 channel - if "label" in sample: - sample["label"] = torch.as_tensor(sample["label"]).float() - - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. + """Initialize the main Dataset objects for use in :func:`setup`. This includes optionally downloading the dataset. This is done once per node, while :func:`setup` is done once per GPU. @@ -73,8 +59,7 @@ def prepare_data(self) -> None: def setup(self, stage: Optional[str] = None) -> None: """Create the train/val/test splits based on the original Dataset objects. - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + This method is called once per GPU per run. We split samples between train/val by the ``storm_id`` property. I.e. all samples with the same ``storm_id`` value will be either in the train or the val @@ -86,13 +71,7 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - self.all_train_dataset = TropicalCyclone( - split="train", transforms=self.preprocess, **self.kwargs - ) - - self.all_test_dataset = TropicalCyclone( - split="test", transforms=self.preprocess, **self.kwargs - ) + self.all_train_dataset = TropicalCyclone(split="train", **self.kwargs) storm_ids = [] for item in self.all_train_dataset.collection: @@ -100,18 +79,16 @@ def setup(self, stage: Optional[str] = None) -> None: storm_ids.append(storm_id) train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + GroupShuffleSplit(test_size=0.2, n_splits=2).split( storm_ids, groups=storm_ids ) ) self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) + self.test_dataset = TropicalCyclone(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -124,7 +101,7 @@ def train_dataloader(self) -> DataLoader[Any]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -137,7 +114,7 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -149,3 +126,22 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.TropicalCyclone.plot`.""" + return self.test_dataset.plot(*args, **kwargs) From 3263b71d30675accdbfd175227609dfbcb8d5a1f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 14:52:48 -0600 Subject: [PATCH 007/108] Update ETCI2021 --- torchgeo/datamodules/etci2021.py | 89 ++++++++++++++------------------ 1 file changed, 39 insertions(+), 50 deletions(-) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 101c9a3a318..26575968f44 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -7,12 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch -from torch import Generator +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader, random_split -from torchvision.transforms import Normalize from ..datasets import ETCI2021 +from ..transforms import AugmentationSequential class ETCI2021DataModule(pl.LightningDataModule): @@ -24,55 +24,35 @@ class ETCI2021DataModule(pl.LightningDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( - [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701] - ) - - band_stds = torch.tensor( - [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622] - ) + band_means = [ + 128.02253931, + 128.02253931, + 128.02253931, + 128.11221701, + 128.11221701, + 128.11221701, + ] + band_stds = [89.8145088, 89.8145088, 89.8145088, 95.2797861, 95.2797861, 95.2797861] def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for ETCI2021 based DataLoaders. + """Initialize a new LightningDataModule instance. Args: - seed: The seed value to use when doing the dataset random_split batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.ETCI2021` """ super().__init__() - self.seed = seed self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Notably, moves the given water mask to act as an input layer. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - - if "mask" in sample: - flood_mask = sample["mask"][1] - flood_mask = (flood_mask > 0).long() - sample["mask"] = flood_mask - - return sample + self.transform = AugmentationSequential( + Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -83,31 +63,25 @@ def prepare_data(self) -> None: ETCI2021(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up """ - train_val_dataset = ETCI2021( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = ETCI2021( - split="val", transforms=self.preprocess, **self.kwargs - ) + train_val_dataset = ETCI2021(split="train", **self.kwargs) + self.test_dataset = ETCI2021(split="val", **self.kwargs) size_train_val = len(train_val_dataset) size_train = round(0.8 * size_train_val) size_val = size_train_val - size_train self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [size_train, size_val], - generator=Generator().manual_seed(self.seed), + train_val_dataset, [size_train, size_val] ) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for training. Returns: @@ -120,7 +94,7 @@ def train_dataloader(self) -> DataLoader[Any]: shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for validation. Returns: @@ -133,7 +107,7 @@ def val_dataloader(self) -> DataLoader[Any]: shuffle=False, ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for testing. Returns: @@ -146,6 +120,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.ETCI2021.plot`.""" return self.test_dataset.plot(*args, **kwargs) From 5d32674945541d190c218f9638f8dd30adb43914 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 15:57:24 -0600 Subject: [PATCH 008/108] mypy fixes --- torchgeo/datamodules/bigearthnet.py | 107 +++++++++++----------------- torchgeo/datamodules/chesapeake.py | 4 +- 2 files changed, 46 insertions(+), 65 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index d398515f9ab..9ce46783c63 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +import torch from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader @@ -23,73 +24,51 @@ class BigEarthNetDataModule(pl.LightningDataModule): # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) # min/max band statistics computed on 100k random samples - band_mins_raw = [ - -70.0, - -72.0, - 1.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 1.0, - 0.0, - 0.0, - ] - band_maxs_raw = [ - 31.0, - 35.0, - 18556.0, - 20528.0, - 18976.0, - 17874.0, - 16611.0, - 16512.0, - 16394.0, - 16672.0, - 16141.0, - 16097.0, - 15336.0, - 15203.0, - ] + band_mins_raw = torch.tensor( + [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] + ) + band_maxs_raw = torch.tensor( + [ + 31.0, + 35.0, + 18556.0, + 20528.0, + 18976.0, + 17874.0, + 16611.0, + 16512.0, + 16394.0, + 16672.0, + 16141.0, + 16097.0, + 15336.0, + 15203.0, + ] + ) # min/max band statistics computed by percentile clipping the # above to samples to [2, 98] - band_mins = [ - -48.0, - -42.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ] - band_maxs = [ - 6.0, - 16.0, - 9859.0, - 12872.0, - 13163.0, - 14445.0, - 12477.0, - 12563.0, - 12289.0, - 15596.0, - 12183.0, - 9458.0, - 5897.0, - 5544.0, - ] + band_mins = torch.tensor( + [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ) + band_maxs = torch.tensor( + [ + 6.0, + 16.0, + 9859.0, + 12872.0, + 13163.0, + 14445.0, + 12477.0, + 12563.0, + 12289.0, + 15596.0, + 12183.0, + 9458.0, + 5897.0, + 5544.0, + ] + ) def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index d618fe9a264..c0e288d4a25 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -95,7 +95,9 @@ def __init__( self.layers = ["naip-new", "lc"] self.transform = AugmentationSequential( - CenterCrop(patch_size), Normalize(mean=0, std=255) + CenterCrop(patch_size), + Normalize(mean=0, std=255), + data_keys=["image", "mask"], ) def prepare_data(self) -> None: From 353ba01efe6afafbe90cf0ff440b118f65fa83f2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 16:04:45 -0600 Subject: [PATCH 009/108] Update FAIR1M --- torchgeo/datamodules/fair1m.py | 66 +++++++++++++++------------------- 1 file changed, 28 insertions(+), 38 deletions(-) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index da923be83a8..ab370d28c41 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -3,32 +3,19 @@ """FAIR1M datamodule.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader from ..datasets import FAIR1M +from ..transforms import AugmentationSequential from .utils import dataset_split -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable number of boxes. - - Args: - batch: list of sample dicts return by dataset - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output - - class FAIR1MDataModule(pl.LightningDataModule): """LightningDataModule implementation for the FAIR1M dataset. @@ -43,13 +30,13 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for FAIR1M based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + val_split_pct: Percentage of the dataset to use as a validation set + test_split_pct: Percentage of the dataset to use as a test set **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.FAIR1M` """ @@ -60,18 +47,9 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. @@ -81,12 +59,12 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - self.dataset = FAIR1M(transforms=self.preprocess, **self.kwargs) + self.dataset = FAIR1M(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - def train_dataloader(self) -> DataLoader[Any]: + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for training. Returns: @@ -97,10 +75,9 @@ def train_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, - collate_fn=collate_fn, ) - def val_dataloader(self) -> DataLoader[Any]: + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for validation. Returns: @@ -111,10 +88,9 @@ def val_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - collate_fn=collate_fn, ) - def test_dataloader(self) -> DataLoader[Any]: + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Return a DataLoader for testing. Returns: @@ -125,9 +101,23 @@ def test_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - collate_fn=collate_fn, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.FAIR1M.plot`. From 7e1d500fdca02b343b679a4ead3050436c8c7b6c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 16:18:48 -0600 Subject: [PATCH 010/108] Update Inria --- torchgeo/datamodules/inria.py | 244 +++++++++++----------------------- 1 file changed, 81 insertions(+), 163 deletions(-) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 2c72de9b69c..cb93e952750 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,33 +3,21 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union -import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch -import torchvision.transforms as T -from einops import rearrange -from kornia.contrib import compute_padding, extract_tensor_patches -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate +from kornia.augmentation import Normalize, RandomVerticalFlip, RandomHorizontalFlip +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split -def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - """Flatten wrapper.""" - r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] - r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) - if "mask" in r_batch: - r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) - - return r_batch - - class InriaAerialImageLabelingDataModule(pl.LightningDataModule): """LightningDataModule implementation for the InriaAerialImageLabeling dataset. @@ -39,209 +27,139 @@ class InriaAerialImageLabelingDataModule(pl.LightningDataModule): .. versionadded:: 0.3 """ - h, w = 5000, 5000 - def __init__( self, - batch_size: int = 32, + num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, num_workers: int = 0, val_split_pct: float = 0.1, test_split_pct: float = 0.1, - patch_size: Union[int, Tuple[int, int]] = 512, - num_patches_per_tile: int = 32, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for InriaAerialImageLabeling. + """Initialize a new LightningDataModule instance. + + The Inria Aerial Image Labeling dataset contains images that are too large to + pass directly through a model. Instead, we randomly sample patches from image + tiles during training and chop up image tiles into patch grids during + evaluation. During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures num_workers: The number of workers to use in all created DataLoaders val_split_pct: What percentage of the dataset to use as a validation set test_split_pct: What percentage of the dataset to use as a test set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: Number of random patches per sample **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.InriaAerialImageLabeling` """ super().__init__() - self.batch_size = batch_size + + self.num_tiles_per_batch = num_tiles_per_batch + self.num_patches_per_tile = num_patches_per_tile + self.patch_size = _to_tuple(patch_size) self.num_workers = num_workers self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.patch_size = _to_tuple(patch_size) - self.num_patches_per_tile = num_patches_per_tile self.kwargs = kwargs - self.augmentations = K.AugmentationSequential( - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - data_keys=["input", "mask"], + self.train_transform = AugmentationSequential( + Normalize(mean=0, std=255), + RandomHorizontalFlip(p=0.5), + RandomVerticalFlip(p=0.5), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["image", "mask"], ) - self.random_crop = K.AugmentationSequential( - K.RandomCrop(self.patch_size, p=1.0, keepdim=False), - data_keys=["input", "mask"], + self.test_transform = AugmentationSequential( + Normalize(mean=0, std=255), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], ) - def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Extract patches from single sample.""" - assert sample["image"].ndim == 3 - _, h, w = sample["image"].shape - - padding = compute_padding((h, w), self.patch_size) - sample["original_shape"] = (h, w) - sample["patch_shape"] = self.patch_size - sample["padding"] = padding - sample["image"] = extract_tensor_patches( - sample["image"].unsqueeze(0), - self.patch_size, - self.patch_size, - padding=padding, - ) - # Needed for reconstruction of patches later - sample["num_patches"] = sample["image"].shape[1] - sample["image"] = rearrange(sample["image"], "b n c h w -> (b n) c h w") - return sample - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) - - if "mask" in sample: - sample["mask"] = rearrange(sample["mask"], "h w -> () h w") - - return sample - - def n_random_crop(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Get n random crops.""" - images, masks = [], [] - for _ in range(self.num_patches_per_tile): - image, mask = sample["image"], sample["mask"] - # RandomCrop needs image and mask to be in float - mask = mask.to(torch.float) - image, mask = self.random_crop(image, mask) - images.append(image.squeeze()) - masks.append(mask.squeeze(0).long()) - sample["image"] = torch.stack(images) # (t,c,h,w) - sample["mask"] = torch.stack(masks) # (t, 1, h, w) - return sample - def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. This method is called once per GPU per run. """ - train_transforms = T.Compose([self.preprocess, self.n_random_crop]) - test_transforms = T.Compose([self.preprocess, self.patch_sample]) - - self.dataset = InriaAerialImageLabeling( - split="train", transforms=train_transforms, **self.kwargs + dataset = InriaAerialImageLabeling(split="train", **self.kwargs) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, self.val_split_pct, self.test_split_pct ) + self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - self.test_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - if self.test_split_pct > 0.0: - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, - val_pct=self.val_split_pct, - test_pct=self.test_split_pct, - ) - else: - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct - ) - self.test_dataset = self.val_dataset - else: - self.train_dataset = self.dataset - self.val_dataset = self.dataset - self.test_dataset = self.dataset - - self.predict_dataset = InriaAerialImageLabeling( - split="test", transforms=test_transforms, **self.kwargs - ) + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for training. - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" + Returns: + training data loader + """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.num_patches_per_tile, num_workers=self.num_workers, - collate_fn=collate_wrapper, shuffle=True, ) - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for validation. + + Returns: + validation data loader + """ return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=False, + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for testing. + + Returns: + testing data loader + """ return DataLoader( - self.test_dataset, - batch_size=1, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=False, + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) - def predict_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for prediction.""" + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Return a DataLoader for prediction. + + Returns: + prediction data loader + """ return DataLoader( self.predict_dataset, batch_size=1, num_workers=self.num_workers, - collate_fn=collate_wrapper, shuffle=False, ) def on_after_batch_transfer( - self, batch: Dict[str, Any], dataloader_idx: int - ) -> Dict[str, Any]: + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: """Apply augmentations to batch after transferring to GPU. Args: - batch (dict): A batch of data that needs to be altered or augmented. - dataloader_idx (int): The index of the dataloader to which the batch - belongs. + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs Returns: - dict: A batch of data + A batch of data """ - # Training - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - and self.augmentations is not None - ): - batch["mask"] = batch["mask"].to(torch.float) - batch["image"], batch["mask"] = self.augmentations( - batch["image"], batch["mask"] - ) - batch["mask"] = batch["mask"].to(torch.long) - - # Validation - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif ( + self.trainer.validating + or self.trainer.testing + or self.trainer.predicting + ): + batch = self.test_transform(batch) + return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: From b1b90b4b5432819a6e7f656c42e847e33eadc4d4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 16:34:18 -0600 Subject: [PATCH 011/108] Update LandCoverAI --- torchgeo/datamodules/inria.py | 2 +- torchgeo/datamodules/landcoverai.py | 128 ++++++++++++---------------- 2 files changed, 54 insertions(+), 76 deletions(-) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index cb93e952750..7f246765925 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from kornia.augmentation import Normalize, RandomVerticalFlip, RandomHorizontalFlip +from kornia.augmentation import Normalize, RandomHorizontalFlip, RandomVerticalFlip from torch import Tensor from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 67f952e37b3..fc1ae71df76 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -5,12 +5,21 @@ from typing import Any, Dict, Optional -import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl +from kornia.augmentation import ( + ColorJitter, + Normalize, + RandomHorizontalFlip, + RandomRotation, + RandomSharpness, + RandomVerticalFlip, +) +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import LandCoverAI +from ..transforms import AugmentationSequential class LandCoverAIDataModule(pl.LightningDataModule): @@ -22,7 +31,7 @@ class LandCoverAIDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for LandCover.ai based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders @@ -35,67 +44,25 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - # Kornia expects masks to be floats with a channel dimension - x = batch["image"] - y = batch["mask"].float().unsqueeze(1) - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input", "mask"], - ) - x, y = train_augmentations(x, y) - - # torchmetrics expects masks to be longs without a channel dimension - batch["image"] = x - batch["mask"] = y.squeeze(1).long() - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - if "mask" in sample: - sample["mask"] = sample["mask"].long() - - return sample + self.train_transform = AugmentationSequential( + Normalize(mean=0, std=255), + RandomRotation(p=0.5, degrees=90), + RandomHorizontalFlip(p=0.5), + RandomVerticalFlip(p=0.5), + RandomSharpness(p=0.5), + ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image", "mask"], + ) + self.test_transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image", "mask"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -113,20 +80,11 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LandCoverAI( - split="train", transforms=train_transforms, **self.kwargs - ) + self.train_dataset = LandCoverAI(split="train", **self.kwargs) - self.val_dataset = LandCoverAI( - split="val", transforms=val_test_transforms, **self.kwargs - ) + self.val_dataset = LandCoverAI(split="val", **self.kwargs) - self.test_dataset = LandCoverAI( - split="test", transforms=val_test_transforms, **self.kwargs - ) + self.test_dataset = LandCoverAI(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -167,6 +125,26 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.LandCoverAI.plot`. From d0dc2741de84ab4391d91f1664dc77c3515bf545 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 16:53:08 -0600 Subject: [PATCH 012/108] Update LoveDA --- torchgeo/datamodules/loveda.py | 50 ++++++++++++++++------------------ 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index a82233f9844..70a76d25a8e 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -7,9 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +from kornia import Normalize +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import LoveDA +from ..transforms import AugmentationSequential class LoveDADataModule(pl.LightningDataModule): @@ -36,19 +39,9 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -66,21 +59,11 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - train_transforms = self.preprocess - val_predict_transforms = self.preprocess - - self.train_dataset = LoveDA( - split="train", transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = LoveDA( - split="val", transforms=val_predict_transforms, **self.kwargs - ) + self.train_dataset = LoveDA(split="train", **self.kwargs) + self.val_dataset = LoveDA(split="val", **self.kwargs) # Test set masks are not public, use for prediction instead - self.predict_dataset = LoveDA( - split="test", transforms=val_predict_transforms, **self.kwargs - ) + self.predict_dataset = LoveDA(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -121,6 +104,21 @@ def predict_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.LoveDA.plot`. From f5513b7b63bb918c319e5bdca60bd2e4db6de72c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 17:08:55 -0600 Subject: [PATCH 013/108] Update NAIP --- torchgeo/datamodules/naip.py | 90 ++++++++++++------------------------ 1 file changed, 30 insertions(+), 60 deletions(-) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index f9e9038da8d..a8fa9510dee 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -7,12 +7,13 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +from kornia import Normalize +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..transforms import AugmentationSequential class NAIPChesapeakeDataModule(pl.LightningDataModule): @@ -21,15 +22,13 @@ class NAIPChesapeakeDataModule(pl.LightningDataModule): Uses the train/val/test splits from the dataset. """ - # TODO: tune these hyperparams - length = 1000 - stride = 128 - def __init__( self, batch_size: int = 64, num_workers: int = 0, patch_size: int = 256, + stride: int = 128, + length: int = 1000, **kwargs: Any, ) -> None: """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. @@ -38,6 +37,8 @@ def __init__( batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders patch_size: size of patches to sample + stride: stride of grid sampler + length: epoch size **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and :class:`~torchgeo.datasets.Chesapeake13` @@ -47,6 +48,8 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers self.patch_size = patch_size + self.stride = stride + self.length = length self.naip_kwargs = {} self.chesapeake_kwargs = {} @@ -56,44 +59,9 @@ def __init__( elif key.startswith("chesapeake_"): self.chesapeake_kwargs[key[11:]] = val - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the NAIP Dataset. - - Args: - sample: NAIP image dictionary - - Returns: - preprocessed NAIP data - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - return sample - - def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Chesapeake Dataset. - - Args: - sample: Chesapeake mask dictionary - - Returns: - preprocessed Chesapeake data - """ - sample["mask"] = sample["mask"].long()[0] - - return sample - - def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Removes the bounding box property from a sample. - - Args: - sample: dictionary with geographic metadata - - Returns - sample without the bbox property - """ - del sample["bbox"] - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image", "mask"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -111,21 +79,8 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: state to set up """ - # TODO: these transforms will be applied independently, this won't work if we - # add things like random horizontal flip - - naip_transforms = Compose([self.preprocess, self.remove_bbox]) - chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox]) - - self.chesapeake = Chesapeake13( - transforms=chesapeak_transforms, **self.chesapeake_kwargs - ) - self.naip = NAIP( - crs=self.chesapeake.crs, - res=self.chesapeake.res, - transforms=naip_transforms, - **self.naip_kwargs, - ) + self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) + self.naip = NAIP(**self.naip_kwargs) self.dataset = self.chesapeake & self.naip # TODO: figure out better train/val/test split @@ -187,6 +142,21 @@ def test_dataloader(self) -> DataLoader[Any]: collate_fn=stack_samples, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: """Run NAIP and Chesapeake plot methods. From 95253638082d260ff848cf5b174abec8677b4a7d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 20:43:19 -0600 Subject: [PATCH 014/108] Update NASA --- torchgeo/datamodules/loveda.py | 2 +- torchgeo/datamodules/naip.py | 2 +- torchgeo/datamodules/nasa_marine_debris.py | 58 +++++++++------------- 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 70a76d25a8e..6836041035b 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from kornia import Normalize +from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index a8fa9510dee..4a34a0a1ce4 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from kornia import Normalize +from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index d22324ce2fc..f3d4b33026e 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,34 +3,19 @@ """NASA Marine Debris datamodule.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from kornia.augmentation import Normalize from torch import Tensor from torch.utils.data import DataLoader from ..datasets import NASAMarineDebris +from ..transforms import AugmentationSequential from .utils import dataset_split -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable boxes. - - Args: - batch: list of sample dicts return by dataset - - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] - return output - - class NASAMarineDebrisDataModule(pl.LightningDataModule): """LightningDataModule implementation for the NASA Marine Debris dataset. @@ -45,7 +30,7 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders @@ -62,18 +47,9 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -91,7 +67,7 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - self.dataset = NASAMarineDebris(transforms=self.preprocess, **self.kwargs) + self.dataset = NASAMarineDebris(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) @@ -107,7 +83,6 @@ def train_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, - collate_fn=collate_fn, ) def val_dataloader(self) -> DataLoader[Any]: @@ -121,7 +96,6 @@ def val_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - collate_fn=collate_fn, ) def test_dataloader(self) -> DataLoader[Any]: @@ -135,9 +109,23 @@ def test_dataloader(self) -> DataLoader[Any]: batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, - collate_fn=collate_fn, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`. From b087b748046d0ed80a1a0e27497215d0f660789b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 20:51:17 -0600 Subject: [PATCH 015/108] Update OSCD --- torchgeo/datamodules/oscd.py | 161 ++++++++++++++++------------------- 1 file changed, 74 insertions(+), 87 deletions(-) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 90a4c94797b..1a1babe4c0b 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,17 +3,19 @@ """OSCD datamodule.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union -import kornia.augmentation as K +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch -from einops import repeat -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate -from torchvision.transforms import Compose, Normalize +from kornia.augmentation import Normalize +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import OSCD +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split @@ -64,33 +66,40 @@ class OSCDDataModule(pl.LightningDataModule): def __init__( self, - train_batch_size: int = 32, - num_workers: int = 0, + num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, - patch_size: Tuple[int, int] = (64, 64), - num_patches_per_tile: int = 32, - pad_size: Tuple[int, int] = (1280, 1280), + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for OSCD based DataLoaders. + """Initialize a new LightningDataModule instance. + + The OSCD dataset contains images that are too large to pass + directly through a model. Instead, we randomly sample patches from image tiles + during training and chop up image tiles into patch grids during evaluation. + During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - train_batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: number of random patches per sample - pad_size: size to pad images to during val/test steps + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures + val_split_pct: The percentage of the dataset to use as a validation set + num_workers: The number of workers to use for parallel data loading **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.OSCD` """ super().__init__() - self.train_batch_size = train_batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.patch_size = patch_size + + self.num_tiles_per_batch = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct + self.num_workers = num_workers self.kwargs = kwargs bands = kwargs.get("bands", "all") @@ -98,19 +107,16 @@ def __init__( self.band_means = self.band_means[[3, 2, 1]] self.band_stds = self.band_stds[[3, 2, 1]] - self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True + self.train_transform = AugmentationSequential( + Normalize(mean=self.band_means, std=self.band_stds), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["input", "mask"], + ) + self.test_transform = AugmentationSequential( + Normalize(mean=self.band_means, std=self.band_stds), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], ) - self.padto = K.PadTo(pad_size) - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - sample["image"] = torch.flatten(sample["image"], 0, 1) - return sample def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -121,68 +127,22 @@ def prepare_data(self) -> None: OSCD(split="train", **self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. """ - - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - images, masks = [], [] - for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2).float() - image, mask = self.rcrop(sample["image"], mask) - mask = mask.squeeze()[0] - images.append(image.squeeze()) - masks.append(mask.long()) - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = self.padto(sample["image"])[0] - sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to a fixed size to avoid issues - # with the upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = OSCD(split="train", transforms=train_transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - val_dataset = OSCD(split="train", transforms=test_transforms, **self.kwargs) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset - self.val_dataset = train_dataset - - self.test_dataset = OSCD( - split="test", transforms=test_transforms, **self.kwargs + train_dataset = OSCD(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, val_pct=self.val_split_pct ) + self.test_dataset = OSCD(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training.""" - - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch - ) - r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) - r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) - return r_batch - return DataLoader( self.train_dataset, - batch_size=self.train_batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, - collate_fn=collate_wrapper, shuffle=True, ) @@ -197,3 +157,30 @@ def test_dataloader(self) -> DataLoader[Any]: return DataLoader( self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.OSCD.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) From 128d27955d196d5bf8dd1f15d82c2838cdc43834 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 20:59:14 -0600 Subject: [PATCH 016/108] Update RESISC45 --- torchgeo/datamodules/resisc45.py | 119 +++++++++++++------------------ 1 file changed, 49 insertions(+), 70 deletions(-) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 651bb47704a..3b61958a84a 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -8,11 +8,11 @@ import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import RESISC45 +from ..transforms import AugmentationSequential class RESISC45DataModule(pl.LightningDataModule): @@ -21,13 +21,13 @@ class RESISC45DataModule(pl.LightningDataModule): Uses the train/val/test splits from the dataset. """ - band_means = torch.tensor([0.36820969, 0.38083247, 0.34341029]) - band_stds = torch.tensor([0.20339924, 0.18524736, 0.18455448]) + band_means = [127.86820969, 127.88083247, 127.84341029] + band_stds = [51.8668062, 47.2380768, 47.0613924] def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for RESISC45 based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders @@ -40,63 +40,26 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.norm = Normalize(self.band_means, self.band_stds) - - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - x = batch["image"] - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.RandomErasing(p=0.1), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input"], - ) - x = train_augmentations(x) - - batch["image"] = x - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - return sample + self.train_transform = AugmentationSequential( + K.Normalize(mean=self.band_means, std=self.band_stds), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.RandomErasing(p=0.1), + K.ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image"], + ) + self.test_transform = AugmentationSequential( + K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -107,20 +70,16 @@ def prepare_data(self) -> None: RESISC45(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - self.train_dataset = RESISC45( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = RESISC45(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = RESISC45(split="test", transforms=transforms, **self.kwargs) + self.train_dataset = RESISC45(split="train", **self.kwargs) + self.val_dataset = RESISC45(split="val", **self.kwargs) + self.test_dataset = RESISC45(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -161,6 +120,26 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.RESISC45.plot`. From 6b41c2d1cfae239c24e452feaf1662ffff1e124b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:03:20 -0600 Subject: [PATCH 017/108] Update SEN12MS --- torchgeo/datamodules/sen12ms.py | 53 +++++++++------------------------ 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 8253791211c..cdc01c2cf08 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -3,8 +3,9 @@ """SEN12MS datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional +import matplotlib.pyplot as plt import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit @@ -50,16 +51,14 @@ class SEN12MSDataModule(pl.LightningDataModule): def __init__( self, - seed: int = 0, band_set: str = "all", batch_size: int = 64, num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for SEN12MS based DataLoaders. + """Initialize a new LightningDataModule instance. Args: - seed: The seed value to use when doing the sklearn based ShuffleSplit band_set: The subset of S1/S2 bands to use. Options are: "all", "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: B2, B3, B4, B8, B11, and B12. @@ -71,43 +70,16 @@ def __init__( super().__init__() assert band_set in SEN12MS.BAND_SETS.keys() - self.seed = seed self.band_set = band_set self.bands = SEN12MS.BAND_SETS[band_set] self.batch_size = batch_size self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - - if self.band_set == "all": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 - elif self.band_set == "s1": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - else: - sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 - - if "mask" in sample: - sample["mask"] = sample["mask"][0, :, :].long() - sample["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, sample["mask"]) - - return sample - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. + """Initialize the main Dataset objects. - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + This method is called once per GPU per run. We split samples between train and val geographically with proportions of 80/20. This mimics the geographic test set split. @@ -117,13 +89,9 @@ def setup(self, stage: Optional[str] = None) -> None: """ season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - self.all_train_dataset = SEN12MS( - split="train", bands=self.bands, transforms=self.preprocess, **self.kwargs - ) + self.all_train_dataset = SEN12MS(split="train", bands=self.bands, **self.kwargs) - self.all_test_dataset = SEN12MS( - split="test", bands=self.bands, transforms=self.preprocess, **self.kwargs - ) + self.all_test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" # This patch will belong to the scene that is uniquelly identified by its @@ -187,3 +155,10 @@ def test_dataloader(self) -> DataLoader[Any]: num_workers=self.num_workers, shuffle=False, ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run :meth:`torchgeo.datasets.SEN12MS.plot`. + + .. versionadded:: 0.4 + """ + return self.test_dataset.plot(*args, **kwargs) From 37905281a897c104d6d98c3f7499c0b707487356 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:08:25 -0600 Subject: [PATCH 018/108] Update So2Sat --- torchgeo/datamodules/so2sat.py | 146 ++++++++++++--------------------- 1 file changed, 54 insertions(+), 92 deletions(-) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 1f331ea38c0..298ced010f4 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,15 +3,16 @@ """So2Sat datamodule.""" -from typing import Any, Dict, Optional, cast +from typing import Any, Dict, Optional import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import So2Sat +from ..transforms import AugmentationSequential class So2SatDataModule(pl.LightningDataModule): @@ -20,35 +21,31 @@ class So2SatDataModule(pl.LightningDataModule): Uses the train/val/test splits from the dataset. """ - band_means = torch.tensor( - [ - 0.12375696117681859, - 0.1092774636368323, - 0.1010855203267882, - 0.1142398616114001, - 0.1592656692023089, - 0.18147236008771792, - 0.1745740312291377, - 0.19501607349635292, - 0.15428468872076637, - 0.10905050699570007, - ] - ) - - band_stds = torch.tensor( - [ - 0.03958795985905458, - 0.047778262752410296, - 0.06636616706371974, - 0.06358874912497474, - 0.07744387147984592, - 0.09101635085921553, - 0.09218466562387101, - 0.10164581233948201, - 0.09991773043519253, - 0.08780632509122865, - ] - ) + band_means = [ + 0.12375696117681859, + 0.1092774636368323, + 0.1010855203267882, + 0.1142398616114001, + 0.1592656692023089, + 0.18147236008771792, + 0.1745740312291377, + 0.19501607349635292, + 0.15428468872076637, + 0.10905050699570007, + ] + + band_stds = [ + 0.03958795985905458, + 0.047778262752410296, + 0.06636616706371974, + 0.06358874912497474, + 0.07744387147984592, + 0.09101635085921553, + 0.09218466562387101, + 0.10164581233948201, + 0.09991773043519253, + 0.08780632509122865, + ] # this reorders the bands to put S2 RGB first, then remainder of S2 reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] @@ -61,7 +58,7 @@ def __init__( unsupervised_mode: bool = False, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for So2Sat based DataLoaders. + """Initialize a new LightningDataModule instance. Args: batch_size: The batch size to use in all created DataLoaders @@ -79,28 +76,12 @@ def __init__( self.unsupervised_mode = unsupervised_mode self.kwargs = kwargs - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] - - if self.band_set == "rgb": - sample["image"] = sample["image"][:3, :, :] - - return sample + self.transform = AugmentationSequential( + Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + ) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. @@ -108,45 +89,11 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ bands = So2Sat.BAND_SETS["s2"] - train_transforms = Compose([self.preprocess]) - val_test_transforms = self.preprocess - - if not self.unsupervised_mode: - self.train_dataset = So2Sat( - split="train", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = So2Sat( - split="validation", - bands=bands, - transforms=val_test_transforms, - **self.kwargs, - ) - - self.test_dataset = So2Sat( - split="test", bands=bands, transforms=val_test_transforms, **self.kwargs - ) - - else: - - temp_train = So2Sat( - split="train", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = So2Sat( - split="validation", - bands=bands, - transforms=train_transforms, - **self.kwargs, - ) - - self.test_dataset = So2Sat( - split="test", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.train_dataset = cast( - So2Sat, temp_train + self.val_dataset + self.test_dataset - ) + self.train_dataset = So2Sat(split="train", bands=bands, **self.kwargs) + + self.val_dataset = So2Sat(split="validation", bands=bands, **self.kwargs) + + self.test_dataset = So2Sat(split="test", bands=bands, **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -187,6 +134,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.So2Sat.plot`. From 176974fe96d62f4b9fbd6842c36ea4bc8e05f257 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:16:59 -0600 Subject: [PATCH 019/108] Update SpaceNet --- torchgeo/datamodules/spacenet.py | 112 +++++++++++++------------------ 1 file changed, 45 insertions(+), 67 deletions(-) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 4af8ac8091d..1e53925eead 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -8,9 +8,11 @@ import kornia.augmentation as K import matplotlib.pyplot as plt import pytorch_lightning as pl +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import SpaceNet1 +from ..transforms import AugmentationSequential from .utils import dataset_split @@ -47,72 +49,28 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - self.padto = K.PadTo((448, 448)) - - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - # Kornia expects masks to be floats with a channel dimension - x = batch["image"] - y = batch["mask"].float().unsqueeze(1) - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input", "mask"], - ) - x, y = train_augmentations(x, y) - - # torchmetrics expects masks to be longs without a channel dimension - batch["image"] = x - batch["mask"] = y.squeeze(1).long() - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() / 255 - sample["image"] = self.padto(sample["image"]).squeeze() - - if "mask" in sample: - # We add 1 to the mask to map the current {background, building} labels to - # the values {1, 2}. This is necessary because we add 0 padding to the - # mask that we want to ignore in the loss function. - sample["mask"] = self.padto(sample["mask"].float() + 1).squeeze() - sample["mask"] = sample["mask"].long() - return sample + self.train_transform = AugmentationSequential( + K.Normalize(mean=0, std=255), + K.PadTo((448, 448)), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image", "mask"], + ) + self.test_transform = AugmentationSequential( + K.Normalize(mean=0, std=255), + K.PadTo((448, 448)), + data_keys=["image", "mask"], + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -130,7 +88,7 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - self.dataset = SpaceNet1(transforms=self.preprocess, **self.kwargs) + self.dataset = SpaceNet1(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) @@ -174,6 +132,26 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.SpaceNet.plot`.""" return self.dataset.plot(*args, **kwargs) From 5c08e952e7b1ac2208d39db33bc136f20203f19f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:23:42 -0600 Subject: [PATCH 020/108] Update UCMerced --- torchgeo/datamodules/ucmerced.py | 50 +++++++++++++++----------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index bc56908814f..b3a6766e32d 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -7,11 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -import torchvision +from kornia.augmentation import Normalize, Resize +from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import UCMerced +from ..transforms import AugmentationSequential class UCMercedDataModule(pl.LightningDataModule): @@ -36,23 +37,9 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - c, h, w = sample["image"].shape - if h != 256 or w != 256: - sample["image"] = torchvision.transforms.functional.resize( - sample["image"], size=(256, 256) - ) - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), Resize(size=256), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -70,13 +57,9 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - self.train_dataset = UCMerced( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = UCMerced(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = UCMerced(split="test", transforms=transforms, **self.kwargs) + self.train_dataset = UCMerced(split="train", **self.kwargs) + self.val_dataset = UCMerced(split="val", **self.kwargs) + self.test_dataset = UCMerced(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -117,6 +100,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.UCMerced.plot`. From 37ff0dd35538dca81f76894fc72f179deec4f9e5 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:27:49 -0600 Subject: [PATCH 021/108] Update USAVars --- torchgeo/datamodules/usavars.py | 45 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 9c7fa6d0333..aaac1be849f 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -7,9 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl +from kornia.augmentation import Normalize +from torch import Tensor from torch.utils.data import DataLoader from ..datasets import USAVars +from ..transforms import AugmentationSequential class USAVarsDataModule(pl.LightningModule): @@ -36,18 +39,9 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -62,15 +56,9 @@ def setup(self, stage: Optional[str] = None) -> None: This method is called once per GPU per run. """ - self.train_dataset = USAVars( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.val_dataset = USAVars( - split="val", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = USAVars( - split="test", transforms=self.preprocess, **self.kwargs - ) + self.train_dataset = USAVars(split="train", **self.kwargs) + self.val_dataset = USAVars(split="val", **self.kwargs) + self.test_dataset = USAVars(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training.""" @@ -99,6 +87,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.USAVars.plot`. From d47a72e7b4139f0eaa98125564835b22c44e8c5d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:31:53 -0600 Subject: [PATCH 022/108] Update xview --- torchgeo/datamodules/xview.py | 57 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 8f3fce7a8c4..5c1f1812b5d 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -7,10 +7,12 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose +from kornia.augmentation import Normalize +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import XView2 +from ..transforms import AugmentationSequential from .utils import dataset_split @@ -44,18 +46,9 @@ def __init__( self.val_split_pct = val_split_pct self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.transform = AugmentationSequential( + Normalize(mean=0, std=255), data_keys=["image"] + ) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. @@ -65,22 +58,11 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - dataset = XView2(split="train", transforms=transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset - self.val_dataset = dataset - - self.test_dataset = XView2(split="test", transforms=transforms, **self.kwargs) + dataset = XView2(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, val_pct=self.val_split_pct + ) + self.test_dataset = XView2(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. @@ -121,6 +103,21 @@ def test_dataloader(self) -> DataLoader[Any]: shuffle=False, ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + batch = self.transform(batch) + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.XView2.plot`. From d99dfb75115130ebe8fa875dcf83bfd54e09f2eb Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:35:49 -0600 Subject: [PATCH 023/108] Remove seed --- conf/nasa_marine_debris.yaml | 1 - conf/sen12ms.yaml | 1 - conf/spacenet1.yaml | 3 +-- tests/conf/sen12ms_all.yaml | 1 - tests/conf/sen12ms_s1.yaml | 1 - tests/conf/sen12ms_s2_all.yaml | 1 - tests/conf/sen12ms_s2_reduced.yaml | 1 - torchgeo/datamodules/sen12ms.py | 2 +- 8 files changed, 2 insertions(+), 9 deletions(-) diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index 3b94582652d..48b0e8b0285 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -1,5 +1,4 @@ program: - seed: 0 overwrite: True trainer: diff --git a/conf/sen12ms.yaml b/conf/sen12ms.yaml index 1dba884add1..3946774328a 100644 --- a/conf/sen12ms.yaml +++ b/conf/sen12ms.yaml @@ -20,4 +20,3 @@ experiment: band_set: "all" batch_size: 32 num_workers: 4 - seed: 0 diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml index ae9b9eeb68a..05949f26010 100644 --- a/conf/spacenet1.yaml +++ b/conf/spacenet1.yaml @@ -1,6 +1,5 @@ program: overwrite: False - seed: 0 trainer: gpus: [3] min_epochs: 50 @@ -22,4 +21,4 @@ experiment: datamodule: root: "data/spacenet" batch_size: 32 - num_workers: 4 \ No newline at end of file + num_workers: 4 diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml index ecf01b3bc72..e5676876550 100644 --- a/tests/conf/sen12ms_all.yaml +++ b/tests/conf/sen12ms_all.yaml @@ -15,4 +15,3 @@ experiment: band_set: "all" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml index a5a4f083d9c..5289c3c8b63 100644 --- a/tests/conf/sen12ms_s1.yaml +++ b/tests/conf/sen12ms_s1.yaml @@ -16,4 +16,3 @@ experiment: band_set: "s1" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml index dafe781db89..f1499b523e3 100644 --- a/tests/conf/sen12ms_s2_all.yaml +++ b/tests/conf/sen12ms_s2_all.yaml @@ -15,4 +15,3 @@ experiment: band_set: "s2-all" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml index 9891c4b44ef..72e85b56fc3 100644 --- a/tests/conf/sen12ms_s2_reduced.yaml +++ b/tests/conf/sen12ms_s2_reduced.yaml @@ -15,4 +15,3 @@ experiment: band_set: "s2-reduced" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index cdc01c2cf08..87ea34763c4 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -106,7 +106,7 @@ def setup(self, stage: Optional[str] = None) -> None: scenes.append(season_id + scene_id) train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( + GroupShuffleSplit(test_size=0.2, n_splits=2).split( scenes, groups=scenes ) ) From c579ebae53b90c5989664fea753d01cd07ac2775 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:37:51 -0600 Subject: [PATCH 024/108] mypy fixes --- torchgeo/datamodules/sen12ms.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 87ea34763c4..240f0ad135e 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -91,7 +91,7 @@ def setup(self, stage: Optional[str] = None) -> None: self.all_train_dataset = SEN12MS(split="train", bands=self.bands, **self.kwargs) - self.all_test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) + self.test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" # This patch will belong to the scene that is uniquelly identified by its @@ -113,9 +113,6 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) def train_dataloader(self) -> DataLoader[Any]: """Return a DataLoader for training. From 722f6ad06df0b248fbc370139d827ea6c60ea1b5 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 30 Dec 2022 21:58:53 -0600 Subject: [PATCH 025/108] OSCD hacks --- conf/oscd.yaml | 11 ++++------- tests/datamodules/test_oscd.py | 26 ++++++++++++++++---------- torchgeo/datamodules/oscd.py | 2 +- torchgeo/datamodules/sen12ms.py | 4 +--- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/conf/oscd.yaml b/conf/oscd.yaml index 486dbb9b64b..48634f24878 100644 --- a/conf/oscd.yaml +++ b/conf/oscd.yaml @@ -19,11 +19,8 @@ experiment: ignore_index: 0 datamodule: root: "data/oscd" - train_batch_size: 32 - num_workers: 4 - val_split_pct: 0.1 - bands: "all" - pad_size: - - 1028 - - 1028 + num_tiles_per_batch: 32 num_patches_per_tile: 128 + patch_size: 64 + val_split_pct: 0.1 + num_workers: 4 diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 3a8fdefc4e9..3d42aa60c11 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -4,6 +4,7 @@ import os import pytest +from pytorch_lightning import Trainer from _pytest.fixtures import SubRequest from torchgeo.datamodules import OSCDDataModule @@ -13,27 +14,30 @@ class TestOSCDDataModule: @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) def datamodule(self, request: SubRequest) -> OSCDDataModule: bands, val_split_pct = request.param - patch_size = (2, 2) + num_tiles_per_batch = 1 num_patches_per_tile = 2 + patch_size = 2 root = os.path.join("tests", "data", "oscd") - batch_size = 1 num_workers = 0 dm = OSCDDataModule( root=root, download=True, bands=bands, - train_batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_pct, - patch_size=patch_size, + num_tiles_per_batch=num_tiles_per_batch, num_patches_per_tile=num_patches_per_tile, + patch_size=patch_size, + val_split_pct=val_split_pct, + num_workers=num_workers, ) dm.prepare_data() dm.setup() + dm.trainer = Trainer() return dm def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.trainer.training = True # type: ignore[union-attr] sample = next(iter(datamodule.train_dataloader())) + sample = datamodule.on_after_batch_transfer(sample, 0) assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 if datamodule.test_dataset.bands == "all": @@ -42,11 +46,11 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: assert sample["image"].shape[1] == 6 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.trainer.validating = True # type: ignore[union-attr] sample = next(iter(datamodule.val_dataloader())) + sample = datamodule.on_after_batch_transfer(sample, 0) if datamodule.val_split_pct > 0.0: - assert ( - sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - ) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 if datamodule.test_dataset.bands == "all": assert sample["image"].shape[1] == 26 @@ -54,8 +58,10 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: assert sample["image"].shape[1] == 6 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.trainer.testing = True # type: ignore[union-attr] sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) + sample = datamodule.on_after_batch_transfer(sample, 0) + assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 if datamodule.test_dataset.bands == "all": assert sample["image"].shape[1] == 26 diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 1a1babe4c0b..9d43794bda3 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -110,7 +110,7 @@ def __init__( self.train_transform = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["input", "mask"], + data_keys=["image", "mask"], ) self.test_transform = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 240f0ad135e..62dece84ecc 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -106,9 +106,7 @@ def setup(self, stage: Optional[str] = None) -> None: scenes.append(season_id + scene_id) train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2).split( - scenes, groups=scenes - ) + GroupShuffleSplit(test_size=0.2, n_splits=2).split(scenes, groups=scenes) ) self.train_dataset = Subset(self.all_train_dataset, train_indices) From e237e2658182ae61a557d262a2c736be84c2a9a2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 13:59:46 -0600 Subject: [PATCH 026/108] Add NonGeoDataModule base class --- torchgeo/datamodules/bigearthnet.py | 72 +--------- torchgeo/datamodules/chesapeake.py | 16 +-- torchgeo/datamodules/cowc.py | 75 +--------- torchgeo/datamodules/cyclone.py | 70 +--------- torchgeo/datamodules/deepglobelandcover.py | 76 +---------- torchgeo/datamodules/etci2021.py | 70 +--------- torchgeo/datamodules/eurosat.py | 69 +--------- torchgeo/datamodules/fair1m.py | 72 +--------- torchgeo/datamodules/geo.py | 151 +++++++++++++++++++++ torchgeo/datamodules/gid15.py | 77 +---------- torchgeo/datamodules/inria.py | 92 +------------ torchgeo/datamodules/landcoverai.py | 81 +---------- torchgeo/datamodules/loveda.py | 72 +--------- torchgeo/datamodules/naip.py | 8 +- torchgeo/datamodules/nasa_marine_debris.py | 72 +--------- torchgeo/datamodules/oscd.py | 63 +-------- torchgeo/datamodules/potsdam.py | 75 +--------- torchgeo/datamodules/resisc45.py | 79 +---------- torchgeo/datamodules/sen12ms.py | 53 +------- torchgeo/datamodules/so2sat.py | 74 +--------- torchgeo/datamodules/spacenet.py | 76 +---------- torchgeo/datamodules/ucmerced.py | 72 +--------- torchgeo/datamodules/usavars.py | 60 +------- torchgeo/datamodules/vaihingen.py | 75 +--------- torchgeo/datamodules/xview.py | 72 +--------- 25 files changed, 270 insertions(+), 1502 deletions(-) create mode 100644 torchgeo/datamodules/geo.py diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 9ce46783c63..176bbbd875c 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -3,20 +3,17 @@ """BigEarthNet datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import BigEarthNet from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class BigEarthNetDataModule(pl.LightningDataModule): +class BigEarthNetDataModule(NonGeoDataModule): """LightningDataModule implementation for the BigEarthNet dataset. Uses the train/val/test splits from the dataset. @@ -97,7 +94,7 @@ def __init__( self.mins = self.band_mins[2:, None, None] self.maxs = self.band_maxs[2:, None, None] - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] ) @@ -120,64 +117,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = BigEarthNet(split="train", **self.kwargs) self.val_dataset = BigEarthNet(split="val", **self.kwargs) self.test_dataset = BigEarthNet(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.BigEarthNet.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index c0e288d4a25..00d363510d5 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import CenterCrop, Normalize +from pytorch_lightning import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader @@ -17,7 +17,7 @@ from ..transforms import AugmentationSequential -class ChesapeakeCVPRDataModule(pl.LightningDataModule): +class ChesapeakeCVPRDataModule(LightningDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. Uses the random splits defined per state to partition tiles into train, val, @@ -74,7 +74,7 @@ def __init__( self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = patch_size # This is a rough estimate of how large of a patch we will need to sample in @@ -94,7 +94,7 @@ def __init__( else: self.layers = ["naip-new", "lc"] - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( CenterCrop(patch_size), Normalize(mean=0, std=255), data_keys=["image", "mask"], @@ -135,7 +135,7 @@ def train_dataloader(self) -> DataLoader[Any]: sampler = RandomBatchGeoSampler( self.train_dataset, size=self.original_patch_size, - batch_size=self.num_tiles_per_batch, + batch_size=self.train_batch_size, length=self.patches_per_tile * len(self.train_dataset), ) return DataLoader( @@ -158,7 +158,7 @@ def val_dataloader(self) -> DataLoader[Any]: ) return DataLoader( self.val_dataset, - batch_size=self.num_tiles_per_batch, + batch_size=self.train_batch_size, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, @@ -177,7 +177,7 @@ def test_dataloader(self) -> DataLoader[Any]: ) return DataLoader( self.test_dataset, - batch_size=self.num_tiles_per_batch, + batch_size=self.train_batch_size, sampler=sampler, num_workers=self.num_workers, collate_fn=stack_samples, @@ -195,7 +195,7 @@ def on_after_batch_transfer( Returns: A batch of data """ - batch = self.transform(batch) + batch = self.aug(batch) return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 2fa0870eec6..da4f1b39743 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -3,19 +3,17 @@ """COWC datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader, random_split +from torch.utils.data import random_split from ..datasets import COWCCounting from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class COWCCountingDataModule(pl.LightningDataModule): +class COWCCountingDataModule(NonGeoDataModule): """LightningDataModule implementation for the COWC Counting dataset.""" def __init__( @@ -34,7 +32,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -56,69 +54,8 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ train_val_dataset = COWCCounting(split="train", **self.kwargs) - self.test_dataset = COWCCounting(split="test", **self.kwargs) self.train_dataset, self.val_dataset = random_split( train_val_dataset, [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], ) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.COWC.plot`. - - .. versionadded:: 0.2 - """ - return self.test_dataset.plot(*args, **kwargs) + self.test_dataset = COWCCounting(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index cd3d6bb78c3..469c4be1b26 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -3,20 +3,18 @@ """Tropical Cyclone Wind Estimation Competition datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize from sklearn.model_selection import GroupShuffleSplit -from torch import Tensor -from torch.utils.data import DataLoader, Subset +from torch.utils.data import Subset from ..datasets import TropicalCyclone from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class TropicalCycloneDataModule(pl.LightningDataModule): +class TropicalCycloneDataModule(NonGeoDataModule): """LightningDataModule implementation for the NASA Cyclone dataset. Implements 80/20 train/val splits based on hurricane storm ids. @@ -43,7 +41,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -87,61 +85,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) self.test_dataset = TropicalCyclone(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.TropicalCyclone.plot`.""" - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index ee2da98a5d9..cc9ac988da2 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,22 +3,19 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class DeepGlobeLandCoverDataModule(pl.LightningDataModule): +class DeepGlobeLandCoverDataModule(NonGeoDataModule): """LightningDataModule implementation for the DeepGlobe Land Cover dataset. Uses the train/test splits from the dataset. @@ -59,19 +56,20 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch + self.test_batch_size = 1 self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -90,63 +88,3 @@ def setup(self, stage: Optional[str] = None) -> None: train_dataset, self.val_split_pct ) self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 26575968f44..369e3f9d332 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -3,19 +3,17 @@ """ETCI 2021 datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader, random_split +from torch.utils.data import random_split from ..datasets import ETCI2021 from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class ETCI2021DataModule(pl.LightningDataModule): +class ETCI2021DataModule(NonGeoDataModule): """LightningDataModule implementation for the ETCI2021 dataset. Splits the existing train split from the dataset into train/val with 80/20 @@ -50,7 +48,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) @@ -80,61 +78,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset, self.val_dataset = random_split( train_val_dataset, [size_train, size_val] ) - - def train_dataloader(self) -> DataLoader[Dict[str, Any]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Any]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Any]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.ETCI2021.plot`.""" - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 8ace880c321..9e97d8bbef2 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -3,19 +3,16 @@ """EuroSAT datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import EuroSAT from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class EuroSATDataModule(pl.LightningDataModule): +class EuroSATDataModule(NonGeoDataModule): """LightningDataModule implementation for the EuroSAT dataset. Uses the train/val/test splits from the dataset. @@ -71,7 +68,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) @@ -94,61 +91,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = EuroSAT(split="train", **self.kwargs) self.val_dataset = EuroSAT(split="val", **self.kwargs) self.test_dataset = EuroSAT(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index ab370d28c41..cd64a4eca34 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -3,20 +3,17 @@ """FAIR1M datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import FAIR1M from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule from .utils import dataset_split -class FAIR1MDataModule(pl.LightningDataModule): +class FAIR1MDataModule(NonGeoDataModule): """LightningDataModule implementation for the FAIR1M dataset. .. versionadded:: 0.2 @@ -47,7 +44,7 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -63,64 +60,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.FAIR1M.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py new file mode 100644 index 00000000000..787e3d554d0 --- /dev/null +++ b/torchgeo/datamodules/geo.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Base classes for all :mod:`torchgeo` data modules.""" + +from typing import Any, Dict, Optional + +import matplotlib.pyplot as plt +from pytorch_lightning import LightningDataModule +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException +) +from torch import Tensor +from torch.utils.data import DataLoader + +from ..datasets import NonGeoDataset + + +class NonGeoDataModule(LightningDataModule): + """Base class for data modules lacking geospatial information.""" + + train_dataset: Optional[NonGeoDataset] = None + val_dataset: Optional[NonGeoDataset] = None + test_dataset: Optional[NonGeoDataset] = None + predict_dataset: Optional[NonGeoDataset] = None + + num_workers = 0 + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for training. + + Returns: + A collection of data loaders specifying training samples. + + Raises: + MisconfigurationException: If :attr:`train_dataset` is not defined. + """ + batch_size = getattr(self, "train_batch_size", getattr(self, "batch_size", 1)) + if self.train_dataset is not None: + return DataLoader( + dataset=self.train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + else: + msg = f"{self.__class__.__name__} does not define a 'train_dataset'" + raise MisconfigurationException(msg) + + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for validation. + + Returns: + A collection of data loaders specifying validation samples. + + Raises: + MisconfigurationException: If :attr:`val_dataset` is not defined. + """ + batch_size = getattr(self, "val_batch_size", getattr(self, "batch_size", 1)) + if self.val_dataset is not None: + return DataLoader( + dataset=self.val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + else: + msg = f"{self.__class__.__name__} does not define a 'val_dataset'" + raise MisconfigurationException(msg) + + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for testing. + + Returns: + A collection of data loaders specifying testing samples. + + Raises: + MisconfigurationException: If :attr:`test_dataset` is not defined. + """ + batch_size = getattr(self, "test_batch_size", getattr(self, "batch_size", 1)) + if self.test_dataset is not None: + return DataLoader( + dataset=self.test_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + else: + msg = f"{self.__class__.__name__} does not define a 'test_dataset'" + raise MisconfigurationException(msg) + + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for prediction. + + Returns: + A collection of data loaders specifying prediction samples. + + Raises: + MisconfigurationException: If :attr:`predict_dataset` is not defined. + """ + batch_size = getattr(self, "predict_batch_size", getattr(self, "batch_size", 1)) + if self.predict_dataset is not None: + return DataLoader( + dataset=self.predict_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + else: + msg = f"{self.__class__.__name__} does not define a 'predict_dataset'" + raise MisconfigurationException(msg) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if self.trainer: + if self.trainer.training: + aug = getattr(self, "train_aug", getattr(self, "aug")) + elif self.trainer.validating: + aug = getattr(self, "val_aug", getattr(self, "aug")) + elif self.trainer.testing: + aug = getattr(self, "test_aug", getattr(self, "aug")) + elif self.trainer.predicting: + aug = getattr(self, "predict_aug", getattr(self, "aug")) + + batch = aug(batch) + + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run the plot method of the dataset if one exists. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + """ + if self.train_dataset is not None: + if hasattr(self.train_dataset, "plot"): + return self.train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 7f4c1dc962e..8bb7540b8bb 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,22 +3,19 @@ """GID-15 datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import GID15 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class GID15DataModule(pl.LightningDataModule): +class GID15DataModule(NonGeoDataModule): """LightningDataModule implementation for the GID-15 dataset. Uses the train/test splits from the dataset. @@ -57,19 +54,19 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.val_transform = AugmentationSequential( + self.val_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -104,65 +101,3 @@ def setup(self, stage: Optional[str] = None) -> None: # Test set masks are not public, use for prediction instead self.predict_dataset = GID15(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for predicting. - - Returns: - predicting data loader - """ - return DataLoader( - self.predict_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating: - batch = self.val_transform(batch) - elif self.trainer.predicting: - batch = self.predict_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.GID15.plot`.""" - return self.predict_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 7f246765925..de4a6646aa7 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,22 +3,19 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize, RandomHorizontalFlip, RandomVerticalFlip -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class InriaAerialImageLabelingDataModule(pl.LightningDataModule): +class InriaAerialImageLabelingDataModule(NonGeoDataModule): """LightningDataModule implementation for the InriaAerialImageLabeling dataset. Uses the train/test splits from the dataset and further splits @@ -60,7 +57,7 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.num_workers = num_workers @@ -68,14 +65,14 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0, std=255), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=0, std=255), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -91,80 +88,3 @@ def setup(self, stage: Optional[str] = None) -> None: dataset, self.val_split_pct, self.test_split_pct ) self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_patches_per_tile, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for prediction. - - Returns: - prediction data loader - """ - return DataLoader( - self.predict_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif ( - self.trainer.validating - or self.trainer.testing - or self.trainer.predicting - ): - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.InriaAerialImageLabeling.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index fc1ae71df76..d970c03379c 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -3,10 +3,8 @@ """LandCover.ai datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import ( ColorJitter, Normalize, @@ -15,14 +13,13 @@ RandomSharpness, RandomVerticalFlip, ) -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import LandCoverAI from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class LandCoverAIDataModule(pl.LightningDataModule): +class LandCoverAIDataModule(NonGeoDataModule): """LightningDataModule implementation for the LandCover.ai dataset. Uses the train/val/test splits from the dataset. @@ -44,7 +41,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0, std=255), RandomRotation(p=0.5, degrees=90), RandomHorizontalFlip(p=0.5), @@ -60,7 +57,7 @@ def __init__( ), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image", "mask"] ) @@ -81,73 +78,5 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ self.train_dataset = LandCoverAI(split="train", **self.kwargs) - self.val_dataset = LandCoverAI(split="val", **self.kwargs) - self.test_dataset = LandCoverAI(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.LandCoverAI.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 6836041035b..706e7c8b21a 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -3,19 +3,16 @@ """LoveDA datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import LoveDA from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class LoveDADataModule(pl.LightningDataModule): +class LoveDADataModule(NonGeoDataModule): """LightningDataModule implementation for the LoveDA dataset. Uses the train/val/test splits from the dataset. @@ -39,7 +36,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -64,64 +61,3 @@ def setup(self, stage: Optional[str] = None) -> None: # Test set masks are not public, use for prediction instead self.predict_dataset = LoveDA(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def predict_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for prediction. - - Returns: - predict data loader - """ - return DataLoader( - self.predict_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.LoveDA.plot`. - - .. versionadded:: 0.4 - """ - return self.train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 4a34a0a1ce4..893f268108d 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -6,8 +6,8 @@ from typing import Any, Dict, Optional, Tuple import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize +from pytorch_lightning import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader @@ -16,7 +16,7 @@ from ..transforms import AugmentationSequential -class NAIPChesapeakeDataModule(pl.LightningDataModule): +class NAIPChesapeakeDataModule(LightningDataModule): """LightningDataModule implementation for the NAIP and Chesapeake datasets. Uses the train/val/test splits from the dataset. @@ -59,7 +59,7 @@ def __init__( elif key.startswith("chesapeake_"): self.chesapeake_kwargs[key[11:]] = val - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image", "mask"] ) @@ -154,7 +154,7 @@ def on_after_batch_transfer( Returns: A batch of data """ - batch = self.transform(batch) + batch = self.aug(batch) return batch def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index f3d4b33026e..53e798f5ba2 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,20 +3,17 @@ """NASA Marine Debris datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import NASAMarineDebris from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule from .utils import dataset_split -class NASAMarineDebrisDataModule(pl.LightningDataModule): +class NASAMarineDebrisDataModule(NonGeoDataModule): """LightningDataModule implementation for the NASA Marine Debris dataset. .. versionadded:: 0.2 @@ -47,7 +44,7 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -71,64 +68,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 9d43794bda3..d3bfc9ab690 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,23 +3,20 @@ """OSCD datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import OSCD from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class OSCDDataModule(pl.LightningDataModule): +class OSCDDataModule(NonGeoDataModule): """LightningDataModule implementation for the OSCD dataset. Uses the train/test splits from the dataset and further splits @@ -95,7 +92,7 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct @@ -107,12 +104,12 @@ def __init__( self.band_means = self.band_means[[3, 2, 1]] self.band_stds = self.band_stds[[3, 2, 1]] - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -136,51 +133,3 @@ def setup(self, stage: Optional[str] = None) -> None: train_dataset, val_pct=self.val_split_pct ) self.test_dataset = OSCD(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.OSCD.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 84108ef8603..279afdc109e 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,22 +3,19 @@ """Potsdam datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class Potsdam2DDataModule(pl.LightningDataModule): +class Potsdam2DDataModule(NonGeoDataModule): """LightningDataModule implementation for the Potsdam2D dataset. Uses the train/test splits from the dataset. @@ -61,19 +58,19 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -92,63 +89,3 @@ def setup(self, stage: Optional[str] = None) -> None: train_dataset, self.val_split_pct ) self.test_dataset = Potsdam2D(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.Potsdam2D.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 3b61958a84a..cf72aeb3a41 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -3,19 +3,16 @@ """RESISC45 datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import RESISC45 from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class RESISC45DataModule(pl.LightningDataModule): +class RESISC45DataModule(NonGeoDataModule): """LightningDataModule implementation for the RESISC45 dataset. Uses the train/val/test splits from the dataset. @@ -40,7 +37,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -57,7 +54,7 @@ def __init__( ), data_keys=["image"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) @@ -80,69 +77,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = RESISC45(split="train", **self.kwargs) self.val_dataset = RESISC45(split="val", **self.kwargs) self.test_dataset = RESISC45(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.RESISC45.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 62dece84ecc..2fe0f3c741c 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -5,16 +5,15 @@ from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit -from torch.utils.data import DataLoader, Subset +from torch.utils.data import Subset from ..datasets import SEN12MS +from .geo import NonGeoDataModule -class SEN12MSDataModule(pl.LightningDataModule): +class SEN12MSDataModule(NonGeoDataModule): """LightningDataModule implementation for the SEN12MS dataset. Implements 80/20 geographic train/val splits and uses the test split from the @@ -111,49 +110,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.SEN12MS.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 298ced010f4..95e7801f262 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,19 +3,16 @@ """So2Sat datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import So2Sat from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class So2SatDataModule(pl.LightningDataModule): +class So2SatDataModule(NonGeoDataModule): """LightningDataModule implementation for the So2Sat dataset. Uses the train/val/test splits from the dataset. @@ -76,7 +73,7 @@ def __init__( self.unsupervised_mode = unsupervised_mode self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) @@ -90,68 +87,5 @@ def setup(self, stage: Optional[str] = None) -> None: """ bands = So2Sat.BAND_SETS["s2"] self.train_dataset = So2Sat(split="train", bands=bands, **self.kwargs) - self.val_dataset = So2Sat(split="validation", bands=bands, **self.kwargs) - self.test_dataset = So2Sat(split="test", bands=bands, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.So2Sat.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 1e53925eead..7cd0e34b6aa 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -3,20 +3,17 @@ """SpaceNet datamodules.""" -from typing import Any, Dict, Optional +from typing import Any, Optional import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import SpaceNet1 from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule from .utils import dataset_split -class SpaceNet1DataModule(pl.LightningDataModule): +class SpaceNet1DataModule(NonGeoDataModule): """LightningDataModule implementation for the SpaceNet1 dataset. Randomly splits into train/val/test. @@ -49,7 +46,7 @@ def __init__( self.test_split_pct = test_split_pct self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( K.Normalize(mean=0, std=255), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), @@ -66,7 +63,7 @@ def __init__( ), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( K.Normalize(mean=0, std=255), K.PadTo((448, 448)), data_keys=["image", "mask"], @@ -92,66 +89,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.SpaceNet.plot`.""" - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index b3a6766e32d..9a9ce1c26b2 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -3,19 +3,16 @@ """UC Merced datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize, Resize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import UCMerced from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class UCMercedDataModule(pl.LightningDataModule): +class UCMercedDataModule(NonGeoDataModule): """LightningDataModule implementation for the UC Merced dataset. Uses random train/val/test splits. @@ -37,7 +34,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), Resize(size=256), data_keys=["image"] ) @@ -60,64 +57,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = UCMerced(split="train", **self.kwargs) self.val_dataset = UCMerced(split="val", **self.kwargs) self.test_dataset = UCMerced(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.UCMerced.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index aaac1be849f..0d3219ef23c 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -3,19 +3,16 @@ """USAVars datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import USAVars from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class USAVarsDataModule(pl.LightningModule): +class USAVarsDataModule(NonGeoDataModule): """LightningDataModule implementation for the USAVars dataset. Uses random train/val/test splits. @@ -39,7 +36,7 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -59,52 +56,3 @@ def setup(self, stage: Optional[str] = None) -> None: self.train_dataset = USAVars(split="train", **self.kwargs) self.val_dataset = USAVars(split="val", **self.kwargs) self.test_dataset = USAVars(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.USAVars.plot`. - - .. versionadded:: 0.4 - """ - return self.train_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 69ea79d243b..61bf3ebfd18 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,22 +3,19 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class Vaihingen2DDataModule(pl.LightningDataModule): +class Vaihingen2DDataModule(NonGeoDataModule): """LightningDataModule implementation for the Vaihingen2D dataset. Uses the train/test splits from the dataset. @@ -61,19 +58,19 @@ def __init__( """ super().__init__() - self.num_tiles_per_batch = num_tiles_per_batch + self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.num_workers = num_workers self.kwargs = kwargs - self.train_transform = AugmentationSequential( + self.train_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_transform = AugmentationSequential( + self.test_aug = AugmentationSequential( Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], @@ -92,63 +89,3 @@ def setup(self, stage: Optional[str] = None) -> None: train_dataset, self.val_split_pct ) self.test_dataset = Vaihingen2D(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.Vaihingen2D.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 5c1f1812b5d..5dfa75a21ab 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -3,20 +3,17 @@ """xView2 datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Optional -import matplotlib.pyplot as plt -import pytorch_lightning as pl from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import XView2 from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule from .utils import dataset_split -class XView2DataModule(pl.LightningDataModule): +class XView2DataModule(NonGeoDataModule): """LightningDataModule implementation for the xView2 dataset. Uses the train/val/test splits from the dataset. @@ -46,7 +43,7 @@ def __init__( self.val_split_pct = val_split_pct self.kwargs = kwargs - self.transform = AugmentationSequential( + self.aug = AugmentationSequential( Normalize(mean=0, std=255), data_keys=["image"] ) @@ -63,64 +60,3 @@ def setup(self, stage: Optional[str] = None) -> None: dataset, val_pct=self.val_split_pct ) self.test_dataset = XView2(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.transform(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.XView2.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) From 34c8d7eef47caf8405a6d66b683d9697ef4ebce5 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 14:09:09 -0600 Subject: [PATCH 027/108] Fixes --- tests/datamodules/test_oscd.py | 2 +- torchgeo/datamodules/__init__.py | 3 +++ torchgeo/datamodules/geo.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 3d42aa60c11..0d8ee7e2398 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -4,8 +4,8 @@ import os import pytest -from pytorch_lightning import Trainer from _pytest.fixtures import SubRequest +from pytorch_lightning import Trainer from torchgeo.datamodules import OSCDDataModule diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 14e50ad7d45..85ea803cd67 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -11,6 +11,7 @@ from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule +from .geo import NonGeoDataModule from .gid15 import GID15DataModule from .inria import InriaAerialImageLabelingDataModule from .landcoverai import LandCoverAIDataModule @@ -55,4 +56,6 @@ "USAVarsDataModule", "Vaihingen2DDataModule", "XView2DataModule", + # Base classes + "NonGeoDataModule", ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 787e3d554d0..639e6c9ac6b 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -7,9 +7,10 @@ import matplotlib.pyplot as plt from pytorch_lightning import LightningDataModule + # TODO: import from lightning_lite instead from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException + MisconfigurationException, ) from torch import Tensor from torch.utils.data import DataLoader @@ -145,6 +146,7 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: **kwargs: Keyword arguments passed to plot method. Returns: + a matplotlib Figure with the image, ground truth, and predictions """ if self.train_dataset is not None: if hasattr(self.train_dataset, "plot"): From d66672a9394f12e5c7b48fc1f64344a4b852b182 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 14:10:33 -0600 Subject: [PATCH 028/108] Add base class to docs --- docs/api/datamodules.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 4833ff815e4..d5a868faa69 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -128,3 +128,11 @@ xView2 ^^^^^^ .. autoclass:: XView2DataModule + +Base Classes +------------ + +NonGeoDataModule +^^^^^^^^^^^^^^^^ + +.. autoclass:: NonGeoDataModule From 3a47cd8402a45adbfee2cb9cf0820731e8c35121 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 14:22:01 -0600 Subject: [PATCH 029/108] mypy fixes --- tests/datamodules/test_oscd.py | 6 +++--- torchgeo/datamodules/cowc.py | 2 +- torchgeo/datamodules/geo.py | 19 ++++++++++++------- torchgeo/datamodules/oscd.py | 4 ++-- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 0d8ee7e2398..6e5b76a9870 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -40,7 +40,7 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: sample = datamodule.on_after_batch_transfer(sample, 0) assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - if datamodule.test_dataset.bands == "all": + if datamodule.bands == "all": assert sample["image"].shape[1] == 26 else: assert sample["image"].shape[1] == 6 @@ -52,7 +52,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: if datamodule.val_split_pct > 0.0: assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.test_dataset.bands == "all": + if datamodule.bands == "all": assert sample["image"].shape[1] == 26 else: assert sample["image"].shape[1] == 6 @@ -63,7 +63,7 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: sample = datamodule.on_after_batch_transfer(sample, 0) assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.test_dataset.bands == "all": + if datamodule.bands == "all": assert sample["image"].shape[1] == 26 else: assert sample["image"].shape[1] == 6 diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index da4f1b39743..7daba6d2e13 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -54,8 +54,8 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ train_val_dataset = COWCCounting(split="train", **self.kwargs) + self.test_dataset = COWCCounting(split="test", **self.kwargs) self.train_dataset, self.val_dataset = random_split( train_val_dataset, [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], ) - self.test_dataset = COWCCounting(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 639e6c9ac6b..f9755de5776 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -13,18 +13,23 @@ MisconfigurationException, ) from torch import Tensor -from torch.utils.data import DataLoader - -from ..datasets import NonGeoDataset +from torch.utils.data import DataLoader, Dataset class NonGeoDataModule(LightningDataModule): """Base class for data modules lacking geospatial information.""" - train_dataset: Optional[NonGeoDataset] = None - val_dataset: Optional[NonGeoDataset] = None - test_dataset: Optional[NonGeoDataset] = None - predict_dataset: Optional[NonGeoDataset] = None + #: Training dataset + train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + #: Validation dataset + val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + #: Testing dataset + test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + #: Prediction dataset + predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None num_workers = 0 diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index d3bfc9ab690..76a6ec5d092 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -99,8 +99,8 @@ def __init__( self.num_workers = num_workers self.kwargs = kwargs - bands = kwargs.get("bands", "all") - if bands == "rgb": + self.bands = kwargs.get("bands", "all") + if self.bands == "rgb": self.band_means = self.band_means[[3, 2, 1]] self.band_stds = self.band_stds[[3, 2, 1]] From afc824677263698974c12617288bede9bc3a6f90 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 15:30:29 -0600 Subject: [PATCH 030/108] Fix several tests --- tests/trainers/test_classification.py | 5 ++-- torchgeo/datamodules/geo.py | 38 +++++++++++++++++---------- torchgeo/datamodules/resisc45.py | 2 +- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 922e171d422..96617c49bb1 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -15,6 +15,7 @@ from torch.nn.modules import Module from torchvision.models._api import WeightsEnum +from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, @@ -140,7 +141,7 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: - monkeypatch.delattr(EuroSATDataModule, "plot") + monkeypatch.delattr(EuroSAT, "plot") datamodule = EuroSATDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) @@ -218,7 +219,7 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch ) -> None: - monkeypatch.delattr(BigEarthNetDataModule, "plot") + monkeypatch.delattr(BigEarthNet, "plot") datamodule = BigEarthNetDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index f9755de5776..e840ddbed77 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -13,6 +13,7 @@ MisconfigurationException, ) from torch import Tensor +from torch.nn import Identity, Module from torch.utils.data import DataLoader, Dataset @@ -31,8 +32,21 @@ class NonGeoDataModule(LightningDataModule): #: Prediction dataset predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + # DataLoader arguments + batch_size: Optional[int] = None + train_batch_size: Optional[int] = None + val_batch_size: Optional[int] = None + test_batch_size: Optional[int] = None + predict_patch_size: Optional[int] = None num_workers = 0 + # Data augmentation + aug: Optional[Module] = None + train_aug: Optional[Module] = None + val_aug: Optional[Module] = None + test_aug: Optional[Module] = None + predict_aug: Optional[Module] = None + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. @@ -42,11 +56,10 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: Raises: MisconfigurationException: If :attr:`train_dataset` is not defined. """ - batch_size = getattr(self, "train_batch_size", getattr(self, "batch_size", 1)) if self.train_dataset is not None: return DataLoader( dataset=self.train_dataset, - batch_size=batch_size, + batch_size=self.train_batch_size or self.batch_size or 1, shuffle=True, num_workers=self.num_workers, ) @@ -63,11 +76,10 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: Raises: MisconfigurationException: If :attr:`val_dataset` is not defined. """ - batch_size = getattr(self, "val_batch_size", getattr(self, "batch_size", 1)) if self.val_dataset is not None: return DataLoader( dataset=self.val_dataset, - batch_size=batch_size, + batch_size=self.val_batch_size or self.batch_size or 1, shuffle=False, num_workers=self.num_workers, ) @@ -84,12 +96,11 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: Raises: MisconfigurationException: If :attr:`test_dataset` is not defined. """ - batch_size = getattr(self, "test_batch_size", getattr(self, "batch_size", 1)) if self.test_dataset is not None: return DataLoader( dataset=self.test_dataset, - batch_size=batch_size, - shuffle=True, + batch_size=self.test_batch_size or self.batch_size or 1, + shuffle=False, num_workers=self.num_workers, ) else: @@ -105,12 +116,11 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: Raises: MisconfigurationException: If :attr:`predict_dataset` is not defined. """ - batch_size = getattr(self, "predict_batch_size", getattr(self, "batch_size", 1)) if self.predict_dataset is not None: return DataLoader( dataset=self.predict_dataset, - batch_size=batch_size, - shuffle=True, + batch_size=self.predict_batch_size or self.batch_size or 1, + shuffle=False, num_workers=self.num_workers, ) else: @@ -131,13 +141,13 @@ def on_after_batch_transfer( """ if self.trainer: if self.trainer.training: - aug = getattr(self, "train_aug", getattr(self, "aug")) + aug = self.train_aug or self.aug or Identity() elif self.trainer.validating: - aug = getattr(self, "val_aug", getattr(self, "aug")) + aug = self.val_aug or self.aug or Identity() elif self.trainer.testing: - aug = getattr(self, "test_aug", getattr(self, "aug")) + aug = self.test_aug or self.aug or Identity() elif self.trainer.predicting: - aug = getattr(self, "predict_aug", getattr(self, "aug")) + aug = self.predict_aug or self.aug or Identity() batch = aug(batch) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index cf72aeb3a41..dbfc67da430 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -54,7 +54,7 @@ def __init__( ), data_keys=["image"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) From 22f695d29b621d7f98c1b7ee516633132e83f43b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 15:37:49 -0600 Subject: [PATCH 031/108] Fix Normalize --- torchgeo/datamodules/chesapeake.py | 2 +- torchgeo/datamodules/cowc.py | 2 +- torchgeo/datamodules/cyclone.py | 2 +- torchgeo/datamodules/fair1m.py | 2 +- torchgeo/datamodules/inria.py | 4 ++-- torchgeo/datamodules/landcoverai.py | 4 ++-- torchgeo/datamodules/loveda.py | 2 +- torchgeo/datamodules/naip.py | 2 +- torchgeo/datamodules/nasa_marine_debris.py | 2 +- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/datamodules/ucmerced.py | 2 +- torchgeo/datamodules/usavars.py | 2 +- torchgeo/datamodules/xview.py | 2 +- torchgeo/datasets/geo.py | 2 +- 14 files changed, 17 insertions(+), 17 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 00d363510d5..942c2311f29 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -96,7 +96,7 @@ def __init__( self.aug = AugmentationSequential( CenterCrop(patch_size), - Normalize(mean=0, std=255), + Normalize(mean=0.0 std=255.0), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 7daba6d2e13..13d2d50e22b 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -33,7 +33,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 469c4be1b26..dae979a9c9b 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -42,7 +42,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index cd64a4eca34..d3790007553 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index de4a6646aa7..80a458c3e6a 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -66,14 +66,14 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0, std=255), + Normalize(mean=0.0 std=255.0), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0, std=255), + Normalize(mean=0.0 std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index d970c03379c..5290e9144c6 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -42,7 +42,7 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0, std=255), + Normalize(mean=0.0 std=255.0), RandomRotation(p=0.5, degrees=90), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), @@ -58,7 +58,7 @@ def __init__( data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image", "mask"] + Normalize(mean=0.0 std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 706e7c8b21a..e43a4c236ef 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 893f268108d..5b34a795177 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -60,7 +60,7 @@ def __init__( self.chesapeake_kwargs[key[11:]] = val self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image", "mask"] + Normalize(mean=0.0 std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 53e798f5ba2..80e3ae5c91c 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 7cd0e34b6aa..1bef346bdd7 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -47,7 +47,7 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - K.Normalize(mean=0, std=255), + K.Normalize(mean=0.0 std=255.0), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -64,7 +64,7 @@ def __init__( data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - K.Normalize(mean=0, std=255), + K.Normalize(mean=0.0 std=255.0), K.PadTo((448, 448)), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 9a9ce1c26b2..5c76e27c3a6 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -35,7 +35,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), Resize(size=256), data_keys=["image"] + Normalize(mean=0.0, std=255.0), Resize(size=256), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 0d3219ef23c..891bd2a04a3 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 5dfa75a21ab..a76fb69a222 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -44,7 +44,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0, std=255), data_keys=["image"] + Normalize(mean=0.0 std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index dbc4011d0a7..189875f3c1f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -777,7 +777,7 @@ def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: """ img, label = ImageFolder.__getitem__(self, index) array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) label = torch.tensor(label) From 97ce260499c9e99ce2e98d3abeb4a6c8b339b79b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 15:38:56 -0600 Subject: [PATCH 032/108] Syntax error --- torchgeo/datamodules/chesapeake.py | 2 +- torchgeo/datamodules/cowc.py | 2 +- torchgeo/datamodules/cyclone.py | 2 +- torchgeo/datamodules/fair1m.py | 2 +- torchgeo/datamodules/inria.py | 4 ++-- torchgeo/datamodules/landcoverai.py | 4 ++-- torchgeo/datamodules/loveda.py | 2 +- torchgeo/datamodules/naip.py | 2 +- torchgeo/datamodules/nasa_marine_debris.py | 2 +- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/datamodules/usavars.py | 2 +- torchgeo/datamodules/xview.py | 2 +- 12 files changed, 15 insertions(+), 15 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 942c2311f29..13425f55e49 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -96,7 +96,7 @@ def __init__( self.aug = AugmentationSequential( CenterCrop(patch_size), - Normalize(mean=0.0 std=255.0), + Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 13d2d50e22b..458f9c9a626 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -33,7 +33,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index dae979a9c9b..0761e9b3877 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -42,7 +42,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index d3790007553..92714d7c915 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 80a458c3e6a..371fe56f819 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -66,14 +66,14 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), + Normalize(mean=0.0, std=255.0), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), + Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 5290e9144c6..d3b6d6bada8 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -42,7 +42,7 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), + Normalize(mean=0.0, std=255.0), RandomRotation(p=0.5, degrees=90), RandomHorizontalFlip(p=0.5), RandomVerticalFlip(p=0.5), @@ -58,7 +58,7 @@ def __init__( data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image", "mask"] + Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index e43a4c236ef..67929266e02 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 5b34a795177..d1b40531157 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -60,7 +60,7 @@ def __init__( self.chesapeake_kwargs[key[11:]] = val self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image", "mask"] + Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 80e3ae5c91c..5c49d3709be 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 1bef346bdd7..9994c58a1be 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -47,7 +47,7 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0 std=255.0), + K.Normalize(mean=0.0, std=255.0), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -64,7 +64,7 @@ def __init__( data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - K.Normalize(mean=0.0 std=255.0), + K.Normalize(mean=0.0, std=255.0), K.PadTo((448, 448)), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 891bd2a04a3..b444031c9b9 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index a76fb69a222..5416adb71b4 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -44,7 +44,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0 std=255.0), data_keys=["image"] + Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: From 5cdc234506bfa15c05e53021e67f88aa78e43c67 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 16:09:12 -0600 Subject: [PATCH 033/108] Fix bigearthnet --- torchgeo/datamodules/bigearthnet.py | 14 +++++++------- torchgeo/datasets/bigearthnet.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 176bbbd875c..f4c0a53c6d0 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -85,14 +85,14 @@ def __init__( bands = kwargs.get("bands", "all") if bands == "all": - self.mins = self.band_mins[:, None, None] - self.maxs = self.band_maxs[:, None, None] + self.mins = self.band_mins + self.maxs = self.band_maxs elif bands == "s1": - self.mins = self.band_mins[:2, None, None] - self.maxs = self.band_maxs[:2, None, None] + self.mins = self.band_mins[:2] + self.maxs = self.band_maxs[:2] else: - self.mins = self.band_mins[2:, None, None] - self.maxs = self.band_maxs[2:, None, None] + self.mins = self.band_mins[2:] + self.maxs = self.band_maxs[2:] self.aug = AugmentationSequential( Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] @@ -104,7 +104,7 @@ def prepare_data(self) -> None: This method is only called once per run. """ if self.kwargs.get("download", False): - BigEarthNet(split="train", **self.kwargs) + BigEarthNet(**self.kwargs) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main Dataset objects. diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 81c318baa52..cbfe0c2203b 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -394,7 +394,7 @@ def _load_image(self, index: int) -> Tensor: ) images.append(array) arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0) - tensor = torch.from_numpy(arrays) + tensor = torch.from_numpy(arrays).float() return tensor def _load_target(self, index: int) -> Tensor: From ed059ec1123c1f1d8193b622185335efe74156c1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 31 Dec 2022 17:04:38 -0600 Subject: [PATCH 034/108] Fix dtype --- torchgeo/datasets/so2sat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 7c592f94cde..0770228a229 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -234,7 +234,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: s1 = torch.from_numpy(s1) s2 = torch.from_numpy(s2) - sample = {"image": torch.cat([s1, s2]), "label": label} + sample = {"image": torch.cat([s1, s2]).float(), "label": label} if self.transforms is not None: sample = self.transforms(sample) From e5752fc3b36581a8e4eb8a2df8c454b6b182035e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 13:24:48 -0600 Subject: [PATCH 035/108] Consistent kornia import --- torchgeo/datamodules/bigearthnet.py | 4 ++-- torchgeo/datamodules/chesapeake.py | 6 +++--- torchgeo/datamodules/cowc.py | 4 ++-- torchgeo/datamodules/cyclone.py | 4 ++-- torchgeo/datamodules/deepglobelandcover.py | 6 +++--- torchgeo/datamodules/etci2021.py | 4 ++-- torchgeo/datamodules/eurosat.py | 4 ++-- torchgeo/datamodules/fair1m.py | 4 ++-- torchgeo/datamodules/gid15.py | 8 ++++---- torchgeo/datamodules/inria.py | 10 +++++----- torchgeo/datamodules/landcoverai.py | 23 ++++++++-------------- torchgeo/datamodules/loveda.py | 4 ++-- torchgeo/datamodules/naip.py | 4 ++-- torchgeo/datamodules/nasa_marine_debris.py | 4 ++-- torchgeo/datamodules/oscd.py | 6 +++--- torchgeo/datamodules/potsdam.py | 6 +++--- torchgeo/datamodules/so2sat.py | 4 ++-- torchgeo/datamodules/ucmerced.py | 4 ++-- torchgeo/datamodules/usavars.py | 4 ++-- torchgeo/datamodules/vaihingen.py | 6 +++--- torchgeo/datamodules/xview.py | 4 ++-- 21 files changed, 58 insertions(+), 65 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index f4c0a53c6d0..66f19bf373a 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -5,8 +5,8 @@ from typing import Any, Optional +import kornia.augmentation as K import torch -from kornia.augmentation import Normalize from ..datasets import BigEarthNet from ..transforms import AugmentationSequential @@ -95,7 +95,7 @@ def __init__( self.maxs = self.band_maxs[2:] self.aug = AugmentationSequential( - Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] + K.Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 13425f55e49..5debfabd2d4 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -5,8 +5,8 @@ from typing import Any, Dict, List, Optional +import kornia.augmentation as K import matplotlib.pyplot as plt -from kornia.augmentation import CenterCrop, Normalize from pytorch_lightning import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader @@ -95,8 +95,8 @@ def __init__( self.layers = ["naip-new", "lc"] self.aug = AugmentationSequential( - CenterCrop(patch_size), - Normalize(mean=0.0, std=255.0), + K.CenterCrop(patch_size), + K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 458f9c9a626..3cab8ba0ed6 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from torch.utils.data import random_split from ..datasets import COWCCounting @@ -33,7 +33,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 0761e9b3877..9a3beaf3b99 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from sklearn.model_selection import GroupShuffleSplit from torch.utils.data import Subset @@ -42,7 +42,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index cc9ac988da2..a473e3de180 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple @@ -65,12 +65,12 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 369e3f9d332..fac22ce7adf 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from torch.utils.data import random_split from ..datasets import ETCI2021 @@ -49,7 +49,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 9e97d8bbef2..637508506da 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import EuroSAT from ..transforms import AugmentationSequential @@ -69,7 +69,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 92714d7c915..147b5c37e08 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import FAIR1M from ..transforms import AugmentationSequential @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 8bb7540b8bb..54a300760d5 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import GID15 from ..samplers.utils import _to_tuple @@ -62,17 +62,17 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.val_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) self.predict_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image"], ) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 371fe56f819..3a68831afd1 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union -from kornia.augmentation import Normalize, RandomHorizontalFlip, RandomVerticalFlip +import kornia.augmentation as K from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple @@ -66,14 +66,14 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - RandomHorizontalFlip(p=0.5), - RandomVerticalFlip(p=0.5), + K.Normalize(mean=0.0, std=255.0), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index d3b6d6bada8..ee9540ae7cb 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -5,14 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import ( - ColorJitter, - Normalize, - RandomHorizontalFlip, - RandomRotation, - RandomSharpness, - RandomVerticalFlip, -) +import kornia.augmentation as K from ..datasets import LandCoverAI from ..transforms import AugmentationSequential @@ -42,12 +35,12 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - RandomRotation(p=0.5, degrees=90), - RandomHorizontalFlip(p=0.5), - RandomVerticalFlip(p=0.5), - RandomSharpness(p=0.5), - ColorJitter( + K.Normalize(mean=0.0, std=255.0), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter( p=0.5, brightness=0.1, contrast=0.1, @@ -58,7 +51,7 @@ def __init__( data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 67929266e02..3f3207a1eb9 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import LoveDA from ..transforms import AugmentationSequential @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index d1b40531157..6c495f29f78 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -5,8 +5,8 @@ from typing import Any, Dict, Optional, Tuple +import kornia.augmentation as K import matplotlib.pyplot as plt -from kornia.augmentation import Normalize from pytorch_lightning import LightningDataModule from torch import Tensor from torch.utils.data import DataLoader @@ -60,7 +60,7 @@ def __init__( self.chesapeake_kwargs[key[11:]] = val self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 5c49d3709be..cb38b8ace9e 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import NASAMarineDebris from ..transforms import AugmentationSequential @@ -45,7 +45,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 76a6ec5d092..9041fd83d8c 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -5,8 +5,8 @@ from typing import Any, Optional, Tuple, Union +import kornia.augmentation as K import torch -from kornia.augmentation import Normalize from ..datasets import OSCD from ..samplers.utils import _to_tuple @@ -105,12 +105,12 @@ def __init__( self.band_stds = self.band_stds[[3, 2, 1]] self.train_aug = AugmentationSequential( - Normalize(mean=self.band_means, std=self.band_stds), + K.Normalize(mean=self.band_means, std=self.band_stds), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=self.band_means, std=self.band_stds), + K.Normalize(mean=self.band_means, std=self.band_stds), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 279afdc109e..716dcfd8a7c 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple @@ -66,12 +66,12 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 95e7801f262..f7ccf7e2c22 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import So2Sat from ..transforms import AugmentationSequential @@ -74,7 +74,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] + K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 5c76e27c3a6..90c9c29bad6 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize, Resize +import kornia.augmentation as K from ..datasets import UCMerced from ..transforms import AugmentationSequential @@ -35,7 +35,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), Resize(size=256), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), K.Resize(size=256), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index b444031c9b9..fc2e05e5a0e 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import USAVars from ..transforms import AugmentationSequential @@ -37,7 +37,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 61bf3ebfd18..a5b1b5a7375 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple @@ -66,12 +66,12 @@ def __init__( self.kwargs = kwargs self.train_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.test_aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), + K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 5416adb71b4..744d5f68595 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -5,7 +5,7 @@ from typing import Any, Optional -from kornia.augmentation import Normalize +import kornia.augmentation as K from ..datasets import XView2 from ..transforms import AugmentationSequential @@ -44,7 +44,7 @@ def __init__( self.kwargs = kwargs self.aug = AugmentationSequential( - Normalize(mean=0.0, std=255.0), data_keys=["image"] + K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) def setup(self, stage: Optional[str] = None) -> None: From 39885005a2a037bbf3f96fddc3ed505615ff246e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 14:03:14 -0600 Subject: [PATCH 036/108] Get regression datasets working --- torchgeo/datasets/cowc.py | 4 ++-- torchgeo/datasets/cyclone.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 1b40974d6d9..d34749cf2d0 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -147,7 +147,7 @@ def _load_image(self, index: int) -> Tensor: filename = os.path.join(self.root, self.images[index]) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -162,7 +162,7 @@ def _load_target(self, index: int) -> Tensor: the target """ target = int(self.targets[index]) - tensor = torch.tensor(target) + tensor = torch.tensor(target).float() return tensor def _check_integrity(self) -> bool: diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index c87dcc997a7..ebcc48c37c3 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -153,10 +153,8 @@ def _load_image(self, directory: str) -> Tensor: except AttributeError: resample = Image.BILINEAR img = img.resize(size=(self.size, self.size), resample=resample) - array: "np.typing.NDArray[np.int_]" = np.array(img) - if len(array.shape) == 3: - array = array[:, :, 0] - tensor = torch.from_numpy(array) + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor = torch.from_numpy(array).permute((2, 0, 1)).float() return tensor def _load_features(self, directory: str) -> Dict[str, Any]: @@ -178,7 +176,7 @@ def _load_features(self, directory: str) -> Dict[str, Any]: features["relative_time"] = int(features["relative_time"]) features["ocean"] = int(features["ocean"]) - features["label"] = int(features["wind_speed"]) + features["label"] = torch.tensor(int(features["wind_speed"])).float() return features From 3d5ac476a9e581ca71517363e73e2141777430a3 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 14:49:25 -0600 Subject: [PATCH 037/108] Fix detection tests --- tests/trainers/test_detection.py | 3 ++- torchgeo/datasets/nasa_marine_debris.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 68e48e92817..0b4ec5314ae 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,6 +9,7 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +from torchgeo.datasets import NASAMarineDebris from torchgeo.datamodules import NASAMarineDebrisDataModule from torchgeo.trainers import ObjectDetectionTask @@ -59,7 +60,7 @@ def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch ) -> None: - monkeypatch.delattr(NASAMarineDebrisDataModule, "plot") + monkeypatch.delattr(NASAMarineDebris, "plot") datamodule = NASAMarineDebrisDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 56428a5af84..06e37de81f9 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -104,6 +104,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: h_check = (sample["boxes"][:, 3] - sample["boxes"][:, 1]) > 0 indices = w_check & h_check sample["boxes"] = sample["boxes"][indices] + sample["labels"] = torch.ones(len(indices), dtype=torch.int64) if self.transforms is not None: sample = self.transforms(sample) @@ -129,7 +130,7 @@ def _load_image(self, path: str) -> Tensor: """ with rasterio.open(path) as f: array = f.read() - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: From 6056e9b817d7cfddd6ad80d774fd0e6fe69b4dc0 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 14:56:57 -0600 Subject: [PATCH 038/108] Fix some chesapeake bugs --- conf/chesapeake_cvpr.yaml | 4 ++-- tests/conf/chesapeake_cvpr_5.yaml | 4 ++-- tests/conf/chesapeake_cvpr_7.yaml | 4 ++-- tests/conf/chesapeake_cvpr_prior.yaml | 4 ++-- torchgeo/datamodules/chesapeake.py | 2 +- torchgeo/datasets/geo.py | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index 2ee411a6b17..560d61f7a7b 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -25,9 +25,9 @@ experiment: - "de-val" test_splits: - "de-test" - patches_per_tile: 200 + num_tiles_per_batch: 64 + num_patches_per_tile: 200 patch_size: 256 - batch_size: 64 num_workers: 4 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index 7e1663c9493..af1394da67d 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -20,9 +20,9 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 + num_tiles_per_batch: 2 + num_patches_per_tile: 2 patch_size: 64 - batch_size: 2 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 9a34e401d19..ba40d618c1a 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -20,9 +20,9 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 + num_tiles_per_batch: 2 + num_patches_per_tile: 2 patch_size: 64 - batch_size: 2 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index 907b17ac76d..ca774e9917b 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -20,9 +20,9 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 + num_tiles_per_batch: 2 + num_patches_per_tile: 2 patch_size: 64 - batch_size: 2 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: True diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 5debfabd2d4..82c3756495e 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -136,7 +136,7 @@ def train_dataloader(self) -> DataLoader[Any]: self.train_dataset, size=self.original_patch_size, batch_size=self.train_batch_size, - length=self.patches_per_tile * len(self.train_dataset), + length=self.num_patches_per_tile * len(self.train_dataset), ) return DataLoader( self.train_dataset, diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 189875f3c1f..a88bd3fcd7a 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -428,7 +428,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: data = self._merge_files(filepaths, query, self.band_indexes) key = "image" if self.is_image else "mask" - sample = {key: data, "crs": self.crs, "bbox": query} + sample = {key: data} if self.transforms is not None: sample = self.transforms(sample) From 96252bb3aeef1dc708e22c2178b4d845406bc9e4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 15:18:08 -0600 Subject: [PATCH 039/108] Fix several segmentation issues --- conf/inria.yaml | 12 ++++++------ tests/conf/{inria_test.yaml => inria.yaml} | 14 +++++++------- tests/conf/inria_train.yaml | 20 -------------------- tests/conf/inria_val.yaml | 20 -------------------- tests/trainers/test_segmentation.py | 11 +++++------ torchgeo/datamodules/deepglobelandcover.py | 2 +- torchgeo/datamodules/geo.py | 2 +- torchgeo/datamodules/inria.py | 7 ++++++- torchgeo/datamodules/landcoverai.py | 2 +- torchgeo/datamodules/oscd.py | 2 +- torchgeo/datamodules/potsdam.py | 2 +- torchgeo/datamodules/spacenet.py | 2 +- torchgeo/datamodules/vaihingen.py | 2 +- torchgeo/datasets/inria.py | 4 ++-- torchgeo/datasets/landcoverai.py | 4 ++-- torchgeo/datasets/loveda.py | 2 +- torchgeo/datasets/spacenet.py | 2 +- 17 files changed, 37 insertions(+), 73 deletions(-) rename tests/conf/{inria_test.yaml => inria.yaml} (59%) delete mode 100644 tests/conf/inria_train.yaml delete mode 100644 tests/conf/inria_val.yaml diff --git a/conf/inria.yaml b/conf/inria.yaml index 462c873fda6..234ddffcb01 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -1,7 +1,7 @@ program: overwrite: True - + trainer: gpus: 1 min_epochs: 5 @@ -23,8 +23,8 @@ experiment: num_classes: 2 ignore_index: null datamodule: - root: "data/inria" - batch_size: 2 - num_workers: 32 - patch_size: 512 - num_patches_per_tile: 4 + root: "data/inria" + num_tiles_per_batch: 2 + num_patches_per_tile: 4 + patch_size: 512 + num_workers: 32 diff --git a/tests/conf/inria_test.yaml b/tests/conf/inria.yaml similarity index 59% rename from tests/conf/inria_test.yaml rename to tests/conf/inria.yaml index 6bc4b5bd7cc..7cb05607bff 100644 --- a/tests/conf/inria_test.yaml +++ b/tests/conf/inria.yaml @@ -11,10 +11,10 @@ experiment: num_classes: 2 ignore_index: null datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.2 - test_split_pct: 0.2 - patch_size: 2 - num_patches_per_tile: 2 + root: "tests/data/inria" + num_tiles_per_batch: 1 + num_patches_per_tile: 2 + patch_size: 2 + num_workers: 0 + val_split_pct: 0.2 + test_split_pct: 0.2 diff --git a/tests/conf/inria_train.yaml b/tests/conf/inria_train.yaml deleted file mode 100644 index 99db7925f27..00000000000 --- a/tests/conf/inria_train.yaml +++ /dev/null @@ -1,20 +0,0 @@ -experiment: - task: "inria" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.0 - test_split_pct: 0.0 - patch_size: 2 - num_patches_per_tile: 2 diff --git a/tests/conf/inria_val.yaml b/tests/conf/inria_val.yaml deleted file mode 100644 index c20f8923439..00000000000 --- a/tests/conf/inria_val.yaml +++ /dev/null @@ -1,20 +0,0 @@ -experiment: - task: "inria" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.2 - test_split_pct: 0.0 - patch_size: 2 - num_patches_per_tile: 2 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 35467c179d9..79d5be11036 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -11,6 +11,7 @@ from pytorch_lightning import LightningDataModule, Trainer from torch.nn.modules import Module +from torchgeo.datasets import LandCoverAI from torchgeo.datamodules import ( ChesapeakeCVPRDataModule, DeepGlobeLandCoverDataModule, @@ -43,9 +44,7 @@ class TestSemanticSegmentationTask: ("deepglobelandcover", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), ("gid15", GID15DataModule), - ("inria_train", InriaAerialImageLabelingDataModule), - ("inria_val", InriaAerialImageLabelingDataModule), - ("inria_test", InriaAerialImageLabelingDataModule), + ("inria", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), ("loveda", LoveDADataModule), ("naipchesapeake", NAIPChesapeakeDataModule), @@ -86,10 +85,10 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if hasattr(datamodule, "test_dataset") or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): trainer.test(model=model, datamodule=datamodule) - if hasattr(datamodule, "predict_dataset"): + if datamodule.predict_dataset is not None: trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: @@ -151,7 +150,7 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: def test_missing_attributes( self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch ) -> None: - monkeypatch.delattr(LandCoverAIDataModule, "plot") + monkeypatch.delattr(LandCoverAI, "plot") datamodule = LandCoverAIDataModule( root="tests/data/landcoverai", batch_size=1, num_workers=0 ) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index a473e3de180..cf868f9b60e 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -69,7 +69,7 @@ def __init__( _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e840ddbed77..bf1206b0e70 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -37,7 +37,7 @@ class NonGeoDataModule(LightningDataModule): train_batch_size: Optional[int] = None val_batch_size: Optional[int] = None test_batch_size: Optional[int] = None - predict_patch_size: Optional[int] = None + predict_batch_size: Optional[int] = None num_workers = 0 # Data augmentation diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 3a68831afd1..9b5035821c9 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -72,11 +72,16 @@ def __init__( _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.val_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) + self.predict_aug = AugmentationSequential( + K.Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image"], + ) def setup(self, stage: Optional[str] = None) -> None: """Initialize the main ``Dataset`` objects. diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index ee9540ae7cb..84fbbf64e86 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -50,7 +50,7 @@ def __init__( ), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 9041fd83d8c..2b7588271c2 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -109,7 +109,7 @@ def __init__( _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 716dcfd8a7c..accea53ac92 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -70,7 +70,7 @@ def __init__( _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 9994c58a1be..cda314518ea 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -63,7 +63,7 @@ def __init__( ), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), K.PadTo((448, 448)), data_keys=["image", "mask"], diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index a5b1b5a7375..761c352ab3a 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -70,7 +70,7 @@ def __init__( _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) - self.test_aug = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index c1698cf6c15..395d0c909bb 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -116,7 +116,7 @@ def _load_image(self, path: str) -> Tensor: """ with rio.open(path) as img: array = img.read().astype(np.int32) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: @@ -131,7 +131,7 @@ def _load_target(self, path: str) -> Tensor: with rio.open(path) as img: array = img.read().astype(np.int32) array = np.clip(array, 0, 1) - mask = torch.from_numpy(array[0]) + mask = torch.from_numpy(array[0]).long() return mask def __len__(self) -> int: diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index a6514aae510..002ff595bd6 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -149,7 +149,7 @@ def _load_image(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + ".jpg") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -167,7 +167,7 @@ def _load_target(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + "_m.png") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).long() return tensor def _verify(self) -> None: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 70d005c0df3..f6e2b83f3d2 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -214,7 +214,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 3e5baa7ad75..f0e3987dd64 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -158,7 +158,7 @@ def _load_image(self, path: str) -> Tuple[Tensor, Affine, CRS]: filename = os.path.join(path) with rio.open(filename) as img: array = img.read().astype(np.int32) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor, img.transform, img.crs def _load_mask( From e3348c7d3aa99ffbeffe5b05a3912baf2897f88e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 15:22:58 -0600 Subject: [PATCH 040/108] isort fixes --- tests/trainers/test_classification.py | 2 +- tests/trainers/test_detection.py | 2 +- tests/trainers/test_segmentation.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 96617c49bb1..ad28022114e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -15,7 +15,6 @@ from torch.nn.modules import Module from torchvision.models._api import WeightsEnum -from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, @@ -24,6 +23,7 @@ UCMercedDataModule, ) from torchgeo.models import ResNet18_Weights +from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 0b4ec5314ae..293b7bc33a2 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,8 +9,8 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datasets import NASAMarineDebris from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 79d5be11036..1ed7f61b6bb 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -11,7 +11,6 @@ from pytorch_lightning import LightningDataModule, Trainer from torch.nn.modules import Module -from torchgeo.datasets import LandCoverAI from torchgeo.datamodules import ( ChesapeakeCVPRDataModule, DeepGlobeLandCoverDataModule, From 2409d8dd6b7bf5b84df23b2340bb089dc4e9ef2f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 20:31:31 -0600 Subject: [PATCH 041/108] Undo breaking change --- torchgeo/datasets/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index a88bd3fcd7a..189875f3c1f 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -428,7 +428,7 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: data = self._merge_files(filepaths, query, self.band_indexes) key = "image" if self.is_image else "mask" - sample = {key: data} + sample = {key: data, "crs": self.crs, "bbox": query} if self.transforms is not None: sample = self.transforms(sample) From 8b0a6bff3d9c5dfe2a80b9a000a024402c294ea1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 22:58:28 -0600 Subject: [PATCH 042/108] Remove more code duplication, standardize docstrings --- torchgeo/datamodules/bigearthnet.py | 49 ++------ torchgeo/datamodules/cowc.py | 32 ++--- torchgeo/datamodules/cyclone.py | 69 ++++------- torchgeo/datamodules/deepglobelandcover.py | 46 ++++---- torchgeo/datamodules/etci2021.py | 43 +------ torchgeo/datamodules/eurosat.py | 35 +----- torchgeo/datamodules/fair1m.py | 30 ++--- torchgeo/datamodules/geo.py | 131 +++++++++++++++------ torchgeo/datamodules/gid15.py | 55 ++++----- torchgeo/datamodules/inria.py | 47 ++++---- torchgeo/datamodules/landcoverai.py | 35 +----- torchgeo/datamodules/loveda.py | 43 +++---- torchgeo/datamodules/nasa_marine_debris.py | 40 +++---- torchgeo/datamodules/oscd.py | 52 ++++---- torchgeo/datamodules/potsdam.py | 45 ++++--- torchgeo/datamodules/resisc45.py | 35 +----- torchgeo/datamodules/sen12ms.py | 111 +++++++---------- torchgeo/datamodules/so2sat.py | 38 +----- torchgeo/datamodules/spacenet.py | 40 +++---- torchgeo/datamodules/ucmerced.py | 35 +----- torchgeo/datamodules/usavars.py | 32 +---- torchgeo/datamodules/vaihingen.py | 45 ++++--- torchgeo/datamodules/xview.py | 36 +++--- 23 files changed, 430 insertions(+), 694 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 66f19bf373a..78b55d643fd 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -3,7 +3,7 @@ """BigEarthNet datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K import torch @@ -70,50 +70,27 @@ class BigEarthNetDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + """Initialize a new BigEarthNetDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.BigEarthNet` + :class:`~torchgeo.datasets.BigEarthNet`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(BigEarthNet, batch_size, num_workers, **kwargs) bands = kwargs.get("bands", "all") if bands == "all": - self.mins = self.band_mins - self.maxs = self.band_maxs + mins = self.band_mins + maxs = self.band_maxs elif bands == "s1": - self.mins = self.band_mins[:2] - self.maxs = self.band_maxs[:2] + mins = self.band_mins[:2] + maxs = self.band_maxs[:2] else: - self.mins = self.band_mins[2:] - self.maxs = self.band_maxs[2:] + mins = self.band_mins[2:] + maxs = self.band_maxs[2:] self.aug = AugmentationSequential( - K.Normalize(mean=self.mins, std=self.maxs - self.mins), data_keys=["image"] + K.Normalize(mean=mins, std=maxs - mins), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - BigEarthNet(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = BigEarthNet(split="train", **self.kwargs) - self.val_dataset = BigEarthNet(split="val", **self.kwargs) - self.test_dataset = BigEarthNet(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 3cab8ba0ed6..3621bad76e3 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -3,7 +3,7 @@ """COWC datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K from torch.utils.data import random_split @@ -19,39 +19,25 @@ class COWCCountingDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new COWCCountingDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.COWCCounting` + :class:`~torchgeo.datasets.COWCCounting`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(COWCCounting, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def prepare_data(self) -> None: - """Initialize the main Dataset objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - if self.kwargs.get("download", False): - COWCCounting(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ train_val_dataset = COWCCounting(split="train", **self.kwargs) self.test_dataset = COWCCounting(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 9a3beaf3b99..b63cb533e13 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -3,7 +3,7 @@ """Tropical Cyclone Wind Estimation Competition datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K from sklearn.model_selection import GroupShuffleSplit @@ -28,60 +28,41 @@ class TropicalCycloneDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. + """Initialize a new TropicalCycloneDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.TropicalCyclone` + :class:`~torchgeo.datasets.TropicalCyclone`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(TropicalCyclone, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def prepare_data(self) -> None: - """Initialize the main Dataset objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - if self.kwargs.get("download", False): - TropicalCyclone(split="train", **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - This method is called once per GPU per run. - - We split samples between train/val by the ``storm_id`` property. I.e. all - samples with the same ``storm_id`` value will be either in the train or the val - split. This is important to test one type of generalizability -- given a new - storm, can we predict its windspeed. The test set, however, contains *some* - storms from the training set (specifically, the latter parts of the storms) as - well as some novel storms. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.all_train_dataset = TropicalCyclone(split="train", **self.kwargs) - - storm_ids = [] - for item in self.all_train_dataset.collection: - storm_id = item["href"].split("/")[0].split("_")[-2] - storm_ids.append(storm_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2).split( - storm_ids, groups=storm_ids + if stage in ["fit", "validate"]: + dataset = TropicalCyclone(split="train", **self.kwargs) + + storm_ids = [] + for item in dataset.collection: + storm_id = item["href"].split("/")[0].split("_")[-2] + storm_ids.append(storm_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2).split( + storm_ids, groups=storm_ids + ) ) - ) - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = TropicalCyclone(split="test", **self.kwargs) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) + elif stage in ["test"]: + self.test_dataset = TropicalCyclone(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index cf868f9b60e..c6d92008eea 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,7 +3,7 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K @@ -30,7 +30,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new DeepGlobeLandCoverDataModule instance. The DeepGlobe Land Cover dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image tiles @@ -43,26 +43,22 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.DeepGlobeLandCover` + :class:`~torchgeo.datasets.DeepGlobeLandCover`. """ - super().__init__() + super().__init__(DeepGlobeLandCover, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch - self.test_batch_size = 1 self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -75,16 +71,16 @@ def __init__( data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_dataset = DeepGlobeLandCover(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = DeepGlobeLandCover(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, self.val_split_pct + ) + elif stage in ["test"]: + self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index fac22ce7adf..a6bcde7ad1a 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -3,10 +3,9 @@ """ETCI 2021 datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K -from torch.utils.data import random_split from ..datasets import ETCI2021 from ..transforms import AugmentationSequential @@ -35,46 +34,16 @@ class ETCI2021DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new ETCI2021DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.ETCI2021` + :class:`~torchgeo.datasets.ETCI2021`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(ETCI2021, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - ETCI2021(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_val_dataset = ETCI2021(split="train", **self.kwargs) - self.test_dataset = ETCI2021(split="val", **self.kwargs) - - size_train_val = len(train_val_dataset) - size_train = round(0.8 * size_train_val) - size_val = size_train_val - size_train - - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, [size_train, size_val] - ) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 637508506da..21cd2c901db 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -3,7 +3,7 @@ """EuroSAT datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -55,39 +55,16 @@ class EuroSATDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new EuroSATDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.EuroSAT` + :class:`~torchgeo.datasets.EuroSAT`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(EuroSAT, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - EuroSAT(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = EuroSAT(split="train", **self.kwargs) - self.val_dataset = EuroSAT(split="val", **self.kwargs) - self.test_dataset = EuroSAT(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 147b5c37e08..32b423dc769 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -3,7 +3,7 @@ """FAIR1M datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -27,36 +27,32 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new FAIR1MDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: Percentage of the dataset to use as a validation set - test_split_pct: Percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.FAIR1M` """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(FAIR1M, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = FAIR1M(**self.kwargs) + dataset = FAIR1M(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index bf1206b0e70..3234303fdb6 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,7 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Type import matplotlib.pyplot as plt from pytorch_lightning import LightningDataModule @@ -16,36 +16,85 @@ from torch.nn import Identity, Module from torch.utils.data import DataLoader, Dataset +from ..datasets import NonGeoDataset + class NonGeoDataModule(LightningDataModule): """Base class for data modules lacking geospatial information.""" - #: Training dataset - train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None - - #: Validation dataset - val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + def __init__( + self, + dataset_class: Type[NonGeoDataset], + batch_size: int = 1, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new NonGeoDataModule instance. - #: Testing dataset - test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + Args: + dataset_class: Class used to instantiate a new dataset. + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to ``dataset_class`` + """ + super().__init__() + + self.dataset_class = dataset_class + self.batch_size = batch_size + self.num_workers = num_workers + self.kwargs = kwargs + + # Datasets + self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + # Data loaders + self.train_batch_size: Optional[int] = None + self.val_batch_size: Optional[int] = None + self.test_batch_size: Optional[int] = None + self.predict_batch_size: Optional[int] = None + + # Data augmentation + self.aug: Optional[Module] = None + self.train_aug: Optional[Module] = None + self.val_aug: Optional[Module] = None + self.test_aug: Optional[Module] = None + self.predict_aug: Optional[Module] = None + + def prepare_data(self) -> None: + """Download and prepare data. + + During distributed training, this method is called only within a single process + to avoid corrupted data. This method should not set state since it is not called + on every device, use :meth:`setup` instead. + """ + if self.kwargs.get("download", False): + self.dataset_class(**self.kwargs) - #: Prediction dataset - predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + def setup(self, stage: str) -> None: + """Set up datasets. - # DataLoader arguments - batch_size: Optional[int] = None - train_batch_size: Optional[int] = None - val_batch_size: Optional[int] = None - test_batch_size: Optional[int] = None - predict_batch_size: Optional[int] = None - num_workers = 0 + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. - # Data augmentation - aug: Optional[Module] = None - train_aug: Optional[Module] = None - val_aug: Optional[Module] = None - test_aug: Optional[Module] = None - predict_aug: Optional[Module] = None + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = self.dataset_class( # type: ignore[call-arg] + split="train", **self.kwargs + ) + elif stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( # type: ignore[call-arg] + split="val", **self.kwargs + ) + elif stage in ["test"]: + self.test_dataset = self.dataset_class( # type: ignore[call-arg] + split="test", **self.kwargs + ) def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: """Implement one or more PyTorch DataLoaders for training. @@ -54,17 +103,18 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: A collection of data loaders specifying training samples. Raises: - MisconfigurationException: If :attr:`train_dataset` is not defined. + MisconfigurationException: If :meth:`setup` does not define a + 'train_dataset'. """ if self.train_dataset is not None: return DataLoader( dataset=self.train_dataset, - batch_size=self.train_batch_size or self.batch_size or 1, + batch_size=self.train_batch_size or self.batch_size, shuffle=True, num_workers=self.num_workers, ) else: - msg = f"{self.__class__.__name__} does not define a 'train_dataset'" + msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" raise MisconfigurationException(msg) def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: @@ -74,17 +124,18 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: A collection of data loaders specifying validation samples. Raises: - MisconfigurationException: If :attr:`val_dataset` is not defined. + MisconfigurationException: If :meth:`setup` does not define a + 'val_dataset'. """ if self.val_dataset is not None: return DataLoader( dataset=self.val_dataset, - batch_size=self.val_batch_size or self.batch_size or 1, + batch_size=self.val_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, ) else: - msg = f"{self.__class__.__name__} does not define a 'val_dataset'" + msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" raise MisconfigurationException(msg) def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: @@ -94,17 +145,18 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: A collection of data loaders specifying testing samples. Raises: - MisconfigurationException: If :attr:`test_dataset` is not defined. + MisconfigurationException: If :meth:`setup` does not define a + 'test_dataset'. """ if self.test_dataset is not None: return DataLoader( dataset=self.test_dataset, - batch_size=self.test_batch_size or self.batch_size or 1, + batch_size=self.test_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, ) else: - msg = f"{self.__class__.__name__} does not define a 'test_dataset'" + msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" raise MisconfigurationException(msg) def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: @@ -114,17 +166,18 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: A collection of data loaders specifying prediction samples. Raises: - MisconfigurationException: If :attr:`predict_dataset` is not defined. + MisconfigurationException: If :meth:`setup` does not define a + 'predict_dataset'. """ if self.predict_dataset is not None: return DataLoader( dataset=self.predict_dataset, - batch_size=self.predict_batch_size or self.batch_size or 1, + batch_size=self.predict_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, ) else: - msg = f"{self.__class__.__name__} does not define a 'predict_dataset'" + msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" raise MisconfigurationException(msg) def on_after_batch_transfer( @@ -161,8 +214,8 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: **kwargs: Keyword arguments passed to plot method. Returns: - a matplotlib Figure with the image, ground truth, and predictions + A matplotlib Figure with the image, ground truth, and predictions. """ - if self.train_dataset is not None: - if hasattr(self.train_dataset, "plot"): - return self.train_dataset.plot(*args, **kwargs) + if self.val_dataset is not None: + if hasattr(self.val_dataset, "plot"): + return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 54a300760d5..7ad19b33a76 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,7 +3,7 @@ """GID-15 datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K @@ -32,7 +32,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new GID15DataModule instance. The GID-15 dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image tiles @@ -41,25 +41,22 @@ def __init__( ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + val_split_pct: Percentage of the dataset to use as a validation set + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.GID15` + :class:`~torchgeo.datasets.GID15`. """ - super().__init__() + super().__init__(GID15, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -77,27 +74,17 @@ def __init__( data_keys=["image"], ) - def prepare_data(self) -> None: - """Initialize the main Dataset objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - if self.kwargs.get("download", False): - GID15(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_dataset = GID15(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - - # Test set masks are not public, use for prediction instead - self.predict_dataset = GID15(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = GID15(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, self.val_split_pct + ) + elif stage in ["test"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = GID15(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 9b5035821c9..fb802b6f3e9 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,7 +3,7 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K @@ -34,7 +34,7 @@ def __init__( test_split_pct: float = 0.1, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new InriaAerialImageLabelingDataModule instance. The Inria Aerial Image Labeling dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image @@ -43,27 +43,24 @@ def __init__( ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.InriaAerialImageLabeling` + :class:`~torchgeo.datasets.InriaAerialImageLabeling`. """ - super().__init__() + super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) - self.num_workers = num_workers self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -83,13 +80,17 @@ def __init__( data_keys=["image"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + def setup(self, stage: str) -> None: + """Set up datasets. - This method is called once per GPU per run. + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - dataset = InriaAerialImageLabeling(split="train", **self.kwargs) - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, self.val_split_pct, self.test_split_pct - ) - self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) + if stage in ["fit", "validate", "test"]: + dataset = InriaAerialImageLabeling(split="train", **self.kwargs) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + dataset, self.val_split_pct, self.test_split_pct + ) + elif stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 84fbbf64e86..eabe26466ff 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -3,7 +3,7 @@ """LandCover.ai datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -21,18 +21,15 @@ class LandCoverAIDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new LandCoverAIDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.LandCoverAI` + :class:`~torchgeo.datasets.LandCoverAI`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -53,23 +50,3 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - LandCoverAI(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = LandCoverAI(split="train", **self.kwargs) - self.val_dataset = LandCoverAI(split="val", **self.kwargs) - self.test_dataset = LandCoverAI(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 3f3207a1eb9..40cde2da63f 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -3,7 +3,7 @@ """LoveDA datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -23,41 +23,30 @@ class LoveDADataModule(NonGeoDataModule): def __init__( self, batch_size: int = 32, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for LoveDA based DataLoaders. + """Initialize a new LoveDADataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.LoveDA` + :class:`~torchgeo.datasets.LoveDA`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(LoveDA, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - LoveDA(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.train_dataset = LoveDA(split="train", **self.kwargs) - self.val_dataset = LoveDA(split="val", **self.kwargs) - - # Test set masks are not public, use for prediction instead - self.predict_dataset = LoveDA(split="test", **self.kwargs) + if stage in ["fit"]: + self.train_dataset = LoveDA(split="train", **self.kwargs) + elif stage in ["fit", "validate"]: + self.val_dataset = LoveDA(split="val", **self.kwargs) + elif stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = LoveDA(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index cb38b8ace9e..83f40a8a249 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,7 +3,7 @@ """NASA Marine Debris datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -27,44 +27,32 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new NASAMarineDebrisDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.NASAMarineDebris` + :class:`~torchgeo.datasets.NASAMarineDebris`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(NASAMarineDebris, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - NASAMarineDebris(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = NASAMarineDebris(**self.kwargs) + dataset = NASAMarineDebris(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 2b7588271c2..ed7cf69b3e5 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,7 +3,7 @@ """OSCD datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K import torch @@ -70,7 +70,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new OSCDDataModule instance. The OSCD dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image tiles @@ -79,25 +79,22 @@ def __init__( ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.OSCD` + :class:`~torchgeo.datasets.OSCD`. """ - super().__init__() + super().__init__(OSCD, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs self.bands = kwargs.get("bands", "all") if self.bands == "rgb": @@ -115,21 +112,16 @@ def __init__( data_keys=["image", "mask"], ) - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. + def setup(self, stage: str) -> None: + """Set up datasets. - This method is only called once per run. - """ - if self.kwargs.get("download", False): - OSCD(split="train", **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_dataset = OSCD(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, val_pct=self.val_split_pct - ) - self.test_dataset = OSCD(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = OSCD(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, val_pct=self.val_split_pct + ) + elif stage in ["test"]: + self.test_dataset = OSCD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index accea53ac92..e9c561a9036 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,7 +3,7 @@ """Potsdam datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K @@ -32,7 +32,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new Potsdam2DDataModule instance. The Potsdam2D dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image tiles @@ -45,25 +45,22 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.Potsdam2D` + :class:`~torchgeo.datasets.Potsdam2D`. """ - super().__init__() + super().__init__(Potsdam2D, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -76,16 +73,16 @@ def __init__( data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_dataset = Potsdam2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = Potsdam2D(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = Potsdam2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, self.val_split_pct + ) + elif stage in ["test"]: + self.test_dataset = Potsdam2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index dbfc67da430..c29402d4f84 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -3,7 +3,7 @@ """RESISC45 datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -24,18 +24,15 @@ class RESISC45DataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new RESISC45DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.RESISC45` + :class:`~torchgeo.datasets.RESISC45`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(RESISC45, batch_size, num_workers, **kwargs) self.train_aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), @@ -57,23 +54,3 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - RESISC45(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = RESISC45(split="train", **self.kwargs) - self.val_dataset = RESISC45(split="val", **self.kwargs) - self.test_dataset = RESISC45(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 2fe0f3c741c..df2ae188f8e 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -3,7 +3,7 @@ """SEN12MS datamodule.""" -from typing import Any, Optional +from typing import Any import torch from sklearn.model_selection import GroupShuffleSplit @@ -17,96 +17,75 @@ class SEN12MSDataModule(NonGeoDataModule): """LightningDataModule implementation for the SEN12MS dataset. Implements 80/20 geographic train/val splits and uses the test split from the - classification dataset definitions. See :func:`setup` for more details. + classification dataset definitions. Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See https://arxiv.org/abs/2002.08254. """ #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader - #: here https://github.com/lukasliebel/dfc2020_baseline. + #: here: https://github.com/lukasliebel/dfc2020_baseline. DFC2020_CLASS_MAPPING = torch.tensor( - [ - 0, # maps 0s to 0 - 1, # maps 1s to 1 - 1, # maps 2s to 1 - 1, # ... - 1, - 1, - 2, - 2, - 3, - 3, - 4, - 5, - 6, - 7, - 6, - 8, - 9, - 10, - ] + [0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10] ) def __init__( self, - band_set: str = "all", batch_size: int = 64, num_workers: int = 0, + band_set: str = "all", **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new SEN12MSDataModule instance. Args: - band_set: The subset of S1/S2 bands to use. Options are: "all", + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + band_set: Subset of S1/S2 bands to use. Options are: "all", "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: B2, B3, B4, B8, B11, and B12. - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SEN12MS` + :class:`~torchgeo.datasets.SEN12MS`. """ - super().__init__() - assert band_set in SEN12MS.BAND_SETS.keys() + super().__init__(SEN12MS, batch_size, num_workers, **kwargs) + assert band_set in SEN12MS.BAND_SETS.keys() self.band_set = band_set self.bands = SEN12MS.BAND_SETS[band_set] - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - We split samples between train and val geographically with proportions of 80/20. - This mimics the geographic test set split. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - - self.all_train_dataset = SEN12MS(split="train", bands=self.bands, **self.kwargs) - - self.test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) - - # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" - # This patch will belong to the scene that is uniquelly identified by its - # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply - # give each season a large number and representing a `unique_scene_id` as - # `season_id + scene_id`. - scenes = [] - for scene_fn in self.all_train_dataset.ids: - parts = scene_fn.split("_") - season_id = season_to_int[parts[1]] - scene_id = int(parts[3]) - scenes.append(season_id + scene_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2).split(scenes, groups=scenes) - ) - - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) + if stage in ["fit", "validate"]: + season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} + + self.all_train_dataset = SEN12MS( + split="train", bands=self.bands, **self.kwargs + ) + + # A patch is a filename like: + # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" + # This patch will belong to the scene that is uniquely identified by its + # (season, scene_id) tuple. Because the largest scene_id is 149, we can + # simply give each season a large number and representing a unique_scene_id + # as (season_id + scene_id). + scenes = [] + for scene_fn in self.all_train_dataset.ids: + parts = scene_fn.split("_") + season_id = season_to_int[parts[1]] + scene_id = int(parts[3]) + scenes.append(season_id + scene_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2).split( + scenes, groups=scenes + ) + ) + + self.train_dataset = Subset(self.all_train_dataset, train_indices) + self.val_dataset = Subset(self.all_train_dataset, val_indices) + elif stage in ["test"]: + self.test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index f7ccf7e2c22..249021e0c82 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,7 +3,7 @@ """So2Sat datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -48,44 +48,18 @@ class So2SatDataModule(NonGeoDataModule): reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] def __init__( - self, - batch_size: int = 64, - num_workers: int = 0, - band_set: str = "rgb", - unsupervised_mode: bool = False, - **kwargs: Any, + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new So2SatDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - band_set: Collection of So2Sat bands to use - unsupervised_mode: Makes the train dataloader return imagery from the train, - val, and test sets + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.So2Sat` """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.band_set = band_set - self.unsupervised_mode = unsupervised_mode - self.kwargs = kwargs + super().__init__(So2Sat, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] ) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - bands = So2Sat.BAND_SETS["s2"] - self.train_dataset = So2Sat(split="train", bands=bands, **self.kwargs) - self.val_dataset = So2Sat(split="validation", bands=bands, **self.kwargs) - self.test_dataset = So2Sat(split="test", bands=bands, **self.kwargs) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index cda314518ea..6bfa89f8485 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -3,7 +3,7 @@ """SpaceNet datamodules.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -29,22 +29,20 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for SpaceNet1. + """Initialize a new SpaceNet1DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SpaceNet1` + :class:`~torchgeo.datasets.SpaceNet1`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(SpaceNet1, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -69,23 +67,13 @@ def __init__( data_keys=["image", "mask"], ) - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - SpaceNet1(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = SpaceNet1(**self.kwargs) + dataset = SpaceNet1(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + dataset, self.val_split_pct, self.test_split_pct ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 90c9c29bad6..6cc20eff146 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -3,7 +3,7 @@ """UC Merced datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -21,39 +21,16 @@ class UCMercedDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for UCMerced based DataLoaders. + """Initialize a new UCMercedDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.UCMerced` + :class:`~torchgeo.datasets.UCMerced`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(UCMerced, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), K.Resize(size=256), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - UCMerced(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = UCMerced(split="train", **self.kwargs) - self.val_dataset = UCMerced(split="val", **self.kwargs) - self.test_dataset = UCMerced(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index fc2e05e5a0e..97c964a9b77 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -3,7 +3,7 @@ """USAVars datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -23,36 +23,16 @@ class USAVarsDataModule(NonGeoDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for USAVars based DataLoaders. + """Initialize a new USAVarsDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.USAVars` + :class:`~torchgeo.datasets.USAVars`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(USAVars, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - USAVars(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - """ - self.train_dataset = USAVars(split="train", **self.kwargs) - self.val_dataset = USAVars(split="val", **self.kwargs) - self.test_dataset = USAVars(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 761c352ab3a..ac2fcf21b67 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,7 +3,7 @@ """Vaihingen datamodule.""" -from typing import Any, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K @@ -32,7 +32,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new Vaihingen2DDataModule instance. The Vaihingen2D dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image tiles @@ -45,25 +45,22 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.Vaihingen2D` + :class:`~torchgeo.datasets.Vaihingen2D`. """ - super().__init__() + super().__init__(Vaihingen2D, 1, num_workers, **kwargs) self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs self.train_aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), @@ -76,16 +73,16 @@ def __init__( data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_dataset = Vaihingen2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = Vaihingen2D(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = Vaihingen2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, self.val_split_pct + ) + elif stage in ["test"]: + self.test_dataset = Vaihingen2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 744d5f68595..d89655b16ed 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -3,7 +3,7 @@ """xView2 datamodule.""" -from typing import Any, Optional +from typing import Any import kornia.augmentation as K @@ -28,35 +28,33 @@ def __init__( val_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for xView2 based DataLoaders. + """Initialize a new XView2DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. val_split_pct: What percentage of the dataset to use as a validation set **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.XView2` + :class:`~torchgeo.datasets.XView2`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(XView2, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct - self.kwargs = kwargs self.aug = AugmentationSequential( K.Normalize(mean=0.0, std=255.0), data_keys=["image"] ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - dataset = XView2(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - dataset, val_pct=self.val_split_pct - ) - self.test_dataset = XView2(split="test", **self.kwargs) + if stage in ["fit", "validate"]: + dataset = XView2(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + dataset, val_pct=self.val_split_pct + ) + elif stage in ["test"]: + self.test_dataset = XView2(split="test", **self.kwargs) From 2e1356fffab27a3f94fa2f6db81896964052a0bc Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 1 Jan 2023 23:08:54 -0600 Subject: [PATCH 043/108] mypy fixes --- tests/datamodules/test_fair1m.py | 7 +++++-- tests/datamodules/test_oscd.py | 4 +++- tests/datamodules/test_usavars.py | 7 +++++-- tests/datamodules/test_xview2.py | 7 +++++-- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index ac9d196d58f..000b0144cc2 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -23,20 +23,23 @@ def datamodule(self) -> FAIR1MDataModule: val_split_pct=0.33, test_split_pct=0.33, ) - dm.setup() return dm def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("fit") next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("validate") next(iter(datamodule.val_dataloader())) def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("test") next(iter(datamodule.test_dataloader())) def test_plot(self, datamodule: FAIR1MDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 6e5b76a9870..93efb909fa1 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -30,11 +30,11 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule: num_workers=num_workers, ) dm.prepare_data() - dm.setup() dm.trainer = Trainer() return dm def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.setup("fit") datamodule.trainer.training = True # type: ignore[union-attr] sample = next(iter(datamodule.train_dataloader())) sample = datamodule.on_after_batch_transfer(sample, 0) @@ -46,6 +46,7 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: assert sample["image"].shape[1] == 6 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.setup("validate") datamodule.trainer.validating = True # type: ignore[union-attr] sample = next(iter(datamodule.val_dataloader())) sample = datamodule.on_after_batch_transfer(sample, 0) @@ -58,6 +59,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: assert sample["image"].shape[1] == 6 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: + datamodule.setup("test") datamodule.trainer.testing = True # type: ignore[union-attr] sample = next(iter(datamodule.test_dataloader())) sample = datamodule.on_after_batch_transfer(sample, 0) diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index ddcde7e6b26..2d8e3d7b9eb 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -23,26 +23,29 @@ def datamodule(self, request: SubRequest) -> USAVarsDataModule: root=root, batch_size=batch_size, num_workers=num_workers, download=True ) dm.prepare_data() - dm.setup() return dm def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("fit") assert len(datamodule.train_dataloader()) == 3 sample = next(iter(datamodule.train_dataloader())) assert sample["image"].shape[0] == datamodule.batch_size def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("validate") assert len(datamodule.val_dataloader()) == 2 sample = next(iter(datamodule.val_dataloader())) assert sample["image"].shape[0] == datamodule.batch_size def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("test") assert len(datamodule.test_dataloader()) == 1 sample = next(iter(datamodule.test_dataloader())) assert sample["image"].shape[0] == datamodule.batch_size def test_plot(self, datamodule: USAVarsDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index c190b5d2acd..84735a6e135 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -25,20 +25,23 @@ def datamodule(self, request: SubRequest) -> XView2DataModule: val_split_pct=val_split_size, ) dm.prepare_data() - dm.setup() return dm def test_train_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("fit") next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("validate") next(iter(datamodule.val_dataloader())) def test_test_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("test") next(iter(datamodule.test_dataloader())) def test_plot(self, datamodule: XView2DataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() From b9cd885604caaa6187a4082c198ad892532fe030 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 01:21:02 -0600 Subject: [PATCH 044/108] Add default augmentation --- torchgeo/datamodules/bigearthnet.py | 30 ++++----- torchgeo/datamodules/cowc.py | 6 -- torchgeo/datamodules/cyclone.py | 6 -- torchgeo/datamodules/deepglobelandcover.py | 4 +- torchgeo/datamodules/etci2021.py | 29 +++++---- torchgeo/datamodules/eurosat.py | 71 +++++++++++----------- torchgeo/datamodules/fair1m.py | 7 --- torchgeo/datamodules/geo.py | 20 ++++-- torchgeo/datamodules/gid15.py | 6 +- torchgeo/datamodules/inria.py | 6 +- torchgeo/datamodules/landcoverai.py | 4 +- torchgeo/datamodules/loveda.py | 7 --- torchgeo/datamodules/nasa_marine_debris.py | 7 --- torchgeo/datamodules/oscd.py | 12 ++-- torchgeo/datamodules/potsdam.py | 4 +- torchgeo/datamodules/resisc45.py | 10 ++- torchgeo/datamodules/so2sat.py | 59 +++++++++--------- torchgeo/datamodules/spacenet.py | 4 +- torchgeo/datamodules/ucmerced.py | 4 +- torchgeo/datamodules/usavars.py | 7 --- torchgeo/datamodules/vaihingen.py | 4 +- torchgeo/datamodules/xview.py | 7 --- 22 files changed, 134 insertions(+), 180 deletions(-) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 78b55d643fd..b75c38af3dd 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -5,11 +5,9 @@ from typing import Any -import kornia.augmentation as K import torch from ..datasets import BigEarthNet -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -21,10 +19,10 @@ class BigEarthNetDataModule(NonGeoDataModule): # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) # min/max band statistics computed on 100k random samples - band_mins_raw = torch.tensor( + mins_raw = torch.tensor( [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] ) - band_maxs_raw = torch.tensor( + maxs_raw = torch.tensor( [ 31.0, 35.0, @@ -45,10 +43,10 @@ class BigEarthNetDataModule(NonGeoDataModule): # min/max band statistics computed by percentile clipping the # above to samples to [2, 98] - band_mins = torch.tensor( + mins = torch.tensor( [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ) - band_maxs = torch.tensor( + maxs = torch.tensor( [ 6.0, 16.0, @@ -78,19 +76,17 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.BigEarthNet`. """ - super().__init__(BigEarthNet, batch_size, num_workers, **kwargs) - bands = kwargs.get("bands", "all") if bands == "all": - mins = self.band_mins - maxs = self.band_maxs + mins = self.mins + maxs = self.maxs elif bands == "s1": - mins = self.band_mins[:2] - maxs = self.band_maxs[:2] + mins = self.mins[:2] + maxs = self.maxs[:2] else: - mins = self.band_mins[2:] - maxs = self.band_maxs[2:] + mins = self.mins[2:] + maxs = self.maxs[2:] + self.mean = mins + self.std = maxs - mins - self.aug = AugmentationSequential( - K.Normalize(mean=mins, std=maxs - mins), data_keys=["image"] - ) + super().__init__(BigEarthNet, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 3621bad76e3..da2a9a14870 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -5,11 +5,9 @@ from typing import Any -import kornia.augmentation as K from torch.utils.data import random_split from ..datasets import COWCCounting -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -29,10 +27,6 @@ def __init__( """ super().__init__(COWCCounting, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index b63cb533e13..d15eeda90f4 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -5,12 +5,10 @@ from typing import Any -import kornia.augmentation as K from sklearn.model_selection import GroupShuffleSplit from torch.utils.data import Subset from ..datasets import TropicalCyclone -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -38,10 +36,6 @@ def __init__( """ super().__init__(TropicalCyclone, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index c6d92008eea..90ad3dc4501 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -61,12 +61,12 @@ def __init__( self.val_split_pct = val_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index a6bcde7ad1a..f6fa38005bd 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -5,10 +5,9 @@ from typing import Any -import kornia.augmentation as K +import torch from ..datasets import ETCI2021 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -21,15 +20,19 @@ class ETCI2021DataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - band_means = [ - 128.02253931, - 128.02253931, - 128.02253931, - 128.11221701, - 128.11221701, - 128.11221701, - ] - band_stds = [89.8145088, 89.8145088, 89.8145088, 95.2797861, 95.2797861, 95.2797861] + mean = torch.tensor( + [ + 128.02253931, + 128.02253931, + 128.02253931, + 128.11221701, + 128.11221701, + 128.11221701, + ] + ) + std = torch.tensor( + [89.8145088, 89.8145088, 89.8145088, 95.2797861, 95.2797861, 95.2797861] + ) def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any @@ -43,7 +46,3 @@ def __init__( :class:`~torchgeo.datasets.ETCI2021`. """ super().__init__(ETCI2021, batch_size, num_workers, **kwargs) - - self.aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] - ) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index 21cd2c901db..44f7c0ec7a4 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -5,10 +5,9 @@ from typing import Any -import kornia.augmentation as K +import torch from ..datasets import EuroSAT -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -20,37 +19,41 @@ class EuroSATDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - band_means = [ - 1354.40546513, - 1118.24399958, - 1042.92983953, - 947.62620298, - 1199.47283961, - 1999.79090914, - 2369.22292565, - 2296.82608323, - 732.08340178, - 12.11327804, - 1819.01027855, - 1118.92391149, - 2594.14080798, - ] + mean = torch.tensor( + [ + 1354.40546513, + 1118.24399958, + 1042.92983953, + 947.62620298, + 1199.47283961, + 1999.79090914, + 2369.22292565, + 2296.82608323, + 732.08340178, + 12.11327804, + 1819.01027855, + 1118.92391149, + 2594.14080798, + ] + ) - band_stds = [ - 245.71762908, - 333.00778264, - 395.09249139, - 593.75055589, - 566.4170017, - 861.18399006, - 1086.63139075, - 1117.98170791, - 404.91978886, - 4.77584468, - 1002.58768311, - 761.30323499, - 1231.58581042, - ] + std = torch.tensor( + [ + 245.71762908, + 333.00778264, + 395.09249139, + 593.75055589, + 566.4170017, + 861.18399006, + 1086.63139075, + 1117.98170791, + 404.91978886, + 4.77584468, + 1002.58768311, + 761.30323499, + 1231.58581042, + ] + ) def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any @@ -64,7 +67,3 @@ def __init__( :class:`~torchgeo.datasets.EuroSAT`. """ super().__init__(EuroSAT, batch_size, num_workers, **kwargs) - - self.aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] - ) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 32b423dc769..cc7418764b9 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -5,10 +5,7 @@ from typing import Any -import kornia.augmentation as K - from ..datasets import FAIR1M -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import dataset_split @@ -42,10 +39,6 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 3234303fdb6..37c2b449112 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -5,7 +5,9 @@ from typing import Any, Dict, Optional, Type +import kornia.augmentation as K import matplotlib.pyplot as plt +import torch from pytorch_lightning import LightningDataModule # TODO: import from lightning_lite instead @@ -13,15 +15,19 @@ MisconfigurationException, ) from torch import Tensor -from torch.nn import Identity, Module +from torch.nn import Module from torch.utils.data import DataLoader, Dataset from ..datasets import NonGeoDataset +from ..transforms import AugmentationSequential class NonGeoDataModule(LightningDataModule): """Base class for data modules lacking geospatial information.""" + mean = torch.tensor(0) + std = torch.tensor(255) + def __init__( self, dataset_class: Type[NonGeoDataset], @@ -57,7 +63,9 @@ def __init__( self.predict_batch_size: Optional[int] = None # Data augmentation - self.aug: Optional[Module] = None + self.aug: Module = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + ) self.train_aug: Optional[Module] = None self.val_aug: Optional[Module] = None self.test_aug: Optional[Module] = None @@ -194,13 +202,13 @@ def on_after_batch_transfer( """ if self.trainer: if self.trainer.training: - aug = self.train_aug or self.aug or Identity() + aug = self.train_aug or self.aug elif self.trainer.validating: - aug = self.val_aug or self.aug or Identity() + aug = self.val_aug or self.aug elif self.trainer.testing: - aug = self.test_aug or self.aug or Identity() + aug = self.test_aug or self.aug elif self.trainer.predicting: - aug = self.predict_aug or self.aug or Identity() + aug = self.predict_aug or self.aug batch = aug(batch) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 7ad19b33a76..65b4ea32a20 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -59,17 +59,17 @@ def __init__( self.val_split_pct = val_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.val_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) self.predict_transform = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image"], ) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index fb802b6f3e9..edf7c4d173d 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -63,19 +63,19 @@ def __init__( self.test_split_pct = test_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.val_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) self.predict_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image"], ) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index eabe26466ff..0ac1ec385ac 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -32,7 +32,7 @@ def __init__( super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), @@ -48,5 +48,5 @@ def __init__( data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] ) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 40cde2da63f..8dc7d050fb2 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -5,10 +5,7 @@ from typing import Any -import kornia.augmentation as K - from ..datasets import LoveDA -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -33,10 +30,6 @@ def __init__( """ super().__init__(LoveDA, batch_size, num_workers, **kwargs) - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 83f40a8a249..b07c464eabf 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -5,10 +5,7 @@ from typing import Any -import kornia.augmentation as K - from ..datasets import NASAMarineDebris -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import dataset_split @@ -42,10 +39,6 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index ed7cf69b3e5..89c20749236 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -25,7 +25,7 @@ class OSCDDataModule(NonGeoDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( + mean = torch.tensor( [ 1583.0741, 1374.3202, @@ -43,7 +43,7 @@ class OSCDDataModule(NonGeoDataModule): ] ) - band_stds = torch.tensor( + std = torch.tensor( [ 52.1937, 83.4168, @@ -98,16 +98,16 @@ def __init__( self.bands = kwargs.get("bands", "all") if self.bands == "rgb": - self.band_means = self.band_means[[3, 2, 1]] - self.band_stds = self.band_stds[[3, 2, 1]] + self.mean = self.mean[[3, 2, 1]] + self.std = self.std[[3, 2, 1]] self.train_aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), + K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index e9c561a9036..5b7637d9e78 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -63,12 +63,12 @@ def __init__( self.val_split_pct = val_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index c29402d4f84..cb1e9553553 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -6,6 +6,7 @@ from typing import Any import kornia.augmentation as K +import torch from ..datasets import RESISC45 from ..transforms import AugmentationSequential @@ -18,8 +19,8 @@ class RESISC45DataModule(NonGeoDataModule): Uses the train/val/test splits from the dataset. """ - band_means = [127.86820969, 127.88083247, 127.84341029] - band_stds = [51.8668062, 47.2380768, 47.0613924] + mean = torch.tensor([127.86820969, 127.88083247, 127.84341029]) + std = torch.tensor([51.8668062, 47.2380768, 47.0613924]) def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any @@ -35,7 +36,7 @@ def __init__( super().__init__(RESISC45, batch_size, num_workers, **kwargs) self.train_aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), + K.Normalize(mean=self.mean, std=self.std), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), @@ -51,6 +52,3 @@ def __init__( ), data_keys=["image"], ) - self.aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] - ) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 249021e0c82..73d83efe109 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -5,10 +5,9 @@ from typing import Any -import kornia.augmentation as K +import torch from ..datasets import So2Sat -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -18,31 +17,35 @@ class So2SatDataModule(NonGeoDataModule): Uses the train/val/test splits from the dataset. """ - band_means = [ - 0.12375696117681859, - 0.1092774636368323, - 0.1010855203267882, - 0.1142398616114001, - 0.1592656692023089, - 0.18147236008771792, - 0.1745740312291377, - 0.19501607349635292, - 0.15428468872076637, - 0.10905050699570007, - ] + mean = torch.tensor( + [ + 0.12375696117681859, + 0.1092774636368323, + 0.1010855203267882, + 0.1142398616114001, + 0.1592656692023089, + 0.18147236008771792, + 0.1745740312291377, + 0.19501607349635292, + 0.15428468872076637, + 0.10905050699570007, + ] + ) - band_stds = [ - 0.03958795985905458, - 0.047778262752410296, - 0.06636616706371974, - 0.06358874912497474, - 0.07744387147984592, - 0.09101635085921553, - 0.09218466562387101, - 0.10164581233948201, - 0.09991773043519253, - 0.08780632509122865, - ] + std = torch.tensor( + [ + 0.03958795985905458, + 0.047778262752410296, + 0.06636616706371974, + 0.06358874912497474, + 0.07744387147984592, + 0.09101635085921553, + 0.09218466562387101, + 0.10164581233948201, + 0.09991773043519253, + 0.08780632509122865, + ] + ) # this reorders the bands to put S2 RGB first, then remainder of S2 reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] @@ -59,7 +62,3 @@ def __init__( :class:`~torchgeo.datasets.So2Sat` """ super().__init__(So2Sat, batch_size, num_workers, **kwargs) - - self.aug = AugmentationSequential( - K.Normalize(mean=self.band_means, std=self.band_stds), data_keys=["image"] - ) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 6bfa89f8485..d8ed9f249de 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -45,7 +45,7 @@ def __init__( self.test_split_pct = test_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), K.RandomRotation(p=0.5, degrees=90), K.RandomHorizontalFlip(p=0.5), @@ -62,7 +62,7 @@ def __init__( data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), K.PadTo((448, 448)), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index 6cc20eff146..dc8f43b5828 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -32,5 +32,7 @@ def __init__( super().__init__(UCMerced, batch_size, num_workers, **kwargs) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), K.Resize(size=256), data_keys=["image"] + K.Normalize(mean=self.mean, std=self.std), + K.Resize(size=256), + data_keys=["image"], ) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 97c964a9b77..329afe94058 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -5,10 +5,7 @@ from typing import Any -import kornia.augmentation as K - from ..datasets import USAVars -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule @@ -32,7 +29,3 @@ def __init__( :class:`~torchgeo.datasets.USAVars`. """ super().__init__(USAVars, batch_size, num_workers, **kwargs) - - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index ac2fcf21b67..56451883d75 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -63,12 +63,12 @@ def __init__( self.val_split_pct = val_split_pct self.train_aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), _ExtractTensorPatches(self.patch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index d89655b16ed..5246a92010b 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -5,10 +5,7 @@ from typing import Any -import kornia.augmentation as K - from ..datasets import XView2 -from ..transforms import AugmentationSequential from .geo import NonGeoDataModule from .utils import dataset_split @@ -41,10 +38,6 @@ def __init__( self.val_split_pct = val_split_pct - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image"] - ) - def setup(self, stage: str) -> None: """Set up datasets. From 98f75b8ad535ae858a56861d17718c07c09ee31f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 09:30:43 -0600 Subject: [PATCH 045/108] Augmentations can be any callable --- torchgeo/datamodules/geo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 37c2b449112..b1e56fea011 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,7 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type import kornia.augmentation as K import matplotlib.pyplot as plt @@ -15,7 +15,6 @@ MisconfigurationException, ) from torch import Tensor -from torch.nn import Module from torch.utils.data import DataLoader, Dataset from ..datasets import NonGeoDataset @@ -63,13 +62,14 @@ def __init__( self.predict_batch_size: Optional[int] = None # Data augmentation - self.aug: Module = AugmentationSequential( + Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + self.aug: Transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] ) - self.train_aug: Optional[Module] = None - self.val_aug: Optional[Module] = None - self.test_aug: Optional[Module] = None - self.predict_aug: Optional[Module] = None + self.train_aug: Optional[Transform] = None + self.val_aug: Optional[Transform] = None + self.test_aug: Optional[Transform] = None + self.predict_aug: Optional[Transform] = None def prepare_data(self) -> None: """Download and prepare data. From 8e1b18d994a83b53b73a35654f1cb8bb11063bcf Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 12:02:37 -0600 Subject: [PATCH 046/108] Fix datasets tests --- tests/datasets/test_bigearthnet.py | 2 +- tests/datasets/test_cyclone.py | 6 +++--- torchgeo/datasets/nasa_marine_debris.py | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 349530e9b18..b6b7324ad5c 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -75,7 +75,7 @@ def test_getitem(self, dataset: BigEarthNet) -> None: assert isinstance(x["image"], torch.Tensor) assert isinstance(x["label"], torch.Tensor) assert x["label"].shape == (dataset.num_classes,) - assert x["image"].dtype == torch.int32 + assert x["image"].dtype == torch.float32 assert x["label"].dtype == torch.int64 if dataset.bands == "all": diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index a7310cc8cd9..1452ad18225 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -61,8 +61,8 @@ def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: assert isinstance(x["storm_id"], str) assert isinstance(x["relative_time"], int) assert isinstance(x["ocean"], int) - assert isinstance(x["label"], int) - assert x["image"].shape == (dataset.size, dataset.size) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape == (3, dataset.size, dataset.size) def test_len(self, dataset: TropicalCyclone) -> None: assert len(dataset) == 5 @@ -88,6 +88,6 @@ def test_plot(self, dataset: TropicalCyclone) -> None: plt.close() sample = dataset[0] - sample["prediction"] = torch.tensor(sample["label"]) + sample["prediction"] = sample["label"] dataset.plot(sample) plt.close() diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 06e37de81f9..83dae527b17 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -234,6 +234,7 @@ def plot( """ ncols = 1 + sample["image"] = sample["image"].byte() image = sample["image"] if "boxes" in sample and len(sample["boxes"]): image = draw_bounding_boxes(image=sample["image"], boxes=sample["boxes"]) From a0c835af5795c501da42492c8d111376de5f414b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 12:30:28 -0600 Subject: [PATCH 047/108] Fix datamodule tests --- tests/datamodules/test_oscd.py | 10 +++++----- tests/datamodules/test_xview2.py | 10 +++------- tests/datasets/test_oscd.py | 6 +++--- torchgeo/datamodules/oscd.py | 5 +++++ torchgeo/datasets/oscd.py | 6 ++++-- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 93efb909fa1..fec19474c96 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -11,9 +11,9 @@ class TestOSCDDataModule: - @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) + @pytest.fixture(scope="class", params=["all", "rgb"]) def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands, val_split_pct = request.param + bands = request.param num_tiles_per_batch = 1 num_patches_per_tile = 2 patch_size = 2 @@ -26,7 +26,7 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule: num_tiles_per_batch=num_tiles_per_batch, num_patches_per_tile=num_patches_per_tile, patch_size=patch_size, - val_split_pct=val_split_pct, + val_split_pct=0.5, num_workers=num_workers, ) dm.prepare_data() @@ -52,7 +52,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: sample = datamodule.on_after_batch_transfer(sample, 0) if datamodule.val_split_pct > 0.0: assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1024 if datamodule.bands == "all": assert sample["image"].shape[1] == 26 else: @@ -64,7 +64,7 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: sample = next(iter(datamodule.test_dataloader())) sample = datamodule.on_after_batch_transfer(sample, 0) assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 + assert sample["image"].shape[0] == sample["mask"].shape[0] == 1024 if datamodule.bands == "all": assert sample["image"].shape[1] == 26 else: diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 84735a6e135..8bb6042aa92 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -12,17 +12,13 @@ class TestXView2DataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> XView2DataModule: + @pytest.fixture(scope="class") + def datamodule(self) -> XView2DataModule: root = os.path.join("tests", "data", "xview2") batch_size = 1 num_workers = 0 - val_split_size = request.param dm = XView2DataModule( - root=root, - batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_size, + root=root, batch_size=batch_size, num_workers=num_workers, val_split_pct=0.5 ) dm.prepare_data() return dm diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index f2f0582dc7e..365bcd97a60 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -73,14 +73,14 @@ def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 4 + assert x["image"].ndim == 3 assert isinstance(x["mask"], torch.Tensor) assert x["mask"].ndim == 2 if dataset.bands == "rgb": - assert x["image"].shape[:2] == (2, 3) + assert x["image"].shape[0] == 6 else: - assert x["image"].shape[:2] == (2, 13) + assert x["image"].shape[0] == 26 def test_len(self, dataset: OSCD) -> None: if dataset.split == "train": diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 89c20749236..81197bfa9bc 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -7,6 +7,7 @@ import kornia.augmentation as K import torch +from einops import repeat from ..datasets import OSCD from ..samplers.utils import _to_tuple @@ -101,6 +102,10 @@ def __init__( self.mean = self.mean[[3, 2, 1]] self.std = self.std[[3, 2, 1]] + # Change detection, 2 images from different times + self.mean = repeat(self.mean, "c -> (t c)", t=2) + self.std = repeat(self.std, "c -> (t c)", t=2) + self.train_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), _RandomNCrop(self.patch_size, self.num_patches_per_tile), diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 3977933432c..b302e8e4d21 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -129,7 +129,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image2 = self._load_image(files["images2"]) mask = self._load_target(str(files["mask"])) - image = torch.stack(tensors=[image1, image2], dim=0) + image = torch.cat([image1, image2]) sample = {"image": image, "mask": mask} if self.transforms is not None: @@ -306,7 +306,9 @@ def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]": ) return array - image1, image2 = get_masked(sample["image"][0]), get_masked(sample["image"][1]) + idx = sample["image"].shape[0] // 2 + image1 = get_masked(sample["image"][:idx]) + image2 = get_masked(sample["image"][idx:]) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) axs[0].axis("off") From 5fa6e0b7c60ffc05324eb5b27e83f623c8f322ec Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 12:43:01 -0600 Subject: [PATCH 048/108] Fix more datamodules --- tests/datamodules/test_chesapeake.py | 21 --------------- tests/datamodules/test_oscd.py | 38 ++++++++++++++-------------- tests/datamodules/test_usavars.py | 16 ++++++------ tests/datamodules/test_xview2.py | 1 - 4 files changed, 27 insertions(+), 49 deletions(-) diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py index fab7b9088a1..55ae702de0e 100644 --- a/tests/datamodules/test_chesapeake.py +++ b/tests/datamodules/test_chesapeake.py @@ -2,34 +2,13 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, cast import pytest -import torch -from omegaconf import OmegaConf from torchgeo.datamodules import ChesapeakeCVPRDataModule class TestChesapeakeCVPRDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> ChesapeakeCVPRDataModule: - conf = OmegaConf.load(os.path.join("tests", "conf", "chesapeake_cvpr_5.yaml")) - kwargs = OmegaConf.to_object(conf.experiment.datamodule) - kwargs = cast(Dict[str, Any], kwargs) - - datamodule = ChesapeakeCVPRDataModule(**kwargs) - datamodule.prepare_data() - datamodule.setup() - return datamodule - - def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: - nodata_check = datamodule.nodata_check(4) - sample = {"image": torch.ones(1, 2, 2), "mask": torch.ones(2, 2)} - out = nodata_check(sample) - assert torch.equal(out["image"], torch.zeros(1, 4, 4)) - assert torch.equal(out["mask"], torch.zeros(4, 4)) - def test_invalid_param_config(self) -> None: with pytest.raises(ValueError, match="The pre-generated prior labels"): ChesapeakeCVPRDataModule( diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index fec19474c96..bb430156fe4 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -30,42 +30,42 @@ def datamodule(self, request: SubRequest) -> OSCDDataModule: num_workers=num_workers, ) dm.prepare_data() - dm.trainer = Trainer() + dm.trainer = Trainer(max_epochs=1) return dm def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("fit") datamodule.trainer.training = True # type: ignore[union-attr] - sample = next(iter(datamodule.train_dataloader())) - sample = datamodule.on_after_batch_transfer(sample, 0) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 + batch = next(iter(datamodule.train_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("validate") datamodule.trainer.validating = True # type: ignore[union-attr] - sample = next(iter(datamodule.val_dataloader())) - sample = datamodule.on_after_batch_transfer(sample, 0) + batch = next(iter(datamodule.val_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1024 + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1024 if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: datamodule.setup("test") datamodule.trainer.testing = True # type: ignore[union-attr] - sample = next(iter(datamodule.test_dataloader())) - sample = datamodule.on_after_batch_transfer(sample, 0) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1024 + batch = next(iter(datamodule.test_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1024 if datamodule.bands == "all": - assert sample["image"].shape[1] == 26 + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 2d8e3d7b9eb..dc56845e3e4 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -8,7 +8,7 @@ from _pytest.fixtures import SubRequest from torchgeo.datamodules import USAVarsDataModule -from torchgeo.datasets import unbind_samples +from torchgeo.datasets import unbind_batchs class TestUSAVarsDataModule: @@ -28,24 +28,24 @@ def datamodule(self, request: SubRequest) -> USAVarsDataModule: def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: datamodule.setup("fit") assert len(datamodule.train_dataloader()) == 3 - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.train_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: datamodule.setup("validate") assert len(datamodule.val_dataloader()) == 2 - sample = next(iter(datamodule.val_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.val_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: datamodule.setup("test") assert len(datamodule.test_dataloader()) == 1 - sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.test_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_plot(self, datamodule: USAVarsDataModule) -> None: datamodule.setup("validate") batch = next(iter(datamodule.val_dataloader())) - sample = unbind_samples(batch)[0] + sample = unbind_batchs(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 8bb6042aa92..6d614184b56 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -5,7 +5,6 @@ import matplotlib.pyplot as plt import pytest -from _pytest.fixtures import SubRequest from torchgeo.datamodules import XView2DataModule from torchgeo.datasets import unbind_samples From 509c8db4b93acac231e671a32db2119213b5c324 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 13:42:41 -0600 Subject: [PATCH 049/108] Typo fix --- tests/datamodules/test_usavars.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index dc56845e3e4..14b5fed5390 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -8,7 +8,7 @@ from _pytest.fixtures import SubRequest from torchgeo.datamodules import USAVarsDataModule -from torchgeo.datasets import unbind_batchs +from torchgeo.datasets import unbind_samples class TestUSAVarsDataModule: @@ -46,6 +46,6 @@ def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: def test_plot(self, datamodule: USAVarsDataModule) -> None: datamodule.setup("validate") batch = next(iter(datamodule.val_dataloader())) - sample = unbind_batchs(batch)[0] + sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() From fac17b8a00c8f728d2bec478b6f36d5784d281df Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 14:14:37 -0600 Subject: [PATCH 050/108] Set up val_dataset even when fit --- torchgeo/datamodules/cyclone.py | 2 +- torchgeo/datamodules/deepglobelandcover.py | 2 +- torchgeo/datamodules/geo.py | 4 ++-- torchgeo/datamodules/gid15.py | 2 +- torchgeo/datamodules/inria.py | 2 +- torchgeo/datamodules/loveda.py | 4 ++-- torchgeo/datamodules/oscd.py | 2 +- torchgeo/datamodules/potsdam.py | 2 +- torchgeo/datamodules/sen12ms.py | 2 +- torchgeo/datamodules/vaihingen.py | 2 +- torchgeo/datamodules/xview.py | 2 +- 11 files changed, 13 insertions(+), 13 deletions(-) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index d15eeda90f4..06e77380867 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -58,5 +58,5 @@ def setup(self, stage: str) -> None: self.train_dataset = Subset(dataset, train_indices) self.val_dataset = Subset(dataset, val_indices) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = TropicalCyclone(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 90ad3dc4501..26976814baa 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -82,5 +82,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index b1e56fea011..6e61f6c7b06 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -95,11 +95,11 @@ def setup(self, stage: str) -> None: self.train_dataset = self.dataset_class( # type: ignore[call-arg] split="train", **self.kwargs ) - elif stage in ["fit", "validate"]: + if stage in ["fit", "validate"]: self.val_dataset = self.dataset_class( # type: ignore[call-arg] split="val", **self.kwargs ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = self.dataset_class( # type: ignore[call-arg] split="test", **self.kwargs ) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 65b4ea32a20..c8ad28154a4 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -85,6 +85,6 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: # Test set masks are not public, use for prediction instead self.predict_dataset = GID15(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index edf7c4d173d..cd4e6a0e5cf 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -91,6 +91,6 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, self.val_split_pct, self.test_split_pct ) - elif stage in ["predict"]: + if stage in ["predict"]: # Test set masks are not public, use for prediction instead self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 8dc7d050fb2..f8462ff6588 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -38,8 +38,8 @@ def setup(self, stage: str) -> None: """ if stage in ["fit"]: self.train_dataset = LoveDA(split="train", **self.kwargs) - elif stage in ["fit", "validate"]: + if stage in ["fit", "validate"]: self.val_dataset = LoveDA(split="val", **self.kwargs) - elif stage in ["predict"]: + if stage in ["predict"]: # Test set masks are not public, use for prediction instead self.predict_dataset = LoveDA(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 81197bfa9bc..8a999de4aa2 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -128,5 +128,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, val_pct=self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = OSCD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 5b7637d9e78..e93b09f8ef2 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -84,5 +84,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = Potsdam2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index df2ae188f8e..fac39df5551 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -87,5 +87,5 @@ def setup(self, stage: str) -> None: self.train_dataset = Subset(self.all_train_dataset, train_indices) self.val_dataset = Subset(self.all_train_dataset, val_indices) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 56451883d75..f8a2af9c6ef 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -84,5 +84,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = Vaihingen2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 5246a92010b..183a5f6b7c2 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -49,5 +49,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = dataset_split( dataset, val_pct=self.val_split_pct ) - elif stage in ["test"]: + if stage in ["test"]: self.test_dataset = XView2(split="test", **self.kwargs) From c94fa3eddb51271aabe45638ecf3e14485c5214d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 14:41:59 -0600 Subject: [PATCH 051/108] Fix classification tests --- conf/so2sat.yaml | 5 +- ...so2sat_supervised.yaml => so2sat_all.yaml} | 5 +- ...o2sat_unsupervised.yaml => so2sat_s1.yaml} | 7 +-- tests/conf/so2sat_s2.yaml | 15 +++++ tests/trainers/test_classification.py | 5 +- torchgeo/datamodules/fair1m.py | 2 +- torchgeo/datamodules/so2sat.py | 56 +++++++++++++++++-- torchgeo/trainers/classification.py | 2 +- torchgeo/trainers/detection.py | 2 +- torchgeo/trainers/regression.py | 2 +- torchgeo/trainers/segmentation.py | 2 +- 11 files changed, 80 insertions(+), 23 deletions(-) rename tests/conf/{so2sat_supervised.yaml => so2sat_all.yaml} (79%) rename tests/conf/{so2sat_unsupervised.yaml => so2sat_s1.yaml} (73%) create mode 100644 tests/conf/so2sat_s2.yaml diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index e8157c5c138..e27cbf02e82 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -11,11 +11,10 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 18 num_classes: 17 datamodule: root: "data/so2sat" batch_size: 128 num_workers: 4 - band_set: "rgb" - unsupervised_mode: False + band_set: "all" diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_all.yaml similarity index 79% rename from tests/conf/so2sat_supervised.yaml rename to tests/conf/so2sat_all.yaml index 0cbe484d6fc..83889f10735 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_all.yaml @@ -6,11 +6,10 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 18 num_classes: 17 datamodule: root: "tests/data/so2sat" batch_size: 1 num_workers: 0 - band_set: "rgb" - unsupervised_mode: False + band_set: "all" diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_s1.yaml similarity index 73% rename from tests/conf/so2sat_unsupervised.yaml rename to tests/conf/so2sat_s1.yaml index 02c1e6a32e7..8c87ff55a53 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_s1.yaml @@ -1,16 +1,15 @@ experiment: task: "so2sat" module: - loss: "jaccard" + loss: "focal" model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 8 num_classes: 17 datamodule: root: "tests/data/so2sat" batch_size: 1 num_workers: 0 - band_set: "rgb" - unsupervised_mode: True + band_set: "s1" diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml new file mode 100644 index 00000000000..e8e61dcb739 --- /dev/null +++ b/tests/conf/so2sat_s2.yaml @@ -0,0 +1,15 @@ +experiment: + task: "so2sat" + module: + loss: "focal" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 10 + num_classes: 17 + datamodule: + root: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + band_set: "s2" diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index ad28022114e..8eb29a10566 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -44,8 +44,9 @@ class TestClassificationTask: [ ("eurosat", EuroSATDataModule), ("resisc45", RESISC45DataModule), - ("so2sat_supervised", So2SatDataModule), - ("so2sat_unsupervised", So2SatDataModule), + ("so2sat_all", So2SatDataModule), + ("so2sat_s1", So2SatDataModule), + ("so2sat_s2", So2SatDataModule), ("ucmerced", UCMercedDataModule), ], ) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index cc7418764b9..6c684c45455 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -32,7 +32,7 @@ def __init__( val_split_pct: Percentage of the dataset to use as a validation set. test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.FAIR1M` + :class:`~torchgeo.datasets.FAIR1M`. """ super().__init__(FAIR1M, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 73d83efe109..e8e9c43f742 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -17,8 +17,17 @@ class So2SatDataModule(NonGeoDataModule): Uses the train/val/test splits from the dataset. """ + # TODO: calculate mean/std dev of s1 bands mean = torch.tensor( [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, 0.12375696117681859, 0.1092774636368323, 0.1010855203267882, @@ -31,9 +40,16 @@ class So2SatDataModule(NonGeoDataModule): 0.10905050699570007, ] ) - std = torch.tensor( [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, 0.03958795985905458, 0.047778262752410296, 0.06636616706371974, @@ -47,18 +63,46 @@ class So2SatDataModule(NonGeoDataModule): ] ) - # this reorders the bands to put S2 RGB first, then remainder of S2 - reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] - def __init__( - self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, + batch_size: int = 64, + num_workers: int = 0, + band_set: str = "all", + **kwargs: Any, ) -> None: """Initialize a new So2SatDataModule instance. Args: batch_size: Size of each mini-batch. num_workers: Number of workers for parallel data loading. + band_set: One of 'all', 's1', or 's2'. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.So2Sat` + :class:`~torchgeo.datasets.So2Sat`. """ + kwargs["bands"] = So2Sat.BAND_SETS[band_set] + + if band_set == "s1": + self.mean = self.mean[:8] + self.std = self.std[:8] + elif band_set == "s2": + self.mean = self.mean[8:] + self.std = self.std[8:] + super().__init__(So2Sat, batch_size, num_workers, **kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = So2Sat(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = So2Sat(split="validation", **self.kwargs) + if stage in ["test"]: + self.test_dataset = So2Sat(split="test", **self.kwargs) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 41d7c922ecd..35a71e01e22 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -200,7 +200,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except (AttributeError, ValueError): pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index b50d273d419..13a69846bca 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -187,7 +187,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except (AttributeError, ValueError): pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 73bffd5934f..0d5595fa8a9 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -157,7 +157,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except (AttributeError, ValueError): pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1d92c0842b1..44090009dfb 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -210,7 +210,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except (AttributeError, ValueError): pass def validation_epoch_end(self, outputs: Any) -> None: From 1c99f282d386f3bea9e1902834c7bced5f2a344b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 14:58:45 -0600 Subject: [PATCH 052/108] Fix ETCI2021 --- torchgeo/datamodules/etci2021.py | 35 +++++++++++++++++++++++++++++++- torchgeo/datasets/etci2021.py | 2 +- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index f6fa38005bd..e8ce5a396d6 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -3,9 +3,10 @@ """ETCI 2021 datamodule.""" -from typing import Any +from typing import Any, Dict import torch +from torch import Tensor from ..datasets import ETCI2021 from .geo import NonGeoDataModule @@ -46,3 +47,35 @@ def __init__( :class:`~torchgeo.datasets.ETCI2021`. """ super().__init__(ETCI2021, batch_size, num_workers, **kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = ETCI2021(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = ETCI2021(split="val", **self.kwargs) + if stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = ETCI2021(split="test", **self.kwargs) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if "mask" in batch: + # Predict flood mask, not water mask + batch["mask"] = (batch["mask"][:, 1] > 0).long() + + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 5e2919102b8..2a11edc9204 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -207,7 +207,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor From ee80cc4b623b07c21d7a39c8cff8bc18daff5ab1 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 15:21:16 -0600 Subject: [PATCH 053/108] Fix SEN12MS --- torchgeo/datamodules/sen12ms.py | 48 ++++++++++++++++++++++++--------- torchgeo/datasets/sen12ms.py | 2 +- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index fac39df5551..3692c4a757c 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -3,10 +3,11 @@ """SEN12MS datamodule.""" -from typing import Any +from typing import Any, Dict import torch from sklearn.model_selection import GroupShuffleSplit +from torch import Tensor from torch.utils.data import Subset from ..datasets import SEN12MS @@ -29,6 +30,10 @@ class SEN12MSDataModule(NonGeoDataModule): [0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10] ) + std = torch.tensor( + [-25, -25, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4] + ) + def __init__( self, batch_size: int = 64, @@ -47,11 +52,16 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.SEN12MS`. """ - super().__init__(SEN12MS, batch_size, num_workers, **kwargs) + kwargs["bands"] = SEN12MS.BAND_SETS[band_set] - assert band_set in SEN12MS.BAND_SETS.keys() - self.band_set = band_set - self.bands = SEN12MS.BAND_SETS[band_set] + if band_set == "s1": + self.std = self.std[:2] + elif band_set == "s2-all": + self.std = self.std[2:] + elif band_set == "s2-reduced": + self.std = self.std[torch.tensor([3, 4, 5, 9, 12, 13])] + + super().__init__(SEN12MS, batch_size, num_workers, **kwargs) def setup(self, stage: str) -> None: """Set up datasets. @@ -62,9 +72,7 @@ def setup(self, stage: str) -> None: if stage in ["fit", "validate"]: season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - self.all_train_dataset = SEN12MS( - split="train", bands=self.bands, **self.kwargs - ) + dataset = SEN12MS(split="train", **self.kwargs) # A patch is a filename like: # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" @@ -73,7 +81,7 @@ def setup(self, stage: str) -> None: # simply give each season a large number and representing a unique_scene_id # as (season_id + scene_id). scenes = [] - for scene_fn in self.all_train_dataset.ids: + for scene_fn in dataset.ids: parts = scene_fn.split("_") season_id = season_to_int[parts[1]] scene_id = int(parts[3]) @@ -85,7 +93,23 @@ def setup(self, stage: str) -> None: ) ) - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) if stage in ["test"]: - self.test_dataset = SEN12MS(split="test", bands=self.bands, **self.kwargs) + self.test_dataset = SEN12MS(split="test", **self.kwargs) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"][:, 0]) + + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 787098bf764..fba2475b0eb 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -223,7 +223,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ filename = self.ids[index] - lc = self._load_raster(filename, "lc") + lc = self._load_raster(filename, "lc").long() s1 = self._load_raster(filename, "s1") s2 = self._load_raster(filename, "s2") From 96658bea0396accd2f78ba8446429c22dac12fd3 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 16:55:31 -0600 Subject: [PATCH 054/108] Add GeoDataModule base class --- docs/api/datamodules.rst | 5 + torchgeo/datamodules/__init__.py | 3 +- torchgeo/datamodules/chesapeake.py | 183 ++++++-------------- torchgeo/datamodules/geo.py | 260 ++++++++++++++++++++++++++++- torchgeo/datamodules/gid15.py | 2 +- torchgeo/datamodules/naip.py | 157 +++++------------ 6 files changed, 364 insertions(+), 246 deletions(-) diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index d5a868faa69..69292c1b5d2 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -132,6 +132,11 @@ xView2 Base Classes ------------ +GeoDataModule +^^^^^^^^^^^^^ + +.. autoclass:: GeoDataModule + NonGeoDataModule ^^^^^^^^^^^^^^^^ diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 85ea803cd67..2f846c4a5fc 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -11,7 +11,7 @@ from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule -from .geo import NonGeoDataModule +from .geo import GeoDataModule, NonGeoDataModule from .gid15 import GID15DataModule from .inria import InriaAerialImageLabelingDataModule from .landcoverai import LandCoverAIDataModule @@ -57,5 +57,6 @@ "Vaihingen2DDataModule", "XView2DataModule", # Base classes + "GeoDataModule", "NonGeoDataModule", ) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 82c3756495e..0142bd0bfe8 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,21 +3,17 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, Dict, List, Optional +from typing import Any, List import kornia.augmentation as K -import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule -from torch import Tensor -from torch.utils.data import DataLoader - -from ..datasets import ChesapeakeCVPR, stack_samples -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler + +from ..datasets import ChesapeakeCVPR +from ..samplers import GridGeoSampler, RandomBatchGeoSampler from ..transforms import AugmentationSequential +from .geo import GeoDataModule -class ChesapeakeCVPRDataModule(LightningDataModule): +class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. Uses the random splits defined per state to partition tiles into train, val, @@ -32,36 +28,38 @@ def __init__( num_tiles_per_batch: int = 64, num_patches_per_tile: int = 200, patch_size: int = 256, + length: int = 1000, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, prior_smoothing_constant: float = 1e-4, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. + """Initialize a new ChesapeakeCVPRDataModule instance. Args: - train_splits: The splits used to train the model, e.g. ["ny-train"] - val_splits: The splits used to validate the model, e.g. ["ny-val"] - test_splits: The splits used to test the model, e.g. ["ny-test"] - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - num_workers: The number of workers to use in all created DataLoaders - class_set: The high-resolution land cover class set to use - 5 or 7 + train_splits: Splits used to train the model, e.g., ["ny-train"]. + val_splits: Splits used to validate the model, e.g., ["ny-val"]. + test_splits: Splits used to test the model, e.g., ["ny-test"]. + num_tiles_per_batch: Number of image tiles to sample from during training. + num_patches_per_tile: Number of patches to randomly sample from each image + tile during training + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + class_set: The high-resolution land cover class set to use (5 or 7). use_prior_labels: Flag for using a prior over high-resolution classes - instead of the high-resolution labels themselves - prior_smoothing_constant: additive smoothing to add when using prior labels + instead of the high-resolution labels themselves. + prior_smoothing_constant: Additive smoothing to add when using prior labels. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.ChesapeakeCVPR` + :class:`~torchgeo.datasets.ChesapeakeCVPR`. Raises: - ValueError: if ``use_prior_labels`` is used with ``class_set==7`` + ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ - super().__init__() + super().__init__(ChesapeakeCVPR, 1, patch_size, length, num_workers, **kwargs) + for state in train_splits + val_splits + test_splits: assert state in ChesapeakeCVPR.splits assert class_set in [5, 7] @@ -76,15 +74,12 @@ def __init__( self.test_splits = test_splits self.train_batch_size = num_tiles_per_batch self.num_patches_per_tile = num_patches_per_tile - self.patch_size = patch_size # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 2 - self.num_workers = num_workers self.class_set = class_set self.use_prior_labels = use_prior_labels self.prior_smoothing_constant = prior_smoothing_constant - self.kwargs = kwargs if self.use_prior_labels: self.layers = [ @@ -100,107 +95,33 @@ def __init__( data_keys=["image", "mask"], ) - def prepare_data(self) -> None: - """Confirms that the dataset is downloaded on the local node. - - This method is called once per node, while :func:`setup` is called once per GPU. - """ - if self.kwargs.get("download", False): - ChesapeakeCVPR(splits=self.train_splits, layers=self.layers, **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - self.train_dataset = ChesapeakeCVPR( - splits=self.train_splits, layers=self.layers, **self.kwargs - ) - self.val_dataset = ChesapeakeCVPR( - splits=self.val_splits, layers=self.layers, **self.kwargs - ) - self.test_dataset = ChesapeakeCVPR( - splits=self.test_splits, layers=self.layers, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - sampler = RandomBatchGeoSampler( - self.train_dataset, - size=self.original_patch_size, - batch_size=self.train_batch_size, - length=self.num_patches_per_tile * len(self.train_dataset), - ) - return DataLoader( - self.train_dataset, - batch_sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - sampler = GridGeoSampler( - self.val_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.val_dataset, - batch_size=self.train_batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - sampler = GridGeoSampler( - self.test_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.test_dataset, - batch_size=self.train_batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. + def setup(self, stage: str) -> None: + """Set up datasets and samplers. Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - batch = self.aug(batch) - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.ChesapeakeCVPR.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["fit"]: + self.train_dataset = ChesapeakeCVPR( + splits=self.train_splits, layers=self.layers, **self.kwargs + ) + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, + self.original_patch_size, + self.train_batch_size or 1, + self.length, + ) + if stage in ["fit", "validate"]: + self.val_dataset = ChesapeakeCVPR( + splits=self.val_splits, layers=self.layers, **self.kwargs + ) + self.val_sampler = GridGeoSampler( + self.val_dataset, self.original_patch_size, self.original_patch_size + ) + if stage in ["test"]: + self.test_dataset = ChesapeakeCVPR( + splits=self.test_splits, layers=self.layers, **self.kwargs + ) + self.test_sampler = GridGeoSampler( + self.test_dataset, self.original_patch_size, self.original_patch_size + ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 6e61f6c7b06..26b4dc4085c 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -3,7 +3,7 @@ """Base classes for all :mod:`torchgeo` data modules.""" -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union import kornia.augmentation as K import matplotlib.pyplot as plt @@ -17,10 +17,266 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset -from ..datasets import NonGeoDataset +from ..datasets import GeoDataset, NonGeoDataset, stack_samples +from ..samplers import ( + BatchGeoSampler, + GeoSampler, + GridGeoSampler, + RandomBatchGeoSampler, +) from ..transforms import AugmentationSequential +class GeoDataModule(LightningDataModule): + """Base class for data modules lacking geospatial information.""" + + mean = torch.tensor(0) + std = torch.tensor(255) + + def __init__( + self, + dataset_class: Type[GeoDataset], + batch_size: int = 1, + patch_size: Union[int, Tuple[int, int]] = 64, + length: int = 1000, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new NonGeoDataModule instance. + + Args: + dataset_class: Class used to instantiate a new dataset. + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to ``dataset_class`` + """ + super().__init__() + + self.dataset_class = dataset_class + self.batch_size = batch_size + self.patch_size = patch_size + self.length = length + self.num_workers = num_workers + self.kwargs = kwargs + + # Datasets + self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + # Samplers + self.sampler: Optional[GeoSampler] = None + self.train_sampler: Optional[GeoSampler] = None + self.val_sampler: Optional[GeoSampler] = None + self.test_sampler: Optional[GeoSampler] = None + self.predict_sampler: Optional[GeoSampler] = None + + # Batch samplers + self.batch_sampler: Optional[BatchGeoSampler] = None + self.train_batch_sampler: Optional[BatchGeoSampler] = None + self.val_batch_sampler: Optional[BatchGeoSampler] = None + self.test_batch_sampler: Optional[BatchGeoSampler] = None + self.predict_batch_sampler: Optional[BatchGeoSampler] = None + + # Data loaders + self.train_batch_size: Optional[int] = None + self.val_batch_size: Optional[int] = None + self.test_batch_size: Optional[int] = None + self.predict_batch_size: Optional[int] = None + + # Data augmentation + Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + self.aug: Transform = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + ) + self.train_aug: Optional[Transform] = None + self.val_aug: Optional[Transform] = None + self.test_aug: Optional[Transform] = None + self.predict_aug: Optional[Transform] = None + + def prepare_data(self) -> None: + """Download and prepare data. + + During distributed training, this method is called only within a single process + to avoid corrupted data. This method should not set state since it is not called + on every device, use :meth:`setup` instead. + """ + if self.kwargs.get("download", False): + self.dataset_class(**self.kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = self.dataset_class( # type: ignore[call-arg] + split="train", **self.kwargs + ) + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( # type: ignore[call-arg] + split="val", **self.kwargs + ) + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( # type: ignore[call-arg] + split="test", **self.kwargs + ) + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for training. + + Returns: + A collection of data loaders specifying training samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'train_dataset'. + """ + if self.train_dataset is not None or self.dataset is not None: + return DataLoader( + dataset=self.train_dataset or self.dataset, # type: ignore[arg-type] + batch_size=self.train_batch_size or self.batch_size, + shuffle=True, + sampler=self.train_sampler or self.sampler, + batch_sampler=self.train_batch_sampler or self.batch_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" + raise MisconfigurationException(msg) + + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for validation. + + Returns: + A collection of data loaders specifying validation samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'val_dataset'. + """ + if self.val_dataset is not None or self.dataset is not None: + return DataLoader( + dataset=self.val_dataset or self.dataset, # type: ignore[arg-type] + batch_size=self.val_batch_size or self.batch_size, + shuffle=True, + sampler=self.val_sampler or self.sampler, + batch_sampler=self.val_batch_sampler or self.batch_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" + raise MisconfigurationException(msg) + + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for testing. + + Returns: + A collection of data loaders specifying testing samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'test_dataset'. + """ + if self.test_dataset is not None or self.dataset is not None: + return DataLoader( + dataset=self.test_dataset or self.dataset, # type: ignore[arg-type] + batch_size=self.test_batch_size or self.batch_size, + shuffle=True, + sampler=self.test_sampler or self.sampler, + batch_sampler=self.test_batch_sampler or self.batch_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" + raise MisconfigurationException(msg) + + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for prediction. + + Returns: + A collection of data loaders specifying prediction samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'predict_dataset'. + """ + if self.predict_dataset is not None or self.dataset is not None: + return DataLoader( + dataset=self.predict_dataset or self.dataset, # type: ignore[arg-type] + batch_size=self.predict_batch_size or self.batch_size, + shuffle=True, + sampler=self.predict_sampler or self.sampler, + batch_sampler=self.predict_batch_sampler or self.batch_sampler, + num_workers=self.num_workers, + collate_fn=stack_samples, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" + raise MisconfigurationException(msg) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if self.trainer: + if self.trainer.training: + aug = self.train_aug or self.aug + elif self.trainer.validating: + aug = self.val_aug or self.aug + elif self.trainer.testing: + aug = self.test_aug or self.aug + elif self.trainer.predicting: + aug = self.predict_aug or self.aug + + batch = aug(batch) + + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run the plot method of the dataset if one exists. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A matplotlib Figure with the image, ground truth, and predictions. + """ + if self.val_dataset is not None: + if hasattr(self.val_dataset, "plot"): + return self.val_dataset.plot(*args, **kwargs) + + class NonGeoDataModule(LightningDataModule): """Base class for data modules lacking geospatial information.""" diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index c8ad28154a4..b896160c6da 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -45,7 +45,7 @@ def __init__( num_patches_per_tile: Number of patches to randomly sample from each image tile during training. patch_size: Size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures + Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 6c495f29f78..a4c691e45a4 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -3,20 +3,16 @@ """National Agriculture Imagery Program (NAIP) datamodule.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Tuple, Union -import kornia.augmentation as K import matplotlib.pyplot as plt -from pytorch_lightning import LightningDataModule -from torch import Tensor -from torch.utils.data import DataLoader -from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples +from ..datasets import NAIP, BoundingBox, Chesapeake13 from ..samplers import GridGeoSampler, RandomBatchGeoSampler -from ..transforms import AugmentationSequential +from .geo import GeoDataModule -class NAIPChesapeakeDataModule(LightningDataModule): +class NAIPChesapeakeDataModule(GeoDataModule): """LightningDataModule implementation for the NAIP and Chesapeake datasets. Uses the train/val/test splits from the dataset. @@ -25,32 +21,23 @@ class NAIPChesapeakeDataModule(LightningDataModule): def __init__( self, batch_size: int = 64, - num_workers: int = 0, - patch_size: int = 256, - stride: int = 128, + patch_size: Union[int, Tuple[int, int]] = 256, length: int = 1000, + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. + """Initialize a new NAIPChesapeakeDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - patch_size: size of patches to sample - stride: stride of grid sampler - length: epoch size + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and :class:`~torchgeo.datasets.Chesapeake13` (prefix keys with ``chesapeake_``) """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.patch_size = patch_size - self.stride = stride - self.length = length - self.naip_kwargs = {} self.chesapeake_kwargs = {} for key, val in kwargs.items(): @@ -59,103 +46,44 @@ def __init__( elif key.startswith("chesapeake_"): self.chesapeake_kwargs[key[11:]] = val - self.aug = AugmentationSequential( - K.Normalize(mean=0.0, std=255.0), data_keys=["image", "mask"] + super().__init__( + Chesapeake13, + batch_size, + patch_size, + length, + num_workers, + **self.chesapeake_kwargs, ) - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.chesapeake_kwargs.get("download", False): - Chesapeake13(**self.chesapeake_kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets and samplers. Args: - stage: state to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) - self.naip = NAIP(**self.naip_kwargs) - self.dataset = self.chesapeake & self.naip + chesapeake = Chesapeake13(**self.chesapeake_kwargs) + naip = NAIP(**self.naip_kwargs) + self.dataset = chesapeake & naip - # TODO: figure out better train/val/test split roi = self.dataset.bounds midx = roi.minx + (roi.maxx - roi.minx) / 2 midy = roi.miny + (roi.maxy - roi.miny) / 2 - train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) - val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) - test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) - - self.train_sampler = RandomBatchGeoSampler( - self.naip, self.patch_size, self.batch_size, self.length, train_roi - ) - self.val_sampler = GridGeoSampler( - self.naip, self.patch_size, self.stride, val_roi - ) - self.test_sampler = GridGeoSampler( - self.naip, self.patch_size, self.stride, test_roi - ) - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.dataset, - batch_sampler=self.train_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.val_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.test_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - batch = self.aug(batch) - return batch + if stage in ["fit"]: + roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) + self.train_batch_sampler = RandomBatchGeoSampler( + self.dataset, self.patch_size, self.batch_size, self.length, roi + ) + if stage in ["fit", "validate"]: + roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) + self.val_sampler = GridGeoSampler( + self.dataset, self.patch_size, self.patch_size, roi + ) + if stage in ["test"]: + roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) + self.test_sampler = GridGeoSampler( + self.dataset, self.patch_size, self.patch_size, roi + ) def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: """Run NAIP and Chesapeake plot methods. @@ -163,6 +91,13 @@ def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: See :meth:`torchgeo.datasets.NAIP.plot` and :meth:`torchgeo.datasets.Chesapeake.plot`. + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A list of matplotlib Figures with the image, ground truth, and predictions. + .. versionadded:: 0.4 """ image = self.naip.plot(*args, **kwargs) From aa8108e908ccbb8e82946abc6f3a8390947a4d98 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 20:47:37 -0600 Subject: [PATCH 055/108] Fix several chesapeake bugs --- conf/chesapeake_cvpr.yaml | 3 +- tests/conf/chesapeake_cvpr_5.yaml | 3 +- tests/conf/chesapeake_cvpr_7.yaml | 3 +- tests/conf/chesapeake_cvpr_prior.yaml | 3 +- tests/datamodules/test_chesapeake.py | 4 +-- torchgeo/datamodules/chesapeake.py | 45 +++++++++++++++++++-------- torchgeo/datamodules/geo.py | 37 +++++++++++++++------- torchgeo/datamodules/naip.py | 8 ++++- torchgeo/transforms/transforms.py | 28 ++++++++++------- 9 files changed, 87 insertions(+), 47 deletions(-) diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index 560d61f7a7b..858e2624199 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -25,8 +25,7 @@ experiment: - "de-val" test_splits: - "de-test" - num_tiles_per_batch: 64 - num_patches_per_tile: 200 + batch_size: 200 patch_size: 256 num_workers: 4 class_set: ${experiment.module.num_classes} diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index af1394da67d..7ef269dd661 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -20,8 +20,7 @@ experiment: - "de-test" test_splits: - "de-test" - num_tiles_per_batch: 2 - num_patches_per_tile: 2 + batch_size: 2 patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index ba40d618c1a..653f4934ca0 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -20,8 +20,7 @@ experiment: - "de-test" test_splits: - "de-test" - num_tiles_per_batch: 2 - num_patches_per_tile: 2 + batch_size: 2 patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index ca774e9917b..3e9713fbb59 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -20,8 +20,7 @@ experiment: - "de-test" test_splits: - "de-test" - num_tiles_per_batch: 2 - num_patches_per_tile: 2 + batch_size: 2 patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py index 55ae702de0e..ecf30775544 100644 --- a/tests/datamodules/test_chesapeake.py +++ b/tests/datamodules/test_chesapeake.py @@ -16,9 +16,9 @@ def test_invalid_param_config(self) -> None: train_splits=["de-test"], val_splits=["de-test"], test_splits=["de-test"], - patch_size=32, - patches_per_tile=2, batch_size=2, + patch_size=32, + length=4, num_workers=0, class_set=7, use_prior_labels=True, diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 0142bd0bfe8..875063e83e8 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,9 +3,11 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, List +from typing import Any, Dict, List import kornia.augmentation as K +import torch.nn.functional as F +from torch import Tensor from ..datasets import ChesapeakeCVPR from ..samplers import GridGeoSampler, RandomBatchGeoSampler @@ -25,8 +27,7 @@ def __init__( train_splits: List[str], val_splits: List[str], test_splits: List[str], - num_tiles_per_batch: int = 64, - num_patches_per_tile: int = 200, + batch_size: int = 64, patch_size: int = 256, length: int = 1000, num_workers: int = 0, @@ -41,9 +42,7 @@ def __init__( train_splits: Splits used to train the model, e.g., ["ny-train"]. val_splits: Splits used to validate the model, e.g., ["ny-val"]. test_splits: Splits used to test the model, e.g., ["ny-test"]. - num_tiles_per_batch: Number of image tiles to sample from during training. - num_patches_per_tile: Number of patches to randomly sample from each image - tile during training + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. length: Length of each training epoch. @@ -58,12 +57,10 @@ def __init__( Raises: ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ - super().__init__(ChesapeakeCVPR, 1, patch_size, length, num_workers, **kwargs) + super().__init__(ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs) - for state in train_splits + val_splits + test_splits: - assert state in ChesapeakeCVPR.splits assert class_set in [5, 7] - if use_prior_labels and class_set != 5: + if use_prior_labels and class_set == 7: raise ValueError( "The pre-generated prior labels are only valid for the 5" + " class set of labels" @@ -72,8 +69,6 @@ def __init__( self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 2 @@ -91,7 +86,7 @@ def __init__( self.aug = AugmentationSequential( K.CenterCrop(patch_size), - K.Normalize(mean=0.0, std=255.0), + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"], ) @@ -125,3 +120,27 @@ def setup(self, stage: str) -> None: self.test_sampler = GridGeoSampler( self.test_dataset, self.original_patch_size, self.original_patch_size ) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if self.use_prior_labels: + batch["mask"] = F.normalize(batch["mask"], p=1, dim=1) + batch["mask"] = F.normalize( + batch["mask"] + self.prior_smoothing_constant, p=1, dim=1 + ) + else: + if self.class_set == 5: + batch["mask"][batch["mask"] == 5] = 4 + batch["mask"][batch["mask"] == 6] = 4 + + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 26b4dc4085c..e471f091e6f 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -151,10 +151,15 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'train_dataset'. """ if self.train_dataset is not None or self.dataset is not None: + batch_size = self.train_batch_size or self.batch_size + shuffle = True + if self.train_batch_sampler is not None or self.batch_sampler is not None: + batch_size = 1 + shuffle = False return DataLoader( dataset=self.train_dataset or self.dataset, # type: ignore[arg-type] - batch_size=self.train_batch_size or self.batch_size, - shuffle=True, + batch_size=batch_size, + shuffle=shuffle, sampler=self.train_sampler or self.sampler, batch_sampler=self.train_batch_sampler or self.batch_sampler, num_workers=self.num_workers, @@ -175,10 +180,13 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'val_dataset'. """ if self.val_dataset is not None or self.dataset is not None: + batch_size = self.val_batch_size or self.batch_size + if self.val_batch_sampler is not None or self.batch_sampler is not None: + batch_size = 1 return DataLoader( dataset=self.val_dataset or self.dataset, # type: ignore[arg-type] - batch_size=self.val_batch_size or self.batch_size, - shuffle=True, + batch_size=batch_size, + shuffle=False, sampler=self.val_sampler or self.sampler, batch_sampler=self.val_batch_sampler or self.batch_sampler, num_workers=self.num_workers, @@ -199,10 +207,13 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'test_dataset'. """ if self.test_dataset is not None or self.dataset is not None: + batch_size = self.test_batch_size or self.batch_size + if self.test_batch_sampler is not None or self.batch_sampler is not None: + batch_size = 1 return DataLoader( dataset=self.test_dataset or self.dataset, # type: ignore[arg-type] - batch_size=self.test_batch_size or self.batch_size, - shuffle=True, + batch_size=batch_size, + shuffle=False, sampler=self.test_sampler or self.sampler, batch_sampler=self.test_batch_sampler or self.batch_sampler, num_workers=self.num_workers, @@ -223,10 +234,13 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'predict_dataset'. """ if self.predict_dataset is not None or self.dataset is not None: + batch_size = self.predict_batch_size or self.batch_size + if self.predict_batch_sampler is not None or self.batch_sampler is not None: + batch_size = 1 return DataLoader( dataset=self.predict_dataset or self.dataset, # type: ignore[arg-type] - batch_size=self.predict_batch_size or self.batch_size, - shuffle=True, + batch_size=batch_size, + shuffle=False, sampler=self.predict_sampler or self.sampler, batch_sampler=self.predict_batch_sampler or self.batch_sampler, num_workers=self.num_workers, @@ -272,9 +286,10 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: Returns: A matplotlib Figure with the image, ground truth, and predictions. """ - if self.val_dataset is not None: - if hasattr(self.val_dataset, "plot"): - return self.val_dataset.plot(*args, **kwargs) + dataset = self.val_dataset or self.dataset + if dataset is not None: + if hasattr(dataset, "plot"): + return dataset.plot(*args, **kwargs) class NonGeoDataModule(LightningDataModule): diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index a4c691e45a4..110c8825902 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -5,10 +5,12 @@ from typing import Any, Tuple, Union +import kornia.augmentation as K import matplotlib.pyplot as plt from ..datasets import NAIP, BoundingBox, Chesapeake13 from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..transforms import AugmentationSequential from .geo import GeoDataModule @@ -36,7 +38,7 @@ def __init__( **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and :class:`~torchgeo.datasets.Chesapeake13` - (prefix keys with ``chesapeake_``) + (prefix keys with ``chesapeake_``). """ self.naip_kwargs = {} self.chesapeake_kwargs = {} @@ -55,6 +57,10 @@ def __init__( **self.chesapeake_kwargs, ) + self.aug: Transform = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + ) + def setup(self, stage: str) -> None: """Set up datasets and samplers. diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index f874e266ff4..ed16383598e 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -45,26 +45,30 @@ def __init__(self, *args: Module, data_keys: List[str]) -> None: self.augs = kornia.augmentation.AugmentationSequential(*args, data_keys=keys) - def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: """Perform augmentations and update data dict. Args: - sample: the input + batch: the input Returns: the augmented input """ + # TorchGeo bbox is very different from Kornia bbox + if "bbox" in batch: + del batch["bbox"] + # Kornia augmentations require all inputs to be float dtype = {} for key in self.data_keys: - dtype[key] = sample[key].dtype - sample[key] = sample[key].float() + dtype[key] = batch[key].dtype + batch[key] = batch[key].float() # Kornia requires masks to have a channel dimension - if "mask" in sample: - sample["mask"] = rearrange(sample["mask"], "b h w -> b () h w") + if "mask" in batch: + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - inputs = [sample[k] for k in self.data_keys] + inputs = [batch[k] for k in self.data_keys] outputs_list: Union[Tensor, List[Tensor]] = self.augs(*inputs) outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] @@ -72,17 +76,17 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: outputs: Dict[str, Tensor] = { k: v for k, v in zip(self.data_keys, outputs_list) } - sample.update(outputs) + batch.update(outputs) # Convert all inputs back to their previous dtype for key in self.data_keys: - sample[key] = sample[key].to(dtype[key]) + batch[key] = batch[key].to(dtype[key]) # Torchmetrics does not support masks with a channel dimension - if "mask" in sample: - sample["mask"] = rearrange(sample["mask"], "b () h w -> b h w") + if "mask" in batch: + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return sample + return batch class _ExtractTensorPatches(GeometricAugmentationBase2D): From 2e5b2a86c3df084dfd6f7b32db8f03c58c11bb3e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 20:58:58 -0600 Subject: [PATCH 056/108] Fix dtype and shape --- torchgeo/datamodules/chesapeake.py | 4 +++- torchgeo/datamodules/naip.py | 2 +- torchgeo/datasets/chesapeake.py | 4 ++-- torchgeo/datasets/geo.py | 7 +++++-- torchgeo/transforms/transforms.py | 2 +- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 875063e83e8..473e363385a 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -57,7 +57,9 @@ def __init__( Raises: ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ - super().__init__(ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs) + super().__init__( + ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs + ) assert class_set in [5, 7] if use_prior_labels and class_set == 7: diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 110c8825902..a141b3543eb 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -57,7 +57,7 @@ def __init__( **self.chesapeake_kwargs, ) - self.aug: Transform = AugmentationSequential( + self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] ) diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 454c52f980b..3d646c79481 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -688,8 +688,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: sample["image"] = np.concatenate(sample["image"], axis=0) sample["mask"] = np.concatenate(sample["mask"], axis=0) - sample["image"] = torch.from_numpy(sample["image"]) - sample["mask"] = torch.from_numpy(sample["mask"]) + sample["image"] = torch.from_numpy(sample["image"]).float() + sample["mask"] = torch.from_numpy(sample["mask"]).long() if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 189875f3c1f..84a531268e8 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -427,8 +427,11 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: else: data = self._merge_files(filepaths, query, self.band_indexes) - key = "image" if self.is_image else "mask" - sample = {key: data, "crs": self.crs, "bbox": query} + sample = {"crs": self.crs, "bbox": query} + if self.is_image: + sample["image"] = data.float() + else: + sample["mask"] = data.long() if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index ed16383598e..d8ff1670a3e 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -65,7 +65,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: batch[key] = batch[key].float() # Kornia requires masks to have a channel dimension - if "mask" in batch: + if "mask" in batch and len(batch["mask"].shape) == 3: batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") inputs = [batch[k] for k in self.data_keys] From 989bbaadb0cd576cfc89be213f64babd3f54f3d4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 21:09:10 -0600 Subject: [PATCH 057/108] Fix crs/bbox issue --- torchgeo/datamodules/geo.py | 22 ++++++++++++++++++++++ torchgeo/transforms/transforms.py | 4 ---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index e471f091e6f..13b2346609c 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -250,6 +250,28 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" raise MisconfigurationException(msg) + def transfer_batch_to_device( + self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int + ) -> Dict[str, Tensor]: + """Transfer batch to device. + + Defines how custom data types are moved to the target device. + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A reference to the data on the new device. + """ + # Non-Tensor values cannot be moved to a device + del batch["crs"] + del batch["bbox"] + + batch = super().transfer_batch_to_device(batch, device, dataloader_idx) + return batch + def on_after_batch_transfer( self, batch: Dict[str, Tensor], dataloader_idx: int ) -> Dict[str, Tensor]: diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index d8ff1670a3e..c56c95d0b05 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -54,10 +54,6 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: Returns: the augmented input """ - # TorchGeo bbox is very different from Kornia bbox - if "bbox" in batch: - del batch["bbox"] - # Kornia augmentations require all inputs to be float dtype = {} for key in self.data_keys: From a0aae636226fd13773ab483e1ba40ee106e6238f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 21:10:18 -0600 Subject: [PATCH 058/108] Fix test dtype --- tests/datasets/test_geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b691961806e..5a27522f404 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -211,7 +211,7 @@ def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None: x = custom_dtype_ds[custom_dtype_ds.bounds] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert x["image"].dtype == torch.int64 + assert x["image"].dtype == torch.float32 def test_invalid_query(self, sentinel: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) From ad67c8847e3f66a3b87433130521c5d7a1f2143e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 21:32:21 -0600 Subject: [PATCH 059/108] Fix unequal size stacking error --- torchgeo/datamodules/chesapeake.py | 16 ++++++++++------ torchgeo/transforms/transforms.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 473e363385a..00ae98d646a 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -57,6 +57,14 @@ def __init__( Raises: ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ + # This is a rough estimate of how large of a patch we will need to sample in + # EPSG:3857 in order to guarantee a large enough patch in the local CRS. + self.original_patch_size = patch_size * 2 + kwargs["transforms"] = AugmentationSequential( + K.CenterCrop(patch_size), + data_keys=["image", "mask"], + ) + super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs ) @@ -71,9 +79,6 @@ def __init__( self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits - # This is a rough estimate of how large of a patch we will need to sample in - # EPSG:3857 in order to guarantee a large enough patch in the local CRS. - self.original_patch_size = patch_size * 2 self.class_set = class_set self.use_prior_labels = use_prior_labels self.prior_smoothing_constant = prior_smoothing_constant @@ -87,7 +92,6 @@ def __init__( self.layers = ["naip-new", "lc"] self.aug = AugmentationSequential( - K.CenterCrop(patch_size), K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"], ) @@ -136,10 +140,10 @@ def on_after_batch_transfer( A batch of data. """ if self.use_prior_labels: - batch["mask"] = F.normalize(batch["mask"], p=1, dim=1) + batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) batch["mask"] = F.normalize( batch["mask"] + self.prior_smoothing_constant, p=1, dim=1 - ) + ).long() else: if self.class_set == 5: batch["mask"][batch["mask"] == 5] = 4 diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index c56c95d0b05..331b2bc0127 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -79,7 +79,7 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: batch[key] = batch[key].to(dtype[key]) # Torchmetrics does not support masks with a channel dimension - if "mask" in batch: + if "mask" in batch and batch["mask"].shape[1] == 1: batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") return batch From f62700eb53979f457bcd4bf835db75b687924a85 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 2 Jan 2023 21:36:27 -0600 Subject: [PATCH 060/108] flake8 fix --- torchgeo/datamodules/chesapeake.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 00ae98d646a..1e47b5b0fe8 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -61,8 +61,7 @@ def __init__( # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 2 kwargs["transforms"] = AugmentationSequential( - K.CenterCrop(patch_size), - data_keys=["image", "mask"], + K.CenterCrop(patch_size), data_keys=["image", "mask"] ) super().__init__( @@ -92,8 +91,7 @@ def __init__( self.layers = ["naip-new", "lc"] self.aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - data_keys=["image", "mask"], + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] ) def setup(self, stage: str) -> None: From 4de4eff13b6e1dd1f1261c346662753a03616f67 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 3 Jan 2023 15:07:36 -0600 Subject: [PATCH 061/108] Better checks on sampler --- torchgeo/datamodules/geo.py | 56 ++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 13b2346609c..987e359140c 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -150,18 +150,22 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'train_dataset'. """ - if self.train_dataset is not None or self.dataset is not None: + dataset = self.train_dataset or self.dataset + if dataset is not None: + sampler = self.train_sampler or self.sampler + batch_sampler = self.train_batch_sampler or self.batch_sampler + assert sampler or batch_sampler batch_size = self.train_batch_size or self.batch_size shuffle = True - if self.train_batch_sampler is not None or self.batch_sampler is not None: + if batch_sampler is not None: batch_size = 1 shuffle = False return DataLoader( - dataset=self.train_dataset or self.dataset, # type: ignore[arg-type] + dataset=dataset, batch_size=batch_size, shuffle=shuffle, - sampler=self.train_sampler or self.sampler, - batch_sampler=self.train_batch_sampler or self.batch_sampler, + sampler=sampler, + batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=stack_samples, ) @@ -179,16 +183,20 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'val_dataset'. """ - if self.val_dataset is not None or self.dataset is not None: + dataset = self.val_dataset or self.dataset + if dataset is not None: + sampler = self.val_sampler or self.sampler + batch_sampler = self.val_batch_sampler or self.batch_sampler + assert sampler or batch_sampler batch_size = self.val_batch_size or self.batch_size - if self.val_batch_sampler is not None or self.batch_sampler is not None: + if batch_sampler is not None: batch_size = 1 return DataLoader( - dataset=self.val_dataset or self.dataset, # type: ignore[arg-type] + dataset=dataset, batch_size=batch_size, shuffle=False, - sampler=self.val_sampler or self.sampler, - batch_sampler=self.val_batch_sampler or self.batch_sampler, + sampler=sampler, + batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=stack_samples, ) @@ -206,16 +214,20 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'test_dataset'. """ - if self.test_dataset is not None or self.dataset is not None: + dataset = self.test_dataset or self.dataset + if dataset is not None: + sampler = self.test_sampler or self.sampler + batch_sampler = self.test_batch_sampler or self.batch_sampler + assert sampler or batch_sampler batch_size = self.test_batch_size or self.batch_size - if self.test_batch_sampler is not None or self.batch_sampler is not None: + if batch_sampler is not None: batch_size = 1 return DataLoader( - dataset=self.test_dataset or self.dataset, # type: ignore[arg-type] + dataset=dataset, batch_size=batch_size, shuffle=False, - sampler=self.test_sampler or self.sampler, - batch_sampler=self.test_batch_sampler or self.batch_sampler, + sampler=sampler, + batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=stack_samples, ) @@ -233,16 +245,20 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'predict_dataset'. """ - if self.predict_dataset is not None or self.dataset is not None: + dataset = self.predict_dataset or self.dataset + if dataset is not None: + sampler = self.predict_sampler or self.sampler + batch_sampler = self.predict_batch_sampler or self.batch_sampler + assert sampler or batch_sampler batch_size = self.predict_batch_size or self.batch_size - if self.predict_batch_sampler is not None or self.batch_sampler is not None: + if batch_sampler is not None: batch_size = 1 return DataLoader( - dataset=self.predict_dataset or self.dataset, # type: ignore[arg-type] + dataset=dataset, batch_size=batch_size, shuffle=False, - sampler=self.predict_sampler or self.sampler, - batch_sampler=self.predict_batch_sampler or self.batch_sampler, + sampler=sampler, + batch_sampler=batch_sampler, num_workers=self.num_workers, collate_fn=stack_samples, ) From dd66a6241870ae10eae3befe8e7d303a9ef46ede Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 3 Jan 2023 15:19:53 -0600 Subject: [PATCH 062/108] Fix bug introduced in NAIP dm --- torchgeo/datamodules/naip.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index a141b3543eb..d4774d9b242 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -76,19 +76,23 @@ def setup(self, stage: str) -> None: midy = roi.miny + (roi.maxy - roi.miny) / 2 if stage in ["fit"]: - roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) + train_roi = BoundingBox( + roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt + ) self.train_batch_sampler = RandomBatchGeoSampler( - self.dataset, self.patch_size, self.batch_size, self.length, roi + self.dataset, self.patch_size, self.batch_size, self.length, train_roi ) if stage in ["fit", "validate"]: - roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) + val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) self.val_sampler = GridGeoSampler( - self.dataset, self.patch_size, self.patch_size, roi + self.dataset, self.patch_size, self.patch_size, val_roi ) if stage in ["test"]: - roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) + test_roi = BoundingBox( + roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt + ) self.test_sampler = GridGeoSampler( - self.dataset, self.patch_size, self.patch_size, roi + self.dataset, self.patch_size, self.patch_size, test_roi ) def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: From a1970eab67449080ac14e4d08808952347534fc9 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 3 Jan 2023 15:31:34 -0600 Subject: [PATCH 063/108] Fix chesapeake dimensions --- torchgeo/datamodules/chesapeake.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 1e47b5b0fe8..617cfe3152b 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -7,6 +7,7 @@ import kornia.augmentation as K import torch.nn.functional as F +from einops import rearrange from torch import Tensor from ..datasets import ChesapeakeCVPR @@ -137,6 +138,10 @@ def on_after_batch_transfer( Returns: A batch of data. """ + # CenterCrop adds additional dimensions + batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") + if self.use_prior_labels: batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) batch["mask"] = F.normalize( From 61621b0bb40c138387ec21c26a58397293abc60b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 10:38:36 -0600 Subject: [PATCH 064/108] Add one to mask --- torchgeo/datamodules/spacenet.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index d8ed9f249de..f88472e0b7f 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -77,3 +77,22 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, self.val_split_pct, self.test_split_pct ) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + # We add 1 to the mask to map the current {background, building} labels to + # the values {1, 2}. This is necessary because we add 0 padding to the + # mask that we want to ignore in the loss function. + batch["mask"] += 1 + + return super().on_after_batch_transfer(batch, dataloader_idx) From 3b7a9f492cd1571ba1d0d684bb3333b2f77fe1b3 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 10:40:23 -0600 Subject: [PATCH 065/108] Fix missing imports --- torchgeo/datamodules/spacenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index f88472e0b7f..5bd7cdba476 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -3,9 +3,10 @@ """SpaceNet datamodules.""" -from typing import Any +from typing import Any, Dict import kornia.augmentation as K +from torch import Tensor from ..datasets import SpaceNet1 from ..transforms import AugmentationSequential From c63996e8e396ed7247a2e40e7b4d569e5b1d1bde Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 21:54:34 -0600 Subject: [PATCH 066/108] Fix batch size --- torchgeo/datamodules/chesapeake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 617cfe3152b..a6d90a117fd 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -108,7 +108,7 @@ def setup(self, stage: str) -> None: self.train_batch_sampler = RandomBatchGeoSampler( self.train_dataset, self.original_patch_size, - self.train_batch_size or 1, + self.batch_size, self.length, ) if stage in ["fit", "validate"]: From 716055f38aa013db8bfbda8ad070a35553d1e310 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 22:10:26 -0600 Subject: [PATCH 067/108] Simplify augmentations --- torchgeo/trainers/byol.py | 73 +++++++++++---------------------------- 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index d733df142ae..3b9e8d0db3e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -4,18 +4,15 @@ """BYOL tasks.""" import os -import random -from typing import Any, Callable, Dict, Optional, Tuple, cast +from typing import Any, Dict, Optional, Tuple, cast import pytorch_lightning as pl import timm import torch +import torch.nn as nn import torch.nn.functional as F from kornia import augmentation as K -from kornia import filters -from kornia.geometry import transform as KorniaTransform from torch import Tensor, optim -from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau from torchvision.models._api import WeightsEnum @@ -39,38 +36,10 @@ def normalized_mse(x: Tensor, y: Tensor) -> Tensor: return mse -# TODO: Move this to transforms -class RandomApply(Module): - """Applies augmentation function (augm) with probability p.""" - - def __init__(self, augm: Callable[[Tensor], Tensor], p: float) -> None: - """Initialize RandomApply. - - Args: - augm: augmentation function to apply - p: probability with which the augmentation function is applied - """ - super().__init__() - self.augm = augm - self.p = p - - def forward(self, x: Tensor) -> Tensor: - """Applies an augmentation to the input with some probability. - - Args: - x: a batch of imagery - - Returns - augmented version of ``x`` with probability ``self.p`` else an un-augmented - version - """ - return x if random.random() > self.p else self.augm(x) - - # TODO: This isn't _really_ applying the augmentations from SimCLR as we have # multispectral imagery and thus can't naively apply color jittering or grayscale # conversions. We should think more about what makes sense here. -class SimCLRAugmentation(Module): +class SimCLRAugmentation(nn.Module): """A module for applying SimCLR augmentations. SimCLR was one of the first papers to show the effectiveness of random data @@ -87,13 +56,13 @@ def __init__(self, image_size: Tuple[int, int] = (256, 256)) -> None: super().__init__() self.size = image_size - self.augmentation = Sequential( - KorniaTransform.Resize(size=image_size, align_corners=False), + self.augmentation = nn.Sequential( + K.Resize(size=image_size, align_corners=False), # Not suitable for multispectral adapt - # RandomApply(K.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + # K.ColorJitter(0.8, 0.8, 0.8, 0.8, 0.2), # K.RandomGrayscale(p=0.2), K.RandomHorizontalFlip(), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + K.RandomGaussianBlur((3, 3), (1.5, 1.5), p=0.1), K.RandomResizedCrop(size=image_size), ) @@ -109,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor: return cast(Tensor, self.augmentation(x)) -class MLP(Module): +class MLP(nn.Module): """MLP used in the BYOL projection head.""" def __init__( @@ -123,11 +92,11 @@ def __init__( hidden_size: size of the hidden layer """ super().__init__() - self.mlp = Sequential( - Linear(dim, hidden_size), - BatchNorm1d(hidden_size), - ReLU(inplace=True), - Linear(hidden_size, projection_size), + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size), ) def forward(self, x: Tensor) -> Tensor: @@ -142,7 +111,7 @@ def forward(self, x: Tensor) -> Tensor: return cast(Tensor, self.mlp(x)) -class BackboneWrapper(Module): +class BackboneWrapper(nn.Module): """Backbone wrapper for joining a model and a projection head. When we call .forward() on this module the following steps happen: @@ -158,7 +127,7 @@ class BackboneWrapper(Module): def __init__( self, - model: Module, + model: nn.Module, projection_size: int = 256, hidden_size: int = 4096, layer: int = -2, @@ -178,13 +147,13 @@ def __init__( self.hidden_size = hidden_size self.layer = layer - self._projector: Optional[Module] = None + self._projector: Optional[nn.Module] = None self._projector_dim: Optional[int] = None self._encoded = torch.empty(0) self._register_hook() @property - def projector(self) -> Module: + def projector(self) -> nn.Module: """Wrapper module for the projector head.""" assert self._projector_dim is not None if self._projector is None: @@ -234,7 +203,7 @@ def forward(self, x: Tensor) -> Tensor: return self._encoded -class BYOL(Module): +class BYOL(nn.Module): """BYOL implementation. BYOL contains two identical backbone networks. The first is trained as usual, and @@ -247,13 +216,13 @@ class BYOL(Module): def __init__( self, - model: Module, + model: nn.Module, image_size: Tuple[int, int] = (256, 256), hidden_layer: int = -2, in_channels: int = 4, projection_size: int = 256, hidden_size: int = 4096, - augment_fn: Optional[Module] = None, + augment_fn: Optional[nn.Module] = None, beta: float = 0.99, **kwargs: Any, ) -> None: @@ -273,7 +242,7 @@ def __init__( """ super().__init__() - self.augment: Module + self.augment: nn.Module if augment_fn is None: self.augment = SimCLRAugmentation(image_size) else: From 521c09cb4111a979ee26dd8431f32624379e9ea2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 22:20:19 -0600 Subject: [PATCH 068/108] Don't run test or predict without datasets --- tests/trainers/test_byol.py | 10 ++++++++-- tests/trainers/test_classification.py | 10 ++++++++-- tests/trainers/test_detection.py | 10 ++++++++-- tests/trainers/test_regression.py | 10 ++++++++-- tests/trainers/test_segmentation.py | 4 +++- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 492ca87eda4..e326b3d6bdb 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -72,8 +72,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + + if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + trainer.test(model=model, datamodule=datamodule) + + if datamodule.predict_dataset is not None or hasattr( + datamodule, "predict_sampler" + ): + trainer.predict(model=model, datamodule=datamodule) @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 8eb29a10566..19921bc15f7 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -72,8 +72,14 @@ def test_trainer( # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + + if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + trainer.test(model=model, datamodule=datamodule) + + if datamodule.predict_dataset is not None or hasattr( + datamodule, "predict_sampler" + ): + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 293b7bc33a2..ffc114315db 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -34,8 +34,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + + if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + trainer.test(model=model, datamodule=datamodule) + + if datamodule.predict_dataset is not None or hasattr( + datamodule, "predict_sampler" + ): + trainer.predict(model=model, datamodule=datamodule) @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 22524b0a8e2..9428ea9367e 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -52,8 +52,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + + if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + trainer.test(model=model, datamodule=datamodule) + + if datamodule.predict_dataset is not None or hasattr( + datamodule, "predict_sampler" + ): + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 1ed7f61b6bb..2db38d45005 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -87,7 +87,9 @@ def test_trainer( if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None: + if datamodule.predict_dataset is not None or hasattr( + datamodule, "predict_sampler" + ): trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: From 3201664d494f6f87c18dc42283fde7fae1ee0806 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 22:32:38 -0600 Subject: [PATCH 069/108] Fix tests --- tests/trainers/test_byol.py | 6 ++---- tests/trainers/test_classification.py | 6 ++---- tests/trainers/test_detection.py | 6 ++---- tests/trainers/test_regression.py | 6 ++---- tests/trainers/test_segmentation.py | 6 ++---- 5 files changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index e326b3d6bdb..386a6aa7ac9 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -73,12 +73,10 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset or datamodule.dataset: trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None or hasattr( - datamodule, "predict_sampler" - ): + if datamodule.predict_dataset or datamodule.predict_dataset: trainer.predict(model=model, datamodule=datamodule) @pytest.fixture diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 19921bc15f7..61e3795af1d 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -73,12 +73,10 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset or datamodule.dataset: trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None or hasattr( - datamodule, "predict_sampler" - ): + if datamodule.predict_dataset or datamodule.dataset: trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index ffc114315db..a9bcd666e30 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -35,12 +35,10 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset or datamodule.dataset: trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None or hasattr( - datamodule, "predict_sampler" - ): + if datamodule.predict_dataset or datamodule.dataset: trainer.predict(model=model, datamodule=datamodule) @pytest.fixture diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 9428ea9367e..00a4009873c 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -53,12 +53,10 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset or datamodule.dataset: trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None or hasattr( - datamodule, "predict_sampler" - ): + if datamodule.predict_dataset or datamodule.dataset: trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 2db38d45005..c7b113039b6 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -84,12 +84,10 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset is not None or hasattr(datamodule, "test_sampler"): + if datamodule.test_dataset or datamodule.dataset: trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset is not None or hasattr( - datamodule, "predict_sampler" - ): + if datamodule.predict_dataset or datamodule.dataset: trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: From db87b4290a1f8fc4e64a3d629abd9bdba438dbca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 22:41:59 -0600 Subject: [PATCH 070/108] Allow shared dataset --- torchgeo/datamodules/geo.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 987e359140c..f2699f778db 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -359,6 +359,7 @@ def __init__( self.kwargs = kwargs # Datasets + self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None @@ -423,9 +424,10 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'train_dataset'. """ - if self.train_dataset is not None: + dataset = self.train_dataset or self.dataset + if dataset is not None: return DataLoader( - dataset=self.train_dataset, + dataset=dataset, batch_size=self.train_batch_size or self.batch_size, shuffle=True, num_workers=self.num_workers, @@ -444,9 +446,10 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'val_dataset'. """ - if self.val_dataset is not None: + dataset = self.val_dataset or self.dataset + if dataset is not None: return DataLoader( - dataset=self.val_dataset, + dataset=dataset, batch_size=self.val_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, @@ -465,9 +468,10 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'test_dataset'. """ - if self.test_dataset is not None: + dataset = self.test_dataset or self.dataset + if dataset is not None: return DataLoader( - dataset=self.test_dataset, + dataset=dataset, batch_size=self.test_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, @@ -486,9 +490,10 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: MisconfigurationException: If :meth:`setup` does not define a 'predict_dataset'. """ - if self.predict_dataset is not None: + dataset = self.predict_dataset or self.dataset + if dataset is not None: return DataLoader( - dataset=self.predict_dataset, + dataset=dataset, batch_size=self.predict_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, From 752e830b71345e179904ad377a7fd1b4755599dd Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Jan 2023 23:11:06 -0600 Subject: [PATCH 071/108] One more try --- tests/trainers/test_byol.py | 21 ++++++++++----- tests/trainers/test_classification.py | 17 ++++++++---- tests/trainers/test_detection.py | 21 ++++++++++----- tests/trainers/test_regression.py | 22 ++++++++++----- tests/trainers/test_segmentation.py | 17 ++++++++---- torchgeo/datamodules/geo.py | 39 ++++++++++++++------------- 6 files changed, 90 insertions(+), 47 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 386a6aa7ac9..f2fb865efdc 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -16,7 +16,11 @@ from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ( + ChesapeakeCVPRDataModule, + GeoDataModule, + NonGeoDataModule, +) from torchgeo.models import ResNet18_Weights from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation @@ -73,11 +77,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - - if datamodule.predict_dataset or datamodule.predict_dataset: - trainer.predict(model=model, datamodule=datamodule) + if isinstance(datamodule, GeoDataModule): + if datamodule.test_dataset or datamodule.test_sampler: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.predict_sampler: + trainer.predict(model=model, datamodule=datamodule) + elif isinstance(datamodule, NonGeoDataModule): + if datamodule.test_dataset or datamodule.dataset: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.dataset: + trainer.predict(model=model, datamodule=datamodule) @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 61e3795af1d..c5fb54b4f6c 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -18,6 +18,8 @@ from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, + GeoDataModule, + NonGeoDataModule, RESISC45DataModule, So2SatDataModule, UCMercedDataModule, @@ -73,11 +75,16 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + if isinstance(datamodule, GeoDataModule): + if datamodule.test_dataset or datamodule.test_sampler: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.predict_sampler: + trainer.predict(model=model, datamodule=datamodule) + elif isinstance(datamodule, NonGeoDataModule): + if datamodule.test_dataset or datamodule.dataset: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.dataset: + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index a9bcd666e30..691e710fdee 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,7 +9,11 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datamodules import ( + GeoDataModule, + NASAMarineDebrisDataModule, + NonGeoDataModule, +) from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask @@ -35,11 +39,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + if isinstance(datamodule, GeoDataModule): + if datamodule.test_dataset or datamodule.test_sampler: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.predict_sampler: + trainer.predict(model=model, datamodule=datamodule) + elif isinstance(datamodule, NonGeoDataModule): + if datamodule.test_dataset or datamodule.dataset: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.dataset: + trainer.predict(model=model, datamodule=datamodule) @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 00a4009873c..3909fe6112b 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -14,7 +14,12 @@ from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule +from torchgeo.datamodules import ( + COWCCountingDataModule, + GeoDataModule, + NonGeoDataModule, + TropicalCycloneDataModule, +) from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask @@ -53,11 +58,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + if isinstance(datamodule, GeoDataModule): + if datamodule.test_dataset or datamodule.test_sampler: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.predict_sampler: + trainer.predict(model=model, datamodule=datamodule) + elif isinstance(datamodule, NonGeoDataModule): + if datamodule.test_dataset or datamodule.dataset: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.dataset: + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index c7b113039b6..d28faa49fe7 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -15,11 +15,13 @@ ChesapeakeCVPRDataModule, DeepGlobeLandCoverDataModule, ETCI2021DataModule, + GeoDataModule, GID15DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, NAIPChesapeakeDataModule, + NonGeoDataModule, Potsdam2DDataModule, SEN12MSDataModule, SpaceNet1DataModule, @@ -84,11 +86,16 @@ def test_trainer( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + if isinstance(datamodule, GeoDataModule): + if datamodule.test_dataset or datamodule.test_sampler: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.predict_sampler: + trainer.predict(model=model, datamodule=datamodule) + elif isinstance(datamodule, NonGeoDataModule): + if datamodule.test_dataset or datamodule.dataset: + trainer.test(model=model, datamodule=datamodule) + if datamodule.predict_dataset or datamodule.dataset: + trainer.predict(model=model, datamodule=datamodule) def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml")) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index f2699f778db..2013d7f6ce4 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -151,15 +151,15 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'train_dataset'. """ dataset = self.train_dataset or self.dataset - if dataset is not None: - sampler = self.train_sampler or self.sampler - batch_sampler = self.train_batch_sampler or self.batch_sampler - assert sampler or batch_sampler + sampler = self.train_sampler or self.sampler + batch_sampler = self.train_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.train_batch_size or self.batch_size shuffle = True if batch_sampler is not None: batch_size = 1 shuffle = False + sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, @@ -184,13 +184,13 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'val_dataset'. """ dataset = self.val_dataset or self.dataset - if dataset is not None: - sampler = self.val_sampler or self.sampler - batch_sampler = self.val_batch_sampler or self.batch_sampler - assert sampler or batch_sampler + sampler = self.val_sampler or self.sampler + batch_sampler = self.val_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.val_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 + sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, @@ -215,13 +215,13 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'test_dataset'. """ dataset = self.test_dataset or self.dataset - if dataset is not None: - sampler = self.test_sampler or self.sampler - batch_sampler = self.test_batch_sampler or self.batch_sampler - assert sampler or batch_sampler + sampler = self.test_sampler or self.sampler + batch_sampler = self.test_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.test_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 + sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, @@ -246,13 +246,13 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: 'predict_dataset'. """ dataset = self.predict_dataset or self.dataset - if dataset is not None: - sampler = self.predict_sampler or self.sampler - batch_sampler = self.predict_batch_sampler or self.batch_sampler - assert sampler or batch_sampler + sampler = self.predict_sampler or self.sampler + batch_sampler = self.predict_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.predict_batch_size or self.batch_size if batch_sampler is not None: batch_size = 1 + sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, @@ -538,6 +538,7 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: Returns: A matplotlib Figure with the image, ground truth, and predictions. """ - if self.val_dataset is not None: - if hasattr(self.val_dataset, "plot"): - return self.val_dataset.plot(*args, **kwargs) + dataset = self.val_dataset or self.dataset + if dataset is not None: + if hasattr(dataset, "plot"): + return dataset.plot(*args, **kwargs) From 11242f3efd9a15502ac72bb61fa1e90954417b70 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 12:30:18 -0600 Subject: [PATCH 072/108] Fix typo --- torchgeo/datamodules/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 2013d7f6ce4..6b0f3790670 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -28,7 +28,7 @@ class GeoDataModule(LightningDataModule): - """Base class for data modules lacking geospatial information.""" + """Base class for data modules containing geospatial information.""" mean = torch.tensor(0) std = torch.tensor(255) From 1ff4aa01b98cbd0179a5d3481d683a3d8f4f6cfe Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 12:36:16 -0600 Subject: [PATCH 073/108] Fix another typo --- torchgeo/datamodules/geo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 6b0f3790670..647903ea2ca 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -42,7 +42,7 @@ def __init__( num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new NonGeoDataModule instance. + """Initialize a new GeoDataModule instance. Args: dataset_class: Class used to instantiate a new dataset. From 05fb13aa253abda9bb27a649178e9626a312e226 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 14:31:24 -0600 Subject: [PATCH 074/108] Fix Chesapeake dimensions --- torchgeo/datamodules/chesapeake.py | 41 +++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index a6d90a117fd..50e699aa188 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List import kornia.augmentation as K +import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch import Tensor @@ -16,6 +17,38 @@ from .geo import GeoDataModule +class _Transform(nn.Module): + """Version of AugmentationSequential designed for samples, not batches.""" + + def __init__(self, aug: nn.Module) -> None: + """Initialize a new _Transform instance. + + Args: + aug: Augmentation to apply. + """ + super().__init__() + self.aug = aug + + def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Apply the augmentation. + + Args: + sample: Input sample. + + Returns: + Augmented sample. + """ + for key in ["image", "mask"]: + dtype = sample[key].dtype + # All inputs must be float + sample[key] = sample[key].float() + sample[key] = self.aug(sample[key]) + sample[key] = sample[key].to(dtype) + # Kornia adds batch dimension + sample[key] = rearrange(sample[key], "() c h w -> c h w") + return sample + + class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. @@ -61,9 +94,7 @@ def __init__( # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. self.original_patch_size = patch_size * 2 - kwargs["transforms"] = AugmentationSequential( - K.CenterCrop(patch_size), data_keys=["image", "mask"] - ) + kwargs["transforms"] = _Transform(K.CenterCrop(patch_size)) super().__init__( ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs @@ -138,10 +169,6 @@ def on_after_batch_transfer( Returns: A batch of data. """ - # CenterCrop adds additional dimensions - batch["image"] = rearrange(batch["image"], "b () c h w -> b c h w") - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - if self.use_prior_labels: batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) batch["mask"] = F.normalize( From 8f41fe583fe3367e8af9126c08a2f2d559ea2df8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 14:32:52 -0600 Subject: [PATCH 075/108] Apply augmentations during sanity check too --- torchgeo/datamodules/geo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 647903ea2ca..042e5a3e213 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -303,7 +303,7 @@ def on_after_batch_transfer( if self.trainer: if self.trainer.training: aug = self.train_aug or self.aug - elif self.trainer.validating: + elif self.trainer.validating or self.trainer.sanity_checking: aug = self.val_aug or self.aug elif self.trainer.testing: aug = self.test_aug or self.aug @@ -517,7 +517,7 @@ def on_after_batch_transfer( if self.trainer: if self.trainer.training: aug = self.train_aug or self.aug - elif self.trainer.validating: + elif self.trainer.validating or self.trainer.sanity_checking: aug = self.val_aug or self.aug elif self.trainer.testing: aug = self.test_aug or self.aug From 96befe31c1dc07e8a12b99a0dc752e4b227a84ca Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 14:49:01 -0600 Subject: [PATCH 076/108] Don't reuse fixtures --- tests/datamodules/test_fair1m.py | 2 +- tests/datamodules/test_oscd.py | 2 +- tests/datamodules/test_usavars.py | 2 +- tests/datamodules/test_xview2.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index 000b0144cc2..dd3f2bd13f2 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -11,7 +11,7 @@ class TestFAIR1MDataModule: - @pytest.fixture(scope="class") + @pytest.fixture def datamodule(self) -> FAIR1MDataModule: root = os.path.join("tests", "data", "fair1m") batch_size = 2 diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index bb430156fe4..f155920dc02 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -11,7 +11,7 @@ class TestOSCDDataModule: - @pytest.fixture(scope="class", params=["all", "rgb"]) + @pytest.fixture(params=["all", "rgb"]) def datamodule(self, request: SubRequest) -> OSCDDataModule: bands = request.param num_tiles_per_batch = 1 diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index 14b5fed5390..7f04644cc20 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -12,7 +12,7 @@ class TestUSAVarsDataModule: - @pytest.fixture() + @pytest.fixture def datamodule(self, request: SubRequest) -> USAVarsDataModule: pytest.importorskip("pandas", minversion="0.23.2") root = os.path.join("tests", "data", "usavars") diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index 6d614184b56..1a0e158b366 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -11,7 +11,7 @@ class TestXView2DataModule: - @pytest.fixture(scope="class") + @pytest.fixture def datamodule(self) -> XView2DataModule: root = os.path.join("tests", "data", "xview2") batch_size = 1 From 227a84e7a729fa738608653409014b54dfa60742 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 15:04:34 -0600 Subject: [PATCH 077/108] Increase coverage --- tests/trainers/test_byol.py | 30 +++++++++++++-------------- tests/trainers/test_classification.py | 26 +++++++++++------------ tests/trainers/test_detection.py | 28 ++++++++++++------------- tests/trainers/test_regression.py | 29 ++++++++++++-------------- tests/trainers/test_segmentation.py | 26 +++++++++++------------ 5 files changed, 66 insertions(+), 73 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index f2fb865efdc..a9138685a37 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -13,14 +13,15 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ( - ChesapeakeCVPRDataModule, - GeoDataModule, - NonGeoDataModule, -) +from torchgeo.datamodules import ChesapeakeCVPRDataModule from torchgeo.models import ResNet18_Weights from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation @@ -76,17 +77,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if isinstance(datamodule, GeoDataModule): - if datamodule.test_dataset or datamodule.test_sampler: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.predict_sampler: - trainer.predict(model=model, datamodule=datamodule) - elif isinstance(datamodule, NonGeoDataModule): - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index c5fb54b4f6c..36c5d19650d 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -12,14 +12,17 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, - GeoDataModule, - NonGeoDataModule, RESISC45DataModule, So2SatDataModule, UCMercedDataModule, @@ -74,17 +77,14 @@ def test_trainer( # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if isinstance(datamodule, GeoDataModule): - if datamodule.test_dataset or datamodule.test_sampler: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.predict_sampler: - trainer.predict(model=model, datamodule=datamodule) - elif isinstance(datamodule, NonGeoDataModule): - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 691e710fdee..d7aae815e0c 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,11 +9,12 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datamodules import ( - GeoDataModule, - NASAMarineDebrisDataModule, - NonGeoDataModule, +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, ) + +from torchgeo.datamodules import NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask @@ -38,17 +39,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if isinstance(datamodule, GeoDataModule): - if datamodule.test_dataset or datamodule.test_sampler: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.predict_sampler: - trainer.predict(model=model, datamodule=datamodule) - elif isinstance(datamodule, NonGeoDataModule): - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 3909fe6112b..f1156b52730 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -14,12 +14,12 @@ from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ( - COWCCountingDataModule, - GeoDataModule, - NonGeoDataModule, - TropicalCycloneDataModule, +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, ) + +from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask @@ -57,17 +57,14 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if isinstance(datamodule, GeoDataModule): - if datamodule.test_dataset or datamodule.test_sampler: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.predict_sampler: - trainer.predict(model=model, datamodule=datamodule) - elif isinstance(datamodule, NonGeoDataModule): - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index d28faa49fe7..edb1be72da6 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,19 +9,22 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torch.nn.modules import Module from torchgeo.datamodules import ( ChesapeakeCVPRDataModule, DeepGlobeLandCoverDataModule, ETCI2021DataModule, - GeoDataModule, GID15DataModule, InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, NAIPChesapeakeDataModule, - NonGeoDataModule, Potsdam2DDataModule, SEN12MSDataModule, SpaceNet1DataModule, @@ -85,17 +88,14 @@ def test_trainer( # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if isinstance(datamodule, GeoDataModule): - if datamodule.test_dataset or datamodule.test_sampler: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.predict_sampler: - trainer.predict(model=model, datamodule=datamodule) - elif isinstance(datamodule, NonGeoDataModule): - if datamodule.test_dataset or datamodule.dataset: - trainer.test(model=model, datamodule=datamodule) - if datamodule.predict_dataset or datamodule.dataset: - trainer.predict(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml")) From 2f8b5b4f02979be31fb6b3223928ce5decbb2ad2 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 15:15:10 -0600 Subject: [PATCH 078/108] Fix ETCI tests --- torchgeo/datamodules/etci2021.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index e8ce5a396d6..3c4a8dd1e83 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -74,8 +74,9 @@ def on_after_batch_transfer( Returns: A batch of data. """ - if "mask" in batch: - # Predict flood mask, not water mask - batch["mask"] = (batch["mask"][:, 1] > 0).long() + if self.trainer: + if not self.trainer.predicting: + # Evaluate against flood mask, not water mask + batch["mask"] = (batch["mask"][:, 1] > 0).long() return super().on_after_batch_transfer(batch, dataloader_idx) From fb2e73beb8abb425fd4440b6196139bea33c4b33 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 17:52:00 -0600 Subject: [PATCH 079/108] Test predict_step --- tests/trainers/test_byol.py | 27 ++++++++++++++++++++ tests/trainers/test_classification.py | 36 +++++++++++++++++++++++++-- tests/trainers/test_detection.py | 13 ++++++++++ tests/trainers/test_regression.py | 14 +++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index a9138685a37..b7690e78ee2 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -22,7 +22,9 @@ from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.datasets import ChesapeakeCVPR from torchgeo.models import ResNet18_Weights +from torchgeo.samplers import GridGeoSampler from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation @@ -34,6 +36,16 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +class CustomBYOLDataModule(ChesapeakeCVPRDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = ChesapeakeCVPR( + splits=self.test_splits, layers=self.layers, **self.kwargs + ) + self.predict_sampler = GridGeoSampler( + self.predict_dataset, self.original_patch_size, self.original_patch_size + ) + + class TestBYOL: def test_custom_augment_fn(self) -> None: backbone = resnet18() @@ -115,3 +127,18 @@ def test_weight_str( ) -> None: model_kwargs["weights"] = str(mocked_weights) BYOLTask(**model_kwargs) + + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + datamodule = CustomBYOLDataModule( + root="tests/data/chesapeake/cvpr", + train_splits=["de-test"], + val_splits=["de-test"], + test_splits=["de-test"], + batch_size=1, + patch_size=64, + num_workers=0, + ) + model_kwargs["in_channels"] = 4 + model = BYOLTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 36c5d19650d..76945211692 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -34,6 +34,16 @@ from .test_utils import ClassificationTestModel +class CustomClassificationDataModule(EuroSATDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = EuroSAT(split="test", **self.kwargs) + + +class CustomMultiLabelClassificationDataModule(BigEarthNetDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = BigEarthNet(split="test", **self.kwargs) + + def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) @@ -161,6 +171,14 @@ def test_missing_attributes( trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + datamodule = CustomClassificationDataModule( + root="tests/data/eurosat", batch_size=1, num_workers=0 + ) + model = ClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, dataloaders=datamodule) + class TestMultiLabelClassificationTask: @pytest.mark.parametrize( @@ -190,8 +208,14 @@ def test_trainer( # Instantiate trainer trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass def test_no_logger(self) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) @@ -238,3 +262,11 @@ def test_missing_attributes( model = MultiLabelClassificationTask(**model_kwargs) trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + datamodule = CustomMultiLabelClassificationDataModule( + root="tests/data/bigearthnet", batch_size=1, num_workers=0 + ) + model = MultiLabelClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, dataloaders=datamodule) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index d7aae815e0c..f362e2566be 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -19,6 +19,11 @@ from torchgeo.trainers import ObjectDetectionTask +class CustomObjectDetectionDataModule(NASAMarineDebrisDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = NASAMarineDebris(**self.kwargs) + + class TestObjectDetectionTask: @pytest.mark.parametrize( "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] @@ -78,3 +83,11 @@ def test_missing_attributes( model = ObjectDetectionTask(**model_kwargs) trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + datamodule = CustomObjectDetectionDataModule( + root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + ) + model = ObjectDetectionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index f1156b52730..da7d928f6e6 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -19,6 +19,7 @@ MisconfigurationException, ) +from torchgeo.datasets import TropicalCyclone from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask @@ -31,6 +32,11 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +class CustomRegressionDataModule(TropicalCycloneDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = TropicalCyclone(split="test", **self.kwargs) + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -122,3 +128,11 @@ def test_weight_str( model_kwargs["weights"] = str(mocked_weights) with pytest.warns(UserWarning): RegressionTask(**model_kwargs) + + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + datamodule = CustomRegressionDataModule( + root="tests/data/cyclone", batch_size=1, num_workers=0 + ) + model = RegressionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) From 6bcc41e02594a4308ae9903c323b35a19f8155ff Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 17:55:24 -0600 Subject: [PATCH 080/108] Test all loss methods --- tests/conf/so2sat_all.yaml | 2 +- tests/conf/so2sat_s2.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml index 83889f10735..a8d8c0bb8e3 100644 --- a/tests/conf/so2sat_all.yaml +++ b/tests/conf/so2sat_all.yaml @@ -1,7 +1,7 @@ experiment: task: "so2sat" module: - loss: "focal" + loss: "ce" model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 diff --git a/tests/conf/so2sat_s2.yaml b/tests/conf/so2sat_s2.yaml index e8e61dcb739..ab9c573a197 100644 --- a/tests/conf/so2sat_s2.yaml +++ b/tests/conf/so2sat_s2.yaml @@ -1,7 +1,7 @@ experiment: task: "so2sat" module: - loss: "focal" + loss: "jaccard" model: "resnet18" learning_rate: 1e-3 learning_rate_schedule_patience: 6 From b3aaacffd439751d604ed86994d0cf6712eb9281 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 21:14:44 -0600 Subject: [PATCH 081/108] Simplify validation plotting --- tests/trainers/test_classification.py | 70 +++++---------------------- tests/trainers/test_detection.py | 12 ----- tests/trainers/test_regression.py | 19 -------- tests/trainers/test_segmentation.py | 30 ------------ torchgeo/trainers/classification.py | 62 +++++++++++++----------- torchgeo/trainers/detection.py | 42 ++++++++-------- torchgeo/trainers/regression.py | 32 ++++++------ torchgeo/trainers/segmentation.py | 32 ++++++------ 8 files changed, 100 insertions(+), 199 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 76945211692..cb2bbefe957 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -96,25 +96,6 @@ def test_trainer( except MisconfigurationException: pass - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = UCMercedDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = ClassificationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) - trainer.fit(model=model, datamodule=datamodule) - @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: return { @@ -160,16 +141,17 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch - ) -> None: - monkeypatch.delattr(EuroSAT, "plot") - datamodule = EuroSATDataModule( - root="tests/data/eurosat", batch_size=1, num_workers=0 - ) - model = ClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.validate(model=model, datamodule=datamodule) + def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["model"] = "invalid_model" + match = "Model type 'invalid_model' is not a valid timm model." + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) + + def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: + model_kwargs["weights"] = "invalid_weights" + match = "Weight type 'invalid_weights' is not valid." + with pytest.raises(ValueError, match=match): + ClassificationTask(**model_kwargs) def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: datamodule = CustomClassificationDataModule( @@ -217,25 +199,6 @@ def test_trainer( except MisconfigurationException: pass - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = BigEarthNetDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = MultiLabelClassificationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) - trainer.fit(model=model, datamodule=datamodule) - @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: return { @@ -252,17 +215,6 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch - ) -> None: - monkeypatch.delattr(BigEarthNet, "plot") - datamodule = BigEarthNetDataModule( - root="tests/data/bigearthnet", batch_size=1, num_workers=0 - ) - model = MultiLabelClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: datamodule = CustomMultiLabelClassificationDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index f362e2566be..f0da12ffceb 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -5,7 +5,6 @@ from typing import Any, Dict, Type, cast import pytest -from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer @@ -73,17 +72,6 @@ def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: model_kwargs["pretrained"] = False ObjectDetectionTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch - ) -> None: - monkeypatch.delattr(NASAMarineDebris, "plot") - datamodule = NASAMarineDebrisDataModule( - root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 - ) - model = ObjectDetectionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: datamodule = CustomObjectDetectionDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index da7d928f6e6..57122ca40a2 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -72,25 +72,6 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: except MisconfigurationException: pass - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = TropicalCycloneDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = RegressionTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) - trainer.fit(model=model, datamodule=datamodule) - @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: return { diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index edb1be72da6..e0cc5419e15 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -97,25 +97,6 @@ def test_trainer( except MisconfigurationException: pass - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = LandCoverAIDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = SemanticSegmentationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) - trainer.fit(model=model, datamodule=datamodule) - @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: return { @@ -152,14 +133,3 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: match = "ignore_index has no effect on training when loss='jaccard'" with pytest.warns(UserWarning, match=match): SemanticSegmentationTask(**model_kwargs) - - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch - ) -> None: - monkeypatch.delattr(LandCoverAI, "plot") - datamodule = LandCoverAIDataModule( - root="tests/data/landcoverai", batch_size=1, num_workers=0 - ) - model = SemanticSegmentationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.validate(model=model, datamodule=datamodule) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 35a71e01e22..300a0654844 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -187,21 +187,23 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: - try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] - batch["prediction"] = y_hat_hard - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except (AttributeError, ValueError): - pass + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. @@ -366,20 +368,22 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: - try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] - batch["prediction"] = y_hat_hard - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - except AttributeError: - pass + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test loss. diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 13a69846bca..7e02777d4b6 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -169,26 +169,28 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.val_metrics.update(y_hat, y) - if batch_idx < 10: - try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] - batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] - batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] - batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] - batch["image"] = batch["image"].cpu() - sample = unbind_samples(batch)[0] - # Convert image to uint8 for plotting - if torch.is_floating_point(sample["image"]): - sample["image"] *= 255 - sample["image"] = sample["image"].to(torch.uint8) - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except (AttributeError, ValueError): - pass + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): + datamodule = self.trainer.datamodule + batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] + batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] + batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] + batch["image"] = batch["image"].cpu() + sample = unbind_samples(batch)[0] + # Convert image to uint8 for plotting + if torch.is_floating_point(sample["image"]): + sample["image"] *= 255 + sample["image"] = sample["image"].to(torch.uint8) + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 0d5595fa8a9..959347fc284 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -144,21 +144,23 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss) self.val_metrics(y_hat, y) - if batch_idx < 10: - try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] - batch["prediction"] = y_hat - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except (AttributeError, ValueError): - pass + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 44090009dfb..a31324c7901 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -197,21 +197,23 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: - try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] - batch["prediction"] = y_hat_hard - for key in ["image", "mask", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() - except (AttributeError, ValueError): - pass + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. From a56e5f3568099bd4e04c27fb83bd878d0a89f107 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 21:17:54 -0600 Subject: [PATCH 082/108] Document new classes --- torchgeo/datamodules/geo.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 042e5a3e213..bb1d325273b 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -28,7 +28,10 @@ class GeoDataModule(LightningDataModule): - """Base class for data modules containing geospatial information.""" + """Base class for data modules containing geospatial information. + + .. versionadded:: 0.4 + """ mean = torch.tensor(0) std = torch.tensor(255) @@ -331,7 +334,10 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: class NonGeoDataModule(LightningDataModule): - """Base class for data modules lacking geospatial information.""" + """Base class for data modules lacking geospatial information. + + .. versionadded:: 0.4 + """ mean = torch.tensor(0) std = torch.tensor(255) From 03f4ce92b7e4f8fe8a605813f8c135d74cc3a5d6 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 21:58:12 -0600 Subject: [PATCH 083/108] Fix plotting --- torchgeo/datamodules/naip.py | 4 ++-- torchgeo/datasets/chesapeake.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index d4774d9b242..b0e21b15505 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -67,8 +67,8 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - chesapeake = Chesapeake13(**self.chesapeake_kwargs) - naip = NAIP(**self.naip_kwargs) + self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) + self.naip = NAIP(**self.naip_kwargs) self.dataset = chesapeake & naip roi = self.dataset.bounds diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 3d646c79481..024c384dc0a 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -765,7 +765,11 @@ def plot( .. versionadded:: 0.4 """ image = np.rollaxis(sample["image"].numpy(), 0, 3) - mask = np.rollaxis(sample["mask"].numpy(), 0, 3) + mask = sample["mask"].numpy() + if mask.ndim == 3: + mask = np.rollaxis(mask, 0, 3) + else: + mask = np.expand_dims(mask, 2) num_panels = len(self.layers) showing_predictions = "prediction" in sample From 142bfc49fa41a1abab4a455997cf15a7de33f461 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 22:01:39 -0600 Subject: [PATCH 084/108] Plotting should be robust in case dataset does not contain RGB bands --- torchgeo/trainers/classification.py | 48 ++++++++++++++++------------- torchgeo/trainers/detection.py | 35 +++++++++++---------- torchgeo/trainers/regression.py | 25 ++++++++------- torchgeo/trainers/segmentation.py | 25 ++++++++------- 4 files changed, 74 insertions(+), 59 deletions(-) diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 300a0654844..e8af2dcb70f 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -193,17 +193,20 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and self.logger and hasattr(self.logger, "experiment") ): - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. @@ -374,16 +377,19 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and self.logger and hasattr(self.logger, "experiment") ): - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + except ValueError: + pass def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test loss. diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 7e02777d4b6..9037208998e 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -175,22 +175,25 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and self.logger and hasattr(self.logger, "experiment") ): - datamodule = self.trainer.datamodule - batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] - batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] - batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] - batch["image"] = batch["image"].cpu() - sample = unbind_samples(batch)[0] - # Convert image to uint8 for plotting - if torch.is_floating_point(sample["image"]): - sample["image"] *= 255 - sample["image"] = sample["image"].to(torch.uint8) - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() + try: + datamodule = self.trainer.datamodule + batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] + batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] + batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] + batch["image"] = batch["image"].cpu() + sample = unbind_samples(batch)[0] + # Convert image to uint8 for plotting + if torch.is_floating_point(sample["image"]): + sample["image"] *= 255 + sample["image"] = sample["image"].to(torch.uint8) + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 959347fc284..98234bdec5d 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -150,17 +150,20 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and self.logger and hasattr(self.logger, "experiment") ): - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat - for key in ["image", "label", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat + for key in ["image", "label", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index a31324c7901..b5868196fc2 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -203,17 +203,20 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: and self.logger and hasattr(self.logger, "experiment") ): - datamodule = self.trainer.datamodule - batch["prediction"] = y_hat_hard - for key in ["image", "mask", "prediction"]: - batch[key] = batch[key].cpu() - sample = unbind_samples(batch)[0] - fig = datamodule.plot(sample) - summary_writer = self.logger.experiment - summary_writer.add_figure( - f"image/{batch_idx}", fig, global_step=self.global_step - ) - plt.close() + try: + datamodule = self.trainer.datamodule + batch["prediction"] = y_hat_hard + for key in ["image", "mask", "prediction"]: + batch[key] = batch[key].cpu() + sample = unbind_samples(batch)[0] + fig = datamodule.plot(sample) + summary_writer = self.logger.experiment + summary_writer.add_figure( + f"image/{batch_idx}", fig, global_step=self.global_step + ) + plt.close() + except ValueError: + pass def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. From be594a136024a41e26d6bffe910afeb255d6c56c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Thu, 5 Jan 2023 22:05:45 -0600 Subject: [PATCH 085/108] Fix flake8 --- torchgeo/datamodules/naip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index b0e21b15505..34082104552 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -69,7 +69,7 @@ def setup(self, stage: str) -> None: """ self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) self.naip = NAIP(**self.naip_kwargs) - self.dataset = chesapeake & naip + self.dataset = self.chesapeake & self.naip roi = self.dataset.bounds midx = roi.minx + (roi.maxx - roi.minx) / 2 From d0d593be8a69a07e301fb9e6c6eb404c3c013467 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 6 Jan 2023 12:35:49 -0600 Subject: [PATCH 086/108] 100% coverage of trainers --- tests/trainers/test_byol.py | 4 +-- tests/trainers/test_classification.py | 38 ++++++++++++++++++++++----- tests/trainers/test_detection.py | 20 ++++++++++++-- tests/trainers/test_regression.py | 19 ++++++++++++-- tests/trainers/test_segmentation.py | 16 +++++++++++ torchgeo/datasets/bigearthnet.py | 3 --- 6 files changed, 85 insertions(+), 15 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index b7690e78ee2..cf84b04f837 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -36,7 +36,7 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict -class CustomBYOLDataModule(ChesapeakeCVPRDataModule): +class PredictBYOLDataModule(ChesapeakeCVPRDataModule): def setup(self, stage: str) -> None: self.predict_dataset = ChesapeakeCVPR( splits=self.test_splits, layers=self.layers, **self.kwargs @@ -129,7 +129,7 @@ def test_weight_str( BYOLTask(**model_kwargs) def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: - datamodule = CustomBYOLDataModule( + datamodule = PredictBYOLDataModule( root="tests/data/chesapeake/cvpr", train_splits=["de-test"], val_splits=["de-test"], diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index cb2bbefe957..03fcf62202e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -34,12 +34,12 @@ from .test_utils import ClassificationTestModel -class CustomClassificationDataModule(EuroSATDataModule): +class PredictClassificationDataModule(EuroSATDataModule): def setup(self, stage: str) -> None: self.predict_dataset = EuroSAT(split="test", **self.kwargs) -class CustomMultiLabelClassificationDataModule(BigEarthNetDataModule): +class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule): def setup(self, stage: str) -> None: self.predict_dataset = BigEarthNet(split="test", **self.kwargs) @@ -53,6 +53,10 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestClassificationTask: @pytest.mark.parametrize( "name,classname", @@ -153,13 +157,24 @@ def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + ) -> None: + monkeypatch.setattr(EuroSATDataModule, "plot", plot) + datamodule = EuroSATDataModule( + root="tests/data/eurosat", batch_size=1, num_workers=0 + ) + model = ClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: - datamodule = CustomClassificationDataModule( + datamodule = PredictClassificationDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) model = ClassificationTask(**model_kwargs) trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.predict(model=model, dataloaders=datamodule) + trainer.predict(model=model, datamodule=datamodule) class TestMultiLabelClassificationTask: @@ -215,10 +230,21 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + ) -> None: + monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) + datamodule = BigEarthNetDataModule( + root="tests/data/bigearthnet", batch_size=1, num_workers=0 + ) + model = MultiLabelClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: - datamodule = CustomMultiLabelClassificationDataModule( + datamodule = PredictMultiLabelClassificationDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask(**model_kwargs) trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.predict(model=model, dataloaders=datamodule) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index f0da12ffceb..a6446cf1cc3 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Type, cast import pytest +from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer @@ -18,11 +19,15 @@ from torchgeo.trainers import ObjectDetectionTask -class CustomObjectDetectionDataModule(NASAMarineDebrisDataModule): +class PredictObjectDetectionDataModule(NASAMarineDebrisDataModule): def setup(self, stage: str) -> None: self.predict_dataset = NASAMarineDebris(**self.kwargs) +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestObjectDetectionTask: @pytest.mark.parametrize( "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] @@ -72,8 +77,19 @@ def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: model_kwargs["pretrained"] = False ObjectDetectionTask(**model_kwargs) + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + ) -> None: + monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) + datamodule = NASAMarineDebrisDataModule( + root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + ) + model = ObjectDetectionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: - datamodule = CustomObjectDetectionDataModule( + datamodule = PredictObjectDetectionDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) model = ObjectDetectionTask(**model_kwargs) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 57122ca40a2..92b13c0614e 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -32,11 +32,15 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict -class CustomRegressionDataModule(TropicalCycloneDataModule): +class PredictRegressionDataModule(TropicalCycloneDataModule): def setup(self, stage: str) -> None: self.predict_dataset = TropicalCyclone(split="test", **self.kwargs) +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -110,8 +114,19 @@ def test_weight_str( with pytest.warns(UserWarning): RegressionTask(**model_kwargs) + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + ) -> None: + monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) + datamodule = TropicalCycloneDataModule( + root="tests/data/cyclone", batch_size=1, num_workers=0 + ) + model = RegressionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: - datamodule = CustomRegressionDataModule( + datamodule = PredictRegressionDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) model = RegressionTask(**model_kwargs) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index e0cc5419e15..a12324c3d95 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -40,6 +40,10 @@ def create_model(**kwargs: Any) -> Module: return SegmentationTestModel(**kwargs) +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestSemanticSegmentationTask: @pytest.mark.parametrize( "name,classname", @@ -133,3 +137,15 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: match = "ignore_index has no effect on training when loss='jaccard'" with pytest.warns(UserWarning, match=match): SemanticSegmentationTask(**model_kwargs) + + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + ) -> None: + model_kwargs["in_channels"] = 15 + monkeypatch.setattr(SEN12MSDataModule, "plot", plot) + datamodule = SEN12MSDataModule( + root="tests/data/sen12ms", batch_size=1, num_workers=0 + ) + model = SemanticSegmentationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index cbfe0c2203b..e2f49770af1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -539,9 +539,6 @@ def plot( Returns: a matplotlib Figure with the rendered sample - Raises: - ValueError: if ``self.bands`` is "s1" - .. versionadded:: 0.2 """ if self.bands == "s2": From 04bae75ea15460e5a0cf8e41c6e40499c47ebd1d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 6 Jan 2023 14:30:07 -0600 Subject: [PATCH 087/108] Add lightning-lite dependency --- environment.yml | 1 + requirements/min.old | 1 + requirements/required.old | 1 + requirements/required.txt | 1 + setup.cfg | 2 ++ tests/trainers/test_byol.py | 6 +----- tests/trainers/test_classification.py | 6 +----- tests/trainers/test_detection.py | 6 +----- tests/trainers/test_regression.py | 2 ++ tests/trainers/test_segmentation.py | 6 +----- torchgeo/datamodules/geo.py | 6 +----- 11 files changed, 13 insertions(+), 25 deletions(-) diff --git a/environment.yml b/environment.yml index 7406b440188..0f750942283 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - isort[colors]>=5.8 - kornia>=0.6.5 - laspy>=2 + - lightning-lite>=1.8 - mypy>=0.900 - nbmake>=0.1 - nbsphinx>=0.8.5 diff --git a/requirements/min.old b/requirements/min.old index 6d22b601bbd..2179dad077f 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -5,6 +5,7 @@ setuptools==42.0.0 einops==0.3.0 fiona==1.8.0 kornia==0.6.5 +lightning-lite==1.8.0 matplotlib==3.3.0 numpy==1.17.2 omegaconf==2.1.0 diff --git a/requirements/required.old b/requirements/required.old index 1e120e0948c..206f5d6fc85 100644 --- a/requirements/required.old +++ b/requirements/required.old @@ -5,6 +5,7 @@ setuptools==62.6.0 einops==0.4.1 fiona==1.9a2 kornia==0.6.5 +lightning-lite==1.8.6 matplotlib==3.5.2 numpy==1.23.0;python_version>='3.8' numpy==1.21.6;python_version=='3.7' diff --git a/requirements/required.txt b/requirements/required.txt index 548c948e821..afea0259289 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -5,6 +5,7 @@ setuptools==66.1.1 einops==0.6.0 fiona==1.9b1 kornia==0.6.9 +lightning-lite==1.8.6 matplotlib==3.6.3 numpy==1.24.1;python_version>='3.8' omegaconf==2.3.0 diff --git a/setup.cfg b/setup.cfg index ce7b2446bc2..8986c17b18f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,8 @@ install_requires = fiona>=1.8,<2 # kornia 0.6.5+ required due to change in kornia.augmentation API kornia>=0.6.5,<0.7 + # lightning-lite 1.8 is oldest version on PyPI + lightning-lite>=1.8,<2 # matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow matplotlib>=3.3,<4 # numpy 1.17.2+ required by pytorch-lightning diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index cf84b04f837..57db4553cca 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -11,13 +11,9 @@ import torch.nn as nn import torchvision from _pytest.monkeypatch import MonkeyPatch +from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 03fcf62202e..9f95411e94e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -10,13 +10,9 @@ import torch import torchvision from _pytest.monkeypatch import MonkeyPatch +from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index a6446cf1cc3..8ce1354f799 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -6,14 +6,10 @@ import pytest from _pytest.monkeypatch import MonkeyPatch +from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) - from torchgeo.datamodules import NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 92b13c0614e..e0782e7f838 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -10,6 +10,7 @@ import torch import torchvision from _pytest.monkeypatch import MonkeyPatch +from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum @@ -19,6 +20,7 @@ MisconfigurationException, ) +from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.models import ResNet18_Weights diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index a12324c3d95..4567de21195 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -7,13 +7,9 @@ import pytest import segmentation_models_pytorch as smp from _pytest.monkeypatch import MonkeyPatch +from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchgeo.datamodules import ( diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index bb1d325273b..f2ef4515e1c 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -8,12 +8,8 @@ import kornia.augmentation as K import matplotlib.pyplot as plt import torch +from lightning_lite.utilities.exceptions import MisconfigurationException from pytorch_lightning import LightningDataModule - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch import Tensor from torch.utils.data import DataLoader, Dataset From 77aae2889d97cf94f4fdf230b4ee073bd84349e4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 6 Jan 2023 14:37:50 -0600 Subject: [PATCH 088/108] Revert "Add lightning-lite dependency" This reverts commit 1df7291ae59f6257a2cabd20a6c767e178bf4f0f. --- environment.yml | 1 - requirements/min.old | 1 - requirements/required.old | 1 - requirements/required.txt | 1 - setup.cfg | 2 -- tests/trainers/test_byol.py | 5 +++++ tests/trainers/test_classification.py | 6 +++++- tests/trainers/test_detection.py | 6 +++++- tests/trainers/test_regression.py | 6 +++++- tests/trainers/test_segmentation.py | 6 +++++- torchgeo/datamodules/geo.py | 6 +++++- 11 files changed, 30 insertions(+), 11 deletions(-) diff --git a/environment.yml b/environment.yml index 0f750942283..7406b440188 100644 --- a/environment.yml +++ b/environment.yml @@ -24,7 +24,6 @@ dependencies: - isort[colors]>=5.8 - kornia>=0.6.5 - laspy>=2 - - lightning-lite>=1.8 - mypy>=0.900 - nbmake>=0.1 - nbsphinx>=0.8.5 diff --git a/requirements/min.old b/requirements/min.old index 2179dad077f..6d22b601bbd 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -5,7 +5,6 @@ setuptools==42.0.0 einops==0.3.0 fiona==1.8.0 kornia==0.6.5 -lightning-lite==1.8.0 matplotlib==3.3.0 numpy==1.17.2 omegaconf==2.1.0 diff --git a/requirements/required.old b/requirements/required.old index 206f5d6fc85..1e120e0948c 100644 --- a/requirements/required.old +++ b/requirements/required.old @@ -5,7 +5,6 @@ setuptools==62.6.0 einops==0.4.1 fiona==1.9a2 kornia==0.6.5 -lightning-lite==1.8.6 matplotlib==3.5.2 numpy==1.23.0;python_version>='3.8' numpy==1.21.6;python_version=='3.7' diff --git a/requirements/required.txt b/requirements/required.txt index afea0259289..548c948e821 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -5,7 +5,6 @@ setuptools==66.1.1 einops==0.6.0 fiona==1.9b1 kornia==0.6.9 -lightning-lite==1.8.6 matplotlib==3.6.3 numpy==1.24.1;python_version>='3.8' omegaconf==2.3.0 diff --git a/setup.cfg b/setup.cfg index 8986c17b18f..ce7b2446bc2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,8 +31,6 @@ install_requires = fiona>=1.8,<2 # kornia 0.6.5+ required due to change in kornia.augmentation API kornia>=0.6.5,<0.7 - # lightning-lite 1.8 is oldest version on PyPI - lightning-lite>=1.8,<2 # matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow matplotlib>=3.3,<4 # numpy 1.17.2+ required by pytorch-lightning diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 57db4553cca..ef44c8a054f 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -14,6 +14,11 @@ from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 9f95411e94e..03fcf62202e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -10,9 +10,13 @@ import torch import torchvision from _pytest.monkeypatch import MonkeyPatch -from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torch.nn.modules import Module from torchvision.models._api import WeightsEnum diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 8ce1354f799..a6446cf1cc3 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -6,10 +6,14 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) + from torchgeo.datamodules import NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index e0782e7f838..4efed8e716b 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -10,7 +10,6 @@ import torch import torchvision from _pytest.monkeypatch import MonkeyPatch -from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum @@ -20,6 +19,11 @@ MisconfigurationException, ) +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) + from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 4567de21195..a12324c3d95 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -7,9 +7,13 @@ import pytest import segmentation_models_pytorch as smp from _pytest.monkeypatch import MonkeyPatch -from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torch.nn.modules import Module from torchgeo.datamodules import ( diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index f2ef4515e1c..bb1d325273b 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -8,8 +8,12 @@ import kornia.augmentation as K import matplotlib.pyplot as plt import torch -from lightning_lite.utilities.exceptions import MisconfigurationException from pytorch_lightning import LightningDataModule + +# TODO: import from lightning_lite instead +from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] + MisconfigurationException, +) from torch import Tensor from torch.utils.data import DataLoader, Dataset From 5f5782de8264e1d430acf026e1cd245dcd55ffa4 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 6 Jan 2023 14:45:39 -0600 Subject: [PATCH 089/108] Define our own MisconfigurationException --- docs/api/datamodules.rst | 5 +++++ tests/trainers/test_byol.py | 7 +------ tests/trainers/test_classification.py | 6 +----- tests/trainers/test_detection.py | 7 +------ tests/trainers/test_regression.py | 7 +++---- tests/trainers/test_segmentation.py | 6 +----- torchgeo/datamodules/__init__.py | 3 +++ torchgeo/datamodules/geo.py | 6 +----- torchgeo/datamodules/utils.py | 5 +++++ 9 files changed, 21 insertions(+), 31 deletions(-) diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 69292c1b5d2..d66fee22755 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -141,3 +141,8 @@ NonGeoDataModule ^^^^^^^^^^^^^^^^ .. autoclass:: NonGeoDataModule + +Utilities +--------- + +.. autoclass:: MisconfigurationException diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index ef44c8a054f..5b3641d36d3 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -14,15 +14,10 @@ from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException from torchgeo.datasets import ChesapeakeCVPR from torchgeo.models import ResNet18_Weights from torchgeo.samplers import GridGeoSampler diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 03fcf62202e..d3b4b77778b 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -12,17 +12,13 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchvision.models._api import WeightsEnum from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, + MisconfigurationException, RESISC45DataModule, So2SatDataModule, UCMercedDataModule, diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index a6446cf1cc3..fe7fae1a2de 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,12 +9,7 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) - -from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 4efed8e716b..c83bad7e103 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -19,12 +19,11 @@ MisconfigurationException, ) -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] +from torchgeo.datamodules import ( + COWCCountingDataModule, MisconfigurationException, + TropicalCycloneDataModule, ) - -from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.models import ResNet18_Weights diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index a12324c3d95..6282152b8da 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,11 +9,6 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchgeo.datamodules import ( @@ -24,6 +19,7 @@ InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, + MisconfigurationException, NAIPChesapeakeDataModule, Potsdam2DDataModule, SEN12MSDataModule, diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 2f846c4a5fc..fe4dafaa986 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -26,6 +26,7 @@ from .spacenet import SpaceNet1DataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule +from .utils import MisconfigurationException from .vaihingen import Vaihingen2DDataModule from .xview import XView2DataModule @@ -59,4 +60,6 @@ # Base classes "GeoDataModule", "NonGeoDataModule", + # Utilities + "MisconfigurationException", ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index bb1d325273b..5c94839c33d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -9,11 +9,6 @@ import matplotlib.pyplot as plt import torch from pytorch_lightning import LightningDataModule - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -25,6 +20,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential +from .utils import MisconfigurationException class GeoDataModule(LightningDataModule): diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d088e493312..be50cd7dc99 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -10,6 +10,11 @@ from ..datasets import NonGeoDataset +# Based on lightning_lite.utilities.exceptions +class MisconfigurationException(Exception): + """Exception used to inform users of misuse with Lightning.""" + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, From 49cc3882b3d90a6b93d8d8ea1bb843f78bc9a94a Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 6 Jan 2023 17:47:02 -0600 Subject: [PATCH 090/108] Properly test new data module base classes --- tests/datamodules/test_geo.py | 121 ++++++++++++++++++++++++++++++++++ torchgeo/datamodules/geo.py | 6 -- 2 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 tests/datamodules/test_geo.py diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py new file mode 100644 index 00000000000..3d3444b60d1 --- /dev/null +++ b/tests/datamodules/test_geo.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Any, Dict + +import pytest +import torch +from torch import Tensor + +from torchgeo.datamodules import ( + GeoDataModule, + MisconfigurationException, + NonGeoDataModule, +) +from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset +from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler + + +class CustomGeoDataset(GeoDataset): + def __init__(self, split: str = "train", download: bool = False) -> None: + super().__init__() + self.index.insert(0, (0, 1, 2, 3, 4, 5)) + self.res = 1 + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} + + +class CustomGeoDataModule(GeoDataModule): + def __init__(self) -> None: + super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True) + + +class SamplerGeoDatModule(CustomGeoDataModule): + def setup(self, stage: str) -> None: + self.dataset = CustomGeoDataset() + self.train_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.val_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.test_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1) + + +class BatchSamplerGeoDatModule(CustomGeoDataModule): + def setup(self, stage: str) -> None: + self.dataset = CustomGeoDataset() + self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + + +class CustomNonGeoDataset(NonGeoDataset): + def __init__(self, split: str = "train", download: bool = False) -> None: + pass + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} + + def __len__(self) -> int: + return 1 + + +class CustomNonGeoDataModule(NonGeoDataModule): + def __init__(self) -> None: + super().__init__(CustomNonGeoDataset, 1, 0, download=True) + + +class TestGeoDataModule: + @pytest.mark.parametrize("stage", ["fit", "validate", "test"]) + def test_setup(self, stage: str) -> None: + dm = CustomGeoDataModule() + dm.prepare_data() + dm.setup(stage) + + def test_sampler(self) -> None: + dm = SamplerGeoDatModule() + dm.setup("train") + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + dm.predict_dataloader() + + def test_batch_sampler(self) -> None: + dm = BatchSamplerGeoDatModule() + dm.setup("train") + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + dm.predict_dataloader() + + def test_no_datasets(self) -> None: + dm = CustomGeoDataModule() + msg = "CustomGeoDataModule.setup does not define a '{}_dataset'" + with pytest.raises(MisconfigurationException, match=msg.format("train")): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("val")): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("test")): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("predict")): + dm.predict_dataloader() + + +class TestNonGeoDataModule: + @pytest.mark.parametrize("stage", ["fit", "validate", "test"]) + def test_setup(self, stage: str) -> None: + dm = CustomNonGeoDataModule() + dm.prepare_data() + dm.setup(stage) + + def test_no_datasets(self) -> None: + dm = CustomNonGeoDataModule() + msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'" + with pytest.raises(MisconfigurationException, match=msg.format("train")): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("val")): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("test")): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("predict")): + dm.predict_dataloader() diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 5c94839c33d..03c40822768 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -154,15 +154,12 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: batch_sampler = self.train_batch_sampler or self.batch_sampler if dataset is not None and (sampler or batch_sampler) is not None: batch_size = self.train_batch_size or self.batch_size - shuffle = True if batch_sampler is not None: batch_size = 1 - shuffle = False sampler = None return DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, @@ -193,7 +190,6 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: return DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=False, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, @@ -224,7 +220,6 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: return DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=False, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, @@ -255,7 +250,6 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: return DataLoader( dataset=dataset, batch_size=batch_size, - shuffle=False, sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, From 0935cb912dfb09fa117a0cf53167f843b50d5fc0 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 15 Jan 2023 15:18:42 -0600 Subject: [PATCH 091/108] Fix mistake in setup call --- tests/datamodules/test_geo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py index 3d3444b60d1..583e7050f9c 100644 --- a/tests/datamodules/test_geo.py +++ b/tests/datamodules/test_geo.py @@ -74,7 +74,7 @@ def test_setup(self, stage: str) -> None: def test_sampler(self) -> None: dm = SamplerGeoDatModule() - dm.setup("train") + dm.setup("fit") dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() @@ -82,7 +82,7 @@ def test_sampler(self) -> None: def test_batch_sampler(self) -> None: dm = BatchSamplerGeoDatModule() - dm.setup("train") + dm.setup("fit") dm.train_dataloader() dm.val_dataloader() dm.test_dataloader() From 4b52032ecad8d0fa2046d3025532d51cc8cb4c29 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 21 Jan 2023 14:47:16 -0600 Subject: [PATCH 092/108] ExtractTensorPatches runs into OOM errors --- torchgeo/datamodules/deepglobelandcover.py | 18 +++---- torchgeo/datamodules/gid15.py | 15 +++--- torchgeo/datamodules/inria.py | 13 +++-- torchgeo/datamodules/oscd.py | 18 +++---- torchgeo/datamodules/potsdam.py | 18 +++---- torchgeo/datamodules/vaihingen.py | 18 +++---- torchgeo/transforms/transforms.py | 55 ---------------------- 7 files changed, 37 insertions(+), 118 deletions(-) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 26976814baa..b3d38a92f0f 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -10,7 +10,7 @@ from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -33,9 +33,8 @@ def __init__( """Initialize a new DeepGlobeLandCoverDataModule instance. The DeepGlobe Land Cover dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to + directly through a model. Instead, we randomly sample patches from image tiles. + The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. .. versionchanged:: 0.4 @@ -43,9 +42,9 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -60,14 +59,9 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index b896160c6da..b89037a6aed 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -10,7 +10,7 @@ from ..datasets import GID15 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -35,15 +35,14 @@ def __init__( """Initialize a new GID15DataModule instance. The GID-15 dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to + directly through a model. Instead, we randomly sample patches from image tiles. + The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set @@ -65,12 +64,12 @@ def __init__( ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.predict_transform = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image"], ) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index cd4e6a0e5cf..f8ed7fcd9c9 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -10,7 +10,7 @@ from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -38,14 +38,13 @@ def __init__( The Inria Aerial Image Labeling dataset contains images that are too large to pass directly through a model. Instead, we randomly sample patches from image - tiles during training and chop up image tiles into patch grids during - evaluation. During training, the effective batch size is equal to + tiles. The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. num_workers: Number of workers for parallel data loading. @@ -71,12 +70,12 @@ def __init__( ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) self.predict_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image"], ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 8a999de4aa2..cbebc6a3162 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -12,7 +12,7 @@ from ..datasets import OSCD from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -74,15 +74,14 @@ def __init__( """Initialize a new OSCDDataModule instance. The OSCD dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to + directly through a model. Instead, we randomly sample patches from image tiles. + The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -106,14 +105,9 @@ def __init__( self.mean = repeat(self.mean, "c -> (t c)", t=2) self.std = repeat(self.std, "c -> (t c)", t=2) - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index e93b09f8ef2..b594fcd7f1c 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -10,7 +10,7 @@ from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -35,9 +35,8 @@ def __init__( """Initialize a new Potsdam2DDataModule instance. The Potsdam2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to + directly through a model. Instead, we randomly sample patches from image tiles. + The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. .. versionchanged:: 0.4 @@ -45,9 +44,9 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -62,14 +61,9 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index f8a2af9c6ef..d4f6d80816e 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -10,7 +10,7 @@ from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop from .geo import NonGeoDataModule from .utils import dataset_split @@ -35,9 +35,8 @@ def __init__( """Initialize a new Vaihingen2DDataModule instance. The Vaihingen2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to + directly through a model. Instead, we randomly sample patches from image tiles. + The effective batch size is equal to ``num_tiles_per_batch`` x ``num_patches_per_tile``. .. versionchanged:: 0.4 @@ -45,9 +44,9 @@ def __init__( and *patch_size*. Args: - num_tiles_per_batch: Number of image tiles to sample from during training. + num_tiles_per_batch: Number of image tiles to sample from. num_patches_per_tile: Number of patches to randomly sample from each image - tile during training. + tile. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -62,14 +61,9 @@ def __init__( self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _ExtractTensorPatches(self.patch_size), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), data_keys=["image", "mask"], ) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 331b2bc0127..c202a8859d2 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -10,7 +10,6 @@ from einops import rearrange from kornia.augmentation import GeometricAugmentationBase2D from kornia.augmentation.random_generator import CropGenerator -from kornia.contrib import compute_padding, extract_tensor_patches from kornia.geometry import crop_by_indices from torch import Tensor from torch.nn.modules import Module @@ -85,60 +84,6 @@ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: return batch -class _ExtractTensorPatches(GeometricAugmentationBase2D): - """Chop up a tensor into a grid.""" - - def __init__(self, window_size: Union[int, Tuple[int, int]]) -> None: - """Initialize a new _ExtractTensorPatches instance. - - Args: - window_size: the size of each patch - """ - super().__init__(p=1) - self.flags = {"window_size": window_size} - - def compute_transformation( - self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any] - ) -> Tensor: - """Compute the transformation. - - Args: - input: the input tensor - params: generated parameters - flags: static parameters - - Returns: - the transformation - """ - out: Tensor = self.identity_matrix(input) - return out - - def apply_transform( - self, - input: Tensor, - params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, - ) -> Tensor: - """Apply the transform. - - Args: - input: the input tensor - params: generated parameters - flags: static parameters - transform: the geometric transformation tensor - - Returns: - the augmented input - """ - size = flags["window_size"] - h, w = input.shape[-2:] - padding = compute_padding((h, w), size) - input = extract_tensor_patches(input, size, size, padding) - input = torch.flatten(input, 0, 1) # [B, N, C?, H, W] -> [B*N, C?, H, W] - return input - - class _RandomNCrop(GeometricAugmentationBase2D): """Take N random crops of a tensor.""" From 7ad9e5beb0ebcf64f146617415aa2189e5f281be Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 21 Jan 2023 15:03:34 -0600 Subject: [PATCH 093/108] Test both fast_dev_run True and False --- tests/trainers/conftest.py | 8 +++++++ tests/trainers/test_byol.py | 10 +++++---- tests/trainers/test_classification.py | 32 +++++++++++++++++---------- tests/trainers/test_detection.py | 14 +++++++----- tests/trainers/test_regression.py | 14 +++++++----- tests/trainers/test_segmentation.py | 12 ++++++---- 6 files changed, 58 insertions(+), 32 deletions(-) diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index ee0135380a1..6bae33ffdd6 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -15,6 +15,14 @@ from torch.nn.modules import Module +@pytest.fixture( + scope="package", params=[True, pytest.param(False, marks=pytest.mark.slow)] +) +def fast_dev_run(request: SubRequest) -> bool: + flag: bool = request.param + return flag + + @pytest.fixture(scope="package") def model() -> Module: kwargs: Dict[str, Optional[bool]] = {} diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 5b3641d36d3..18bf781a04f 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -67,7 +67,9 @@ class TestBYOLTask: ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), ], ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) @@ -83,7 +85,7 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model.backbone = SegmentationTestModel(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -124,7 +126,7 @@ def test_weight_str( model_kwargs["weights"] = str(mocked_weights) BYOLTask(**model_kwargs) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictBYOLDataModule( root="tests/data/chesapeake/cvpr", train_splits=["de-test"], @@ -136,5 +138,5 @@ def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: ) model_kwargs["in_channels"] = 4 model = BYOLTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index d3b4b77778b..bbd19dd4c9c 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -66,7 +66,11 @@ class TestClassificationTask: ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: if name.startswith("so2sat"): pytest.importorskip("h5py", minversion="2.6") @@ -85,7 +89,7 @@ def test_trainer( model = ClassificationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -154,22 +158,22 @@ def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: ClassificationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(EuroSATDataModule, "plot", plot) datamodule = EuroSATDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) model = ClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictClassificationDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) model = ClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.predict(model=model, datamodule=datamodule) @@ -183,7 +187,11 @@ class TestMultiLabelClassificationTask: ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) @@ -199,7 +207,7 @@ def test_trainer( model = MultiLabelClassificationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -227,20 +235,20 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: MultiLabelClassificationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) datamodule = BigEarthNetDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictMultiLabelClassificationDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index fe7fae1a2de..b3ff70cad0f 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -27,7 +27,9 @@ class TestObjectDetectionTask: @pytest.mark.parametrize( "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) @@ -41,7 +43,7 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model = ObjectDetectionTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -73,20 +75,20 @@ def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: ObjectDetectionTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) datamodule = NASAMarineDebrisDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) model = ObjectDetectionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictObjectDetectionDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) model = ObjectDetectionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index c83bad7e103..3fb81ecf951 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -54,7 +54,9 @@ class TestRegressionTask: ("cyclone", TropicalCycloneDataModule), ], ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) @@ -70,7 +72,7 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model.model = RegressionTestModel() # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -120,20 +122,20 @@ def test_weight_str( RegressionTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) datamodule = TropicalCycloneDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) model = RegressionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) - def test_predict(self, model_kwargs: Dict[Any, Any]) -> None: + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: datamodule = PredictRegressionDataModule( root="tests/data/cyclone", batch_size=1, num_workers=0 ) model = RegressionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 6282152b8da..aa458c883a4 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -62,7 +62,11 @@ class TestSemanticSegmentationTask: ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: if name == "naipchesapeake": pytest.importorskip("zipfile_deflate64") @@ -86,7 +90,7 @@ def test_trainer( model = SemanticSegmentationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) try: trainer.test(model=model, datamodule=datamodule) @@ -135,7 +139,7 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: SemanticSegmentationTask(**model_kwargs) def test_no_rgb( - self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any] + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: model_kwargs["in_channels"] = 15 monkeypatch.setattr(SEN12MSDataModule, "plot", plot) @@ -143,5 +147,5 @@ def test_no_rgb( root="tests/data/sen12ms", batch_size=1, num_workers=0 ) model = SemanticSegmentationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) From d06d9e98203dc89699753c4e81cd6b344807b15f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 21 Jan 2023 17:15:44 -0600 Subject: [PATCH 094/108] Fix plot methods --- torchgeo/datamodules/cowc.py | 6 +++--- torchgeo/datamodules/cyclone.py | 8 ++++---- torchgeo/datamodules/deepglobelandcover.py | 4 ++-- torchgeo/datamodules/fair1m.py | 4 ++-- torchgeo/datamodules/geo.py | 2 +- torchgeo/datamodules/gid15.py | 4 ++-- torchgeo/datamodules/inria.py | 4 ++-- torchgeo/datamodules/naip.py | 13 ++++--------- torchgeo/datamodules/nasa_marine_debris.py | 4 ++-- torchgeo/datamodules/oscd.py | 4 ++-- torchgeo/datamodules/potsdam.py | 4 ++-- torchgeo/datamodules/sen12ms.py | 10 +++++----- torchgeo/datamodules/spacenet.py | 4 ++-- torchgeo/datamodules/vaihingen.py | 4 ++-- torchgeo/datamodules/xview.py | 4 ++-- torchgeo/datasets/cyclone.py | 2 +- torchgeo/datasets/sen12ms.py | 6 +++--- 17 files changed, 41 insertions(+), 46 deletions(-) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index da2a9a14870..111beaa6686 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -33,9 +33,9 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - train_val_dataset = COWCCounting(split="train", **self.kwargs) + self.dataset = COWCCounting(split="train", **self.kwargs) self.test_dataset = COWCCounting(split="test", **self.kwargs) self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], + self.dataset, + [len(self.dataset) - len(self.test_dataset), len(self.test_dataset)], ) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 06e77380867..2442bb201ec 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -43,10 +43,10 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = TropicalCyclone(split="train", **self.kwargs) + self.dataset = TropicalCyclone(split="train", **self.kwargs) storm_ids = [] - for item in dataset.collection: + for item in self.dataset.collection: storm_id = item["href"].split("/")[0].split("_")[-2] storm_ids.append(storm_id) @@ -56,7 +56,7 @@ def setup(self, stage: str) -> None: ) ) - self.train_dataset = Subset(dataset, train_indices) - self.val_dataset = Subset(dataset, val_indices) + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) if stage in ["test"]: self.test_dataset = TropicalCyclone(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index b3d38a92f0f..142f5fbfad3 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -72,9 +72,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = DeepGlobeLandCover(split="train", **self.kwargs) + self.dataset = DeepGlobeLandCover(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, self.val_split_pct + self.dataset, self.val_split_pct ) if stage in ["test"]: self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index 6c684c45455..24b6f46eae1 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -45,7 +45,7 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - dataset = FAIR1M(**self.kwargs) + self.dataset = FAIR1M(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index 03c40822768..b15aaf71627 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -534,7 +534,7 @@ def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: Returns: A matplotlib Figure with the image, ground truth, and predictions. """ - dataset = self.val_dataset or self.dataset + dataset = self.dataset or self.val_dataset if dataset is not None: if hasattr(dataset, "plot"): return dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index b89037a6aed..bafa1ec7985 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -80,9 +80,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = GID15(split="train", **self.kwargs) + self.dataset = GID15(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, self.val_split_pct + self.dataset, self.val_split_pct ) if stage in ["test"]: # Test set masks are not public, use for prediction instead diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index f8ed7fcd9c9..4285c07151e 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -86,9 +86,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate", "test"]: - dataset = InriaAerialImageLabeling(split="train", **self.kwargs) + self.dataset = InriaAerialImageLabeling(split="train", **self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, self.val_split_pct, self.test_split_pct + self.dataset, self.val_split_pct, self.test_split_pct ) if stage in ["predict"]: # Test set masks are not public, use for prediction instead diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 34082104552..63ef8c8501e 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -95,21 +95,16 @@ def setup(self, stage: str) -> None: self.dataset, self.patch_size, self.patch_size, test_roi ) - def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: - """Run NAIP and Chesapeake plot methods. - - See :meth:`torchgeo.datasets.NAIP.plot` and - :meth:`torchgeo.datasets.Chesapeake.plot`. + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run NAIP plot method. Args: *args: Arguments passed to plot method. **kwargs: Keyword arguments passed to plot method. Returns: - A list of matplotlib Figures with the image, ground truth, and predictions. + A matplotlib Figure with the image, ground truth, and predictions. .. versionadded:: 0.4 """ - image = self.naip.plot(*args, **kwargs) - label = self.chesapeake.plot(*args, **kwargs) - return image, label + return self.naip.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index b07c464eabf..54119a804a5 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -45,7 +45,7 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - dataset = NASAMarineDebris(**self.kwargs) + self.dataset = NASAMarineDebris(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct + self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index cbebc6a3162..d36b61a9fb0 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -118,9 +118,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = OSCD(split="train", **self.kwargs) + self.dataset = OSCD(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, val_pct=self.val_split_pct + self.dataset, val_pct=self.val_split_pct ) if stage in ["test"]: self.test_dataset = OSCD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index b594fcd7f1c..6c924ea427d 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -74,9 +74,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = Potsdam2D(split="train", **self.kwargs) + self.dataset = Potsdam2D(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, self.val_split_pct + self.dataset, self.val_split_pct ) if stage in ["test"]: self.test_dataset = Potsdam2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 3692c4a757c..4ff04580eab 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -72,7 +72,7 @@ def setup(self, stage: str) -> None: if stage in ["fit", "validate"]: season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - dataset = SEN12MS(split="train", **self.kwargs) + self.dataset = SEN12MS(split="train", **self.kwargs) # A patch is a filename like: # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" @@ -81,7 +81,7 @@ def setup(self, stage: str) -> None: # simply give each season a large number and representing a unique_scene_id # as (season_id + scene_id). scenes = [] - for scene_fn in dataset.ids: + for scene_fn in self.dataset.ids: parts = scene_fn.split("_") season_id = season_to_int[parts[1]] scene_id = int(parts[3]) @@ -93,8 +93,8 @@ def setup(self, stage: str) -> None: ) ) - self.train_dataset = Subset(dataset, train_indices) - self.val_dataset = Subset(dataset, val_indices) + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) if stage in ["test"]: self.test_dataset = SEN12MS(split="test", **self.kwargs) @@ -110,6 +110,6 @@ def on_after_batch_transfer( Returns: A batch of data. """ - batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"][:, 0]) + batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"]) return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 5bd7cdba476..802cc7a26c6 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -74,9 +74,9 @@ def setup(self, stage: str) -> None: Args: stage: Either 'fit', 'validate', 'test', or 'predict'. """ - dataset = SpaceNet1(**self.kwargs) + self.dataset = SpaceNet1(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - dataset, self.val_split_pct, self.test_split_pct + self.dataset, self.val_split_pct, self.test_split_pct ) def on_after_batch_transfer( diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index d4f6d80816e..faa559cb454 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -74,9 +74,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = Vaihingen2D(split="train", **self.kwargs) + self.dataset = Vaihingen2D(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, self.val_split_pct + self.dataset, self.val_split_pct ) if stage in ["test"]: self.test_dataset = Vaihingen2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 183a5f6b7c2..8f96d786bea 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -45,9 +45,9 @@ def setup(self, stage: str) -> None: stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ["fit", "validate"]: - dataset = XView2(split="train", **self.kwargs) + self.dataset = XView2(split="train", **self.kwargs) self.train_dataset, self.val_dataset = dataset_split( - dataset, val_pct=self.val_split_pct + self.dataset, val_pct=self.val_split_pct ) if stage in ["test"]: self.test_dataset = XView2(split="test", **self.kwargs) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index ebcc48c37c3..54d55111468 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -241,7 +241,7 @@ def plot( fig, ax = plt.subplots(1, 1, figsize=(10, 10)) - ax.imshow(image, cmap="gray") + ax.imshow(image.permute(1, 2, 0)) ax.axis("off") if show_titles: diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index fba2475b0eb..0e4db7adf73 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -230,7 +230,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image = torch.cat(tensors=[s1, s2], dim=0) image = torch.index_select(image, dim=0, index=self.band_indices) - sample: Dict[str, Tensor] = {"image": image, "mask": lc} + sample: Dict[str, Tensor] = {"image": image, "mask": lc[0]} if self.transforms is not None: sample = self.transforms(sample) @@ -336,13 +336,13 @@ def plot( else: raise ValueError("Dataset doesn't contain some of the RGB bands") - image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0] + image, mask = sample["image"][rgb_indices].numpy(), sample["mask"] image = percentile_normalization(image) ncols = 2 showing_predictions = "prediction" in sample if showing_predictions: - prediction = sample["prediction"][0] + prediction = sample["prediction"] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) From 647e60138afaf1c61f3be12a4685e5bfbcc5963e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 21 Jan 2023 17:19:03 -0600 Subject: [PATCH 095/108] Fix OSCD tests --- tests/datamodules/test_oscd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index f155920dc02..7d5d70698ce 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -52,7 +52,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 1024 + assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 if datamodule.bands == "all": assert batch["image"].shape[1] == 26 else: @@ -64,7 +64,7 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 1024 + assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 if datamodule.bands == "all": assert batch["image"].shape[1] == 26 else: From deca721f6850fc9713158366a764342d5372b63b Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 21 Jan 2023 23:42:29 -0600 Subject: [PATCH 096/108] Fix bug with inconsistent train/val/test splits between stages --- torchgeo/datamodules/cowc.py | 2 ++ torchgeo/datamodules/cyclone.py | 2 +- torchgeo/datamodules/sen12ms.py | 2 +- torchgeo/datamodules/utils.py | 15 +++++++++++++-- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 111beaa6686..799850f4e4a 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -5,6 +5,7 @@ from typing import Any +from torch import Generator from torch.utils.data import random_split from ..datasets import COWCCounting @@ -38,4 +39,5 @@ def setup(self, stage: str) -> None: self.train_dataset, self.val_dataset = random_split( self.dataset, [len(self.dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=Generator().manual_seed(0), ) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 2442bb201ec..b3c8d3121a9 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -51,7 +51,7 @@ def setup(self, stage: str) -> None: storm_ids.append(storm_id) train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2).split( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split( storm_ids, groups=storm_ids ) ) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 4ff04580eab..d412c93d978 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -88,7 +88,7 @@ def setup(self, stage: str) -> None: scenes.append(season_id + scene_id) train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2).split( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split( scenes, groups=scenes ) ) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index be50cd7dc99..b1df01721c3 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Union +from torch import Generator from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import NonGeoDataset @@ -24,6 +25,10 @@ def dataset_split( If ``test_pct`` is not set then only train and validation splits are returned. + .. deprecated:: 0.4 + Use :func:`torch.utils.data.random_split` instead, ``random_split`` + now supports percentages as of PyTorch 1.13. + Args: dataset: dataset to be split into train/val or train/val/test subsets val_pct: percentage of samples to be in validation set @@ -35,9 +40,15 @@ def dataset_split( if test_pct is None: val_length = round(len(dataset) * val_pct) train_length = len(dataset) - val_length - return random_split(dataset, [train_length, val_length]) + return random_split( + dataset, [train_length, val_length], generator=Generator().manual_seed(0) + ) else: val_length = round(len(dataset) * val_pct) test_length = round(len(dataset) * test_pct) train_length = len(dataset) - (val_length + test_length) - return random_split(dataset, [train_length, val_length, test_length]) + return random_split( + dataset, + [train_length, val_length, test_length], + generator=Generator().manual_seed(0), + ) From 5ae2f4c18b896a3fcc0a499431fe98ce961f1d42 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 10:37:49 -0600 Subject: [PATCH 097/108] Fix issues with images of different sizes --- conf/deepglobelandcover.yaml | 3 +-- conf/gid15.yaml | 3 +-- conf/inria.yaml | 3 +-- conf/oscd.yaml | 3 +-- conf/potsdam2d.yaml | 3 +-- conf/vaihingen2d.yaml | 3 +-- tests/conf/deepglobelandcover.yaml | 3 +-- tests/conf/gid15.yaml | 3 +-- tests/conf/inria.yaml | 3 +-- tests/conf/potsdam2d.yaml | 3 +-- tests/conf/vaihingen2d.yaml | 3 +-- torchgeo/datamodules/deepglobelandcover.py | 20 +++------------- torchgeo/datamodules/gid15.py | 27 +++++----------------- torchgeo/datamodules/inria.py | 20 ++++------------ torchgeo/datamodules/oscd.py | 16 +++---------- torchgeo/datamodules/potsdam.py | 20 +++------------- torchgeo/datamodules/vaihingen.py | 20 +++------------- 17 files changed, 34 insertions(+), 122 deletions(-) diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml index 7732406da45..2e09eca0e4b 100644 --- a/conf/deepglobelandcover.yaml +++ b/conf/deepglobelandcover.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/deepglobelandcover" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/gid15.yaml b/conf/gid15.yaml index dd69143e574..420c6b2f0e9 100644 --- a/conf/gid15.yaml +++ b/conf/gid15.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/gid15" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/inria.yaml b/conf/inria.yaml index 234ddffcb01..9321fb6972e 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -24,7 +24,6 @@ experiment: ignore_index: null datamodule: root: "data/inria" - num_tiles_per_batch: 2 - num_patches_per_tile: 4 + batch_size: 1 patch_size: 512 num_workers: 32 diff --git a/conf/oscd.yaml b/conf/oscd.yaml index 48634f24878..be13cbd40c0 100644 --- a/conf/oscd.yaml +++ b/conf/oscd.yaml @@ -19,8 +19,7 @@ experiment: ignore_index: 0 datamodule: root: "data/oscd" - num_tiles_per_batch: 32 - num_patches_per_tile: 128 + batch_size: 1 patch_size: 64 val_split_pct: 0.1 num_workers: 4 diff --git a/conf/potsdam2d.yaml b/conf/potsdam2d.yaml index bd5f4e9228a..e1312fa57d4 100644 --- a/conf/potsdam2d.yaml +++ b/conf/potsdam2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/potsdam" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml index 0f3015faf66..c6fd448c6dd 100644 --- a/conf/vaihingen2d.yaml +++ b/conf/vaihingen2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/vaihingen" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml index 2bb2cc5b53b..e27fe1271c2 100644 --- a/tests/conf/deepglobelandcover.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/deepglobelandcover" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml index 56e25c7261a..baaea0e1ba2 100644 --- a/tests/conf/gid15.yaml +++ b/tests/conf/gid15.yaml @@ -15,8 +15,7 @@ experiment: datamodule: root: "tests/data/gid15" download: true - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/inria.yaml b/tests/conf/inria.yaml index 7cb05607bff..995c073146b 100644 --- a/tests/conf/inria.yaml +++ b/tests/conf/inria.yaml @@ -12,8 +12,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/inria" - num_tiles_per_batch: 1 - num_patches_per_tile: 2 + batch_size: 1 patch_size: 2 num_workers: 0 val_split_pct: 0.2 diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml index fcdfd07b37c..7492a8c0c86 100644 --- a/tests/conf/potsdam2d.yaml +++ b/tests/conf/potsdam2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/potsdam" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml index 0184b929a0a..7f542f3310b 100644 --- a/tests/conf/vaihingen2d.yaml +++ b/tests/conf/vaihingen2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/vaihingen" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 142f5fbfad3..fdf00265016 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -23,8 +23,7 @@ class DeepGlobeLandCoverDataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, @@ -32,19 +31,8 @@ def __init__( ) -> None: """Initialize a new DeepGlobeLandCoverDataModule instance. - The DeepGlobe Land Cover dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles. - The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -54,14 +42,12 @@ def __init__( """ super().__init__(DeepGlobeLandCover, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index bafa1ec7985..1297741eb2f 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -25,8 +25,7 @@ class GID15DataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, @@ -34,15 +33,8 @@ def __init__( ) -> None: """Initialize a new GID15DataModule instance. - The GID-15 dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles. - The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set @@ -52,24 +44,17 @@ def __init__( """ super().__init__(GID15, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.train_aug = AugmentationSequential( - K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) - self.val_aug = AugmentationSequential( + self.train_aug = self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) - self.predict_transform = AugmentationSequential( + self.predict_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image"], ) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 4285c07151e..bc524e5e921 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -26,8 +26,7 @@ class InriaAerialImageLabelingDataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, num_workers: int = 0, val_split_pct: float = 0.1, @@ -36,15 +35,8 @@ def __init__( ) -> None: """Initialize a new InriaAerialImageLabelingDataModule instance. - The Inria Aerial Image Labeling dataset contains images that are too large to - pass directly through a model. Instead, we randomly sample patches from image - tiles. The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. num_workers: Number of workers for parallel data loading. @@ -55,8 +47,6 @@ def __init__( """ super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct @@ -65,17 +55,17 @@ def __init__( K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) self.val_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) self.predict_aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image"], ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index d36b61a9fb0..624a47c615f 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -64,8 +64,7 @@ class OSCDDataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, @@ -73,15 +72,8 @@ def __init__( ) -> None: """Initialize a new OSCDDataModule instance. - The OSCD dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles. - The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -91,8 +83,6 @@ def __init__( """ super().__init__(OSCD, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct @@ -107,7 +97,7 @@ def __init__( self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index 6c924ea427d..f22558a5908 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -25,8 +25,7 @@ class Potsdam2DDataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, @@ -34,19 +33,8 @@ def __init__( ) -> None: """Initialize a new Potsdam2DDataModule instance. - The Potsdam2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles. - The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -56,14 +44,12 @@ def __init__( """ super().__init__(Potsdam2D, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index faa559cb454..1128ea76655 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -25,8 +25,7 @@ class Vaihingen2DDataModule(NonGeoDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, @@ -34,19 +33,8 @@ def __init__( ) -> None: """Initialize a new Vaihingen2DDataModule instance. - The Vaihingen2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles. - The effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. - Args: - num_tiles_per_batch: Number of image tiles to sample from. - num_patches_per_tile: Number of patches to randomly sample from each image - tile. + batch_size: Size of each mini-batch. patch_size: Size of each patch, either ``size`` or ``(height, width)``. Should be a multiple of 32 for most segmentation architectures. val_split_pct: Percentage of the dataset to use as a validation set. @@ -56,14 +44,12 @@ def __init__( """ super().__init__(Vaihingen2D, 1, num_workers, **kwargs) - self.train_batch_size = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.aug = AugmentationSequential( K.Normalize(mean=self.mean, std=self.std), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) From 6a1571a536be2709bee1758e8bddc6ca12fa2fef Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 10:45:51 -0600 Subject: [PATCH 098/108] Fix OSCD tests --- tests/datamodules/test_oscd.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 7d5d70698ce..d4f28794d8c 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -14,20 +14,15 @@ class TestOSCDDataModule: @pytest.fixture(params=["all", "rgb"]) def datamodule(self, request: SubRequest) -> OSCDDataModule: bands = request.param - num_tiles_per_batch = 1 - num_patches_per_tile = 2 - patch_size = 2 root = os.path.join("tests", "data", "oscd") - num_workers = 0 dm = OSCDDataModule( root=root, download=True, bands=bands, - num_tiles_per_batch=num_tiles_per_batch, - num_patches_per_tile=num_patches_per_tile, - patch_size=patch_size, + batch_size=1, + patch_size=2, val_split_pct=0.5, - num_workers=num_workers, + num_workers=0, ) dm.prepare_data() dm.trainer = Trainer(max_epochs=1) From cc6176e2d16ec86a9186678b89dc6fa16fcc21fb Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 10:50:46 -0600 Subject: [PATCH 099/108] Fix OSCD tests --- tests/datamodules/test_oscd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index d4f28794d8c..66d8500e386 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -34,7 +34,7 @@ def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: batch = next(iter(datamodule.train_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 if datamodule.bands == "all": assert batch["image"].shape[1] == 26 else: @@ -47,7 +47,7 @@ def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 if datamodule.bands == "all": assert batch["image"].shape[1] == 26 else: @@ -59,7 +59,7 @@ def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: batch = next(iter(datamodule.test_dataloader())) batch = datamodule.on_after_batch_transfer(batch, 0) assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) - assert batch["image"].shape[0] == batch["mask"].shape[0] == 2 + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 if datamodule.bands == "all": assert batch["image"].shape[1] == 26 else: From 40d31bbd34aaf80254cded2880b1501451604c01 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 23:10:51 -0600 Subject: [PATCH 100/108] Bad rebase --- tests/trainers/test_byol.py | 1 - tests/trainers/test_classification.py | 2 +- tests/trainers/test_regression.py | 6 ------ 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 18bf781a04f..2f5dbd8d2a4 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -11,7 +11,6 @@ import torch.nn as nn import torchvision from _pytest.monkeypatch import MonkeyPatch -from lightning_lite.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from torchvision.models import resnet18 diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index bbd19dd4c9c..900216e1e1e 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -23,8 +23,8 @@ So2SatDataModule, UCMercedDataModule, ) -from torchgeo.models import ResNet18_Weights from torchgeo.datasets import BigEarthNet, EuroSAT +from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 3fb81ecf951..65fbeabfeca 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -14,18 +14,12 @@ from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) - from torchgeo.datamodules import ( COWCCountingDataModule, MisconfigurationException, TropicalCycloneDataModule, ) from torchgeo.datasets import TropicalCyclone -from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask From f6a30618cc233ef9b66a3de2fef86301d72adf8f Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 23:11:30 -0600 Subject: [PATCH 101/108] No trainer for OSCD so no need for config --- conf/oscd.yaml | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 conf/oscd.yaml diff --git a/conf/oscd.yaml b/conf/oscd.yaml deleted file mode 100644 index be13cbd40c0..00000000000 --- a/conf/oscd.yaml +++ /dev/null @@ -1,25 +0,0 @@ -trainer: - gpus: 1 - min_epochs: 20 - max_epochs: 500 - benchmark: True -experiment: - task: "oscd" - module: - loss: "jaccard" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 26 - num_classes: 2 - num_filters: 256 - ignore_index: 0 - datamodule: - root: "data/oscd" - batch_size: 1 - patch_size: 64 - val_split_pct: 0.1 - num_workers: 4 From ed374296263009c69f108d6c9f9e54d01d348bbb Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 22 Jan 2023 23:13:45 -0600 Subject: [PATCH 102/108] Bad rebase --- tests/trainers/test_classification.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 900216e1e1e..6f8596da35b 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -145,18 +145,6 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) - def test_invalid_model(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["model"] = "invalid_model" - match = "Model type 'invalid_model' is not a valid timm model." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - - def test_invalid_weights(self, model_kwargs: Dict[Any, Any]) -> None: - model_kwargs["weights"] = "invalid_weights" - match = "Weight type 'invalid_weights' is not valid." - with pytest.raises(ValueError, match=match): - ClassificationTask(**model_kwargs) - def test_no_rgb( self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: From 66093a6dcc0688c476788fa5e8fd5b494abafe85 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 23 Jan 2023 11:15:04 -0600 Subject: [PATCH 103/108] plot: only works during validation --- torchgeo/datamodules/geo.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index b15aaf71627..a3d204c66fd 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -308,7 +308,10 @@ def on_after_batch_transfer( return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run the plot method of the dataset if one exists. + """Run the plot method of the validation dataset if one exists. + + Should only be called during 'fit' or 'validate' stages as ``val_dataset`` + may not exist during other stages. Args: *args: Arguments passed to plot method. @@ -525,7 +528,10 @@ def on_after_batch_transfer( return batch def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run the plot method of the dataset if one exists. + """Run the plot method of the validation dataset if one exists. + + Should only be called during 'fit' or 'validate' stages as ``val_dataset`` + may not exist during other stages. Args: *args: Arguments passed to plot method. From 1776fd62095034173d4a44896512f6a05924b169 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 23 Jan 2023 11:56:01 -0600 Subject: [PATCH 104/108] Fix collation of NASA Marine Debris dataset --- tests/datasets/test_nasa_marine_debris.py | 1 - torchgeo/datamodules/geo.py | 20 +++++++++++++----- torchgeo/datamodules/nasa_marine_debris.py | 24 +++++++++++++++++++++- torchgeo/datasets/nasa_marine_debris.py | 1 - 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 1a428cdf858..706dcf52552 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -59,7 +59,6 @@ def test_already_downloaded_not_extracted( shutil.rmtree(dataset.root) os.makedirs(str(tmp_path), exist_ok=True) Dataset().download(output_dir=str(tmp_path)) - print(os.listdir(str(tmp_path))) NASAMarineDebris(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index a3d204c66fd..c1b1944c933 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -10,7 +10,7 @@ import torch from pytorch_lightning import LightningDataModule from torch import Tensor -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader, Dataset, default_collate from ..datasets import GeoDataset, NonGeoDataset, stack_samples from ..samplers import ( @@ -87,6 +87,9 @@ def __init__( self.test_batch_size: Optional[int] = None self.predict_batch_size: Optional[int] = None + # Collation + self.collate_fn = stack_samples + # Data augmentation Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] self.aug: Transform = AugmentationSequential( @@ -163,7 +166,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, - collate_fn=stack_samples, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" @@ -193,7 +196,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, - collate_fn=stack_samples, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" @@ -223,7 +226,7 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, - collate_fn=stack_samples, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" @@ -253,7 +256,7 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: sampler=sampler, batch_sampler=batch_sampler, num_workers=self.num_workers, - collate_fn=stack_samples, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" @@ -370,6 +373,9 @@ def __init__( self.test_batch_size: Optional[int] = None self.predict_batch_size: Optional[int] = None + # Collation + self.collate_fn = default_collate + # Data augmentation Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] self.aug: Transform = AugmentationSequential( @@ -430,6 +436,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: batch_size=self.train_batch_size or self.batch_size, shuffle=True, num_workers=self.num_workers, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" @@ -452,6 +459,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: batch_size=self.val_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" @@ -474,6 +482,7 @@ def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: batch_size=self.test_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" @@ -496,6 +505,7 @@ def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: batch_size=self.predict_batch_size or self.batch_size, shuffle=False, num_workers=self.num_workers, + collate_fn=self.collate_fn, ) else: msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 54119a804a5..584154e44fa 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,13 +3,33 @@ """NASA Marine Debris datamodule.""" -from typing import Any +from typing import Any, Dict, List + +import torch +from torch import Tensor +from torch.utils.data import DataLoader from ..datasets import NASAMarineDebris from .geo import NonGeoDataModule from .utils import dataset_split +def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: + """Custom object detection collate fn to handle variable boxes. + + Args: + batch: list of sample dicts return by dataset + + Returns: + batch dict output + """ + output: Dict[str, Any] = {} + output["image"] = torch.stack([sample["image"] for sample in batch]) + output["boxes"] = [sample["boxes"] for sample in batch] + output["labels"] = [torch.tensor([1] * len(sample["boxes"])) for sample in batch] + return output + + class NASAMarineDebrisDataModule(NonGeoDataModule): """LightningDataModule implementation for the NASA Marine Debris dataset. @@ -39,6 +59,8 @@ def __init__( self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct + self.collate_fn = collate_fn + def setup(self, stage: str) -> None: """Set up datasets. diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 83dae527b17..4a506e5c701 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -104,7 +104,6 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: h_check = (sample["boxes"][:, 3] - sample["boxes"][:, 1]) > 0 indices = w_check & h_check sample["boxes"] = sample["boxes"][indices] - sample["labels"] = torch.ones(len(indices), dtype=torch.int64) if self.transforms is not None: sample = self.transforms(sample) From 38797f15cf8f01ac8a124160402c187e353e6a58 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 23 Jan 2023 12:08:03 -0600 Subject: [PATCH 105/108] flake8 fix --- torchgeo/datamodules/nasa_marine_debris.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index 584154e44fa..cdcab7f7e9c 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -7,7 +7,6 @@ import torch from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import NASAMarineDebris from .geo import NonGeoDataModule From f465efcbef904b8a5bc2257f2800eed931c491ab Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 23 Jan 2023 12:09:38 -0600 Subject: [PATCH 106/108] Quick test --- requirements/required.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/required.txt b/requirements/required.txt index 548c948e821..f7f8a6cb2ae 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -11,7 +11,7 @@ omegaconf==2.3.0 packaging==23.0 pillow==9.4.0 pyproj==3.4.1;python_version>='3.8' -pytorch-lightning[extra]==1.9.0 +pytorch-lightning==1.9.0 rasterio==1.3.4;python_version>='3.8' rtree==1.0.1 scikit-learn==1.2.0;python_version>='3.8' From 514ad2ff05f03e883e6c70c766d0359f4dc19cc6 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Mon, 23 Jan 2023 12:17:14 -0600 Subject: [PATCH 107/108] Revert "Quick test" This reverts commit f465efcbef904b8a5bc2257f2800eed931c491ab. --- requirements/required.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/required.txt b/requirements/required.txt index f7f8a6cb2ae..548c948e821 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -11,7 +11,7 @@ omegaconf==2.3.0 packaging==23.0 pillow==9.4.0 pyproj==3.4.1;python_version>='3.8' -pytorch-lightning==1.9.0 +pytorch-lightning[extra]==1.9.0 rasterio==1.3.4;python_version>='3.8' rtree==1.0.1 scikit-learn==1.2.0;python_version>='3.8' From 6c4dafa861c231273289fe93e189ac75647a1a39 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Mon, 23 Jan 2023 20:49:31 +0000 Subject: [PATCH 108/108] 56 workers is a bit excessive --- conf/nasa_marine_debris.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index 48b0e8b0285..205a9d47362 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -21,5 +21,5 @@ experiment: datamodule: root: "data/nasamr/nasa_marine_debris" batch_size: 4 - num_workers: 56 + num_workers: 6 val_split_pct: 0.2