diff --git a/conf/chesapeake_cvpr.yaml b/conf/chesapeake_cvpr.yaml index 2ee411a6b17..858e2624199 100644 --- a/conf/chesapeake_cvpr.yaml +++ b/conf/chesapeake_cvpr.yaml @@ -25,9 +25,8 @@ experiment: - "de-val" test_splits: - "de-test" - patches_per_tile: 200 + batch_size: 200 patch_size: 256 - batch_size: 64 num_workers: 4 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/conf/cowc_counting.yaml b/conf/cowc_counting.yaml index 91f0d9921b6..eb6363d1e99 100644 --- a/conf/cowc_counting.yaml +++ b/conf/cowc_counting.yaml @@ -13,6 +13,5 @@ experiment: learning_rate_schedule_patience: 2 datamodule: root: "data/cowc_counting" - seed: 0 batch_size: 64 num_workers: 4 diff --git a/conf/cyclone.yaml b/conf/cyclone.yaml index c0e038b3834..68a67f9a306 100644 --- a/conf/cyclone.yaml +++ b/conf/cyclone.yaml @@ -13,6 +13,5 @@ experiment: learning_rate_schedule_patience: 2 datamodule: root: "data/cyclone" - seed: 0 batch_size: 32 num_workers: 4 diff --git a/conf/deepglobelandcover.yaml b/conf/deepglobelandcover.yaml index 7732406da45..2e09eca0e4b 100644 --- a/conf/deepglobelandcover.yaml +++ b/conf/deepglobelandcover.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/deepglobelandcover" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/gid15.yaml b/conf/gid15.yaml index dd69143e574..420c6b2f0e9 100644 --- a/conf/gid15.yaml +++ b/conf/gid15.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/gid15" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/inria.yaml b/conf/inria.yaml index 462c873fda6..9321fb6972e 100644 --- a/conf/inria.yaml +++ b/conf/inria.yaml @@ -1,7 +1,7 @@ program: overwrite: True - + trainer: gpus: 1 min_epochs: 5 @@ -23,8 +23,7 @@ experiment: num_classes: 2 ignore_index: null datamodule: - root: "data/inria" - batch_size: 2 - num_workers: 32 - patch_size: 512 - num_patches_per_tile: 4 + root: "data/inria" + batch_size: 1 + patch_size: 512 + num_workers: 32 diff --git a/conf/nasa_marine_debris.yaml b/conf/nasa_marine_debris.yaml index 3b94582652d..205a9d47362 100644 --- a/conf/nasa_marine_debris.yaml +++ b/conf/nasa_marine_debris.yaml @@ -1,5 +1,4 @@ program: - seed: 0 overwrite: True trainer: @@ -22,5 +21,5 @@ experiment: datamodule: root: "data/nasamr/nasa_marine_debris" batch_size: 4 - num_workers: 56 + num_workers: 6 val_split_pct: 0.2 diff --git a/conf/oscd.yaml b/conf/oscd.yaml deleted file mode 100644 index 486dbb9b64b..00000000000 --- a/conf/oscd.yaml +++ /dev/null @@ -1,29 +0,0 @@ -trainer: - gpus: 1 - min_epochs: 20 - max_epochs: 500 - benchmark: True -experiment: - task: "oscd" - module: - loss: "jaccard" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 26 - num_classes: 2 - num_filters: 256 - ignore_index: 0 - datamodule: - root: "data/oscd" - train_batch_size: 32 - num_workers: 4 - val_split_pct: 0.1 - bands: "all" - pad_size: - - 1028 - - 1028 - num_patches_per_tile: 128 diff --git a/conf/potsdam2d.yaml b/conf/potsdam2d.yaml index bd5f4e9228a..e1312fa57d4 100644 --- a/conf/potsdam2d.yaml +++ b/conf/potsdam2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/potsdam" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/conf/sen12ms.yaml b/conf/sen12ms.yaml index 1dba884add1..3946774328a 100644 --- a/conf/sen12ms.yaml +++ b/conf/sen12ms.yaml @@ -20,4 +20,3 @@ experiment: band_set: "all" batch_size: 32 num_workers: 4 - seed: 0 diff --git a/conf/so2sat.yaml b/conf/so2sat.yaml index e8157c5c138..e27cbf02e82 100644 --- a/conf/so2sat.yaml +++ b/conf/so2sat.yaml @@ -11,11 +11,10 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 18 num_classes: 17 datamodule: root: "data/so2sat" batch_size: 128 num_workers: 4 - band_set: "rgb" - unsupervised_mode: False + band_set: "all" diff --git a/conf/spacenet1.yaml b/conf/spacenet1.yaml index ae9b9eeb68a..05949f26010 100644 --- a/conf/spacenet1.yaml +++ b/conf/spacenet1.yaml @@ -1,6 +1,5 @@ program: overwrite: False - seed: 0 trainer: gpus: [3] min_epochs: 50 @@ -22,4 +21,4 @@ experiment: datamodule: root: "data/spacenet" batch_size: 32 - num_workers: 4 \ No newline at end of file + num_workers: 4 diff --git a/conf/vaihingen2d.yaml b/conf/vaihingen2d.yaml index 0f3015faf66..c6fd448c6dd 100644 --- a/conf/vaihingen2d.yaml +++ b/conf/vaihingen2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "data/vaihingen" - num_tiles_per_batch: 16 - num_patches_per_tile: 16 + batch_size: 1 patch_size: 64 val_split_pct: 0.5 num_workers: 0 diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 4833ff815e4..d66fee22755 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -128,3 +128,21 @@ xView2 ^^^^^^ .. autoclass:: XView2DataModule + +Base Classes +------------ + +GeoDataModule +^^^^^^^^^^^^^ + +.. autoclass:: GeoDataModule + +NonGeoDataModule +^^^^^^^^^^^^^^^^ + +.. autoclass:: NonGeoDataModule + +Utilities +--------- + +.. autoclass:: MisconfigurationException diff --git a/tests/conf/chesapeake_cvpr_5.yaml b/tests/conf/chesapeake_cvpr_5.yaml index 7e1663c9493..7ef269dd661 100644 --- a/tests/conf/chesapeake_cvpr_5.yaml +++ b/tests/conf/chesapeake_cvpr_5.yaml @@ -20,9 +20,8 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 - patch_size: 64 batch_size: 2 + patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/tests/conf/chesapeake_cvpr_7.yaml b/tests/conf/chesapeake_cvpr_7.yaml index 9a34e401d19..653f4934ca0 100644 --- a/tests/conf/chesapeake_cvpr_7.yaml +++ b/tests/conf/chesapeake_cvpr_7.yaml @@ -20,9 +20,8 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 - patch_size: 64 batch_size: 2 + patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: False diff --git a/tests/conf/chesapeake_cvpr_prior.yaml b/tests/conf/chesapeake_cvpr_prior.yaml index 907b17ac76d..3e9713fbb59 100644 --- a/tests/conf/chesapeake_cvpr_prior.yaml +++ b/tests/conf/chesapeake_cvpr_prior.yaml @@ -20,9 +20,8 @@ experiment: - "de-test" test_splits: - "de-test" - patches_per_tile: 2 - patch_size: 64 batch_size: 2 + patch_size: 64 num_workers: 0 class_set: ${experiment.module.num_classes} use_prior_labels: True diff --git a/tests/conf/cowc_counting.yaml b/tests/conf/cowc_counting.yaml index a4f25698100..fc3218e8fef 100644 --- a/tests/conf/cowc_counting.yaml +++ b/tests/conf/cowc_counting.yaml @@ -10,6 +10,5 @@ experiment: datamodule: root: "tests/data/cowc_counting" download: true - seed: 0 batch_size: 1 num_workers: 0 diff --git a/tests/conf/cyclone.yaml b/tests/conf/cyclone.yaml index fd0fd42b412..b3323d28999 100644 --- a/tests/conf/cyclone.yaml +++ b/tests/conf/cyclone.yaml @@ -10,6 +10,5 @@ experiment: datamodule: root: "tests/data/cyclone" download: true - seed: 0 batch_size: 1 num_workers: 0 diff --git a/tests/conf/deepglobelandcover.yaml b/tests/conf/deepglobelandcover.yaml index 2bb2cc5b53b..e27fe1271c2 100644 --- a/tests/conf/deepglobelandcover.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/deepglobelandcover" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/gid15.yaml b/tests/conf/gid15.yaml index 56e25c7261a..baaea0e1ba2 100644 --- a/tests/conf/gid15.yaml +++ b/tests/conf/gid15.yaml @@ -15,8 +15,7 @@ experiment: datamodule: root: "tests/data/gid15" download: true - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/inria_test.yaml b/tests/conf/inria.yaml similarity index 59% rename from tests/conf/inria_test.yaml rename to tests/conf/inria.yaml index 6bc4b5bd7cc..995c073146b 100644 --- a/tests/conf/inria_test.yaml +++ b/tests/conf/inria.yaml @@ -11,10 +11,9 @@ experiment: num_classes: 2 ignore_index: null datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.2 - test_split_pct: 0.2 - patch_size: 2 - num_patches_per_tile: 2 + root: "tests/data/inria" + batch_size: 1 + patch_size: 2 + num_workers: 0 + val_split_pct: 0.2 + test_split_pct: 0.2 diff --git a/tests/conf/inria_train.yaml b/tests/conf/inria_train.yaml deleted file mode 100644 index 99db7925f27..00000000000 --- a/tests/conf/inria_train.yaml +++ /dev/null @@ -1,20 +0,0 @@ -experiment: - task: "inria" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.0 - test_split_pct: 0.0 - patch_size: 2 - num_patches_per_tile: 2 diff --git a/tests/conf/inria_val.yaml b/tests/conf/inria_val.yaml deleted file mode 100644 index c20f8923439..00000000000 --- a/tests/conf/inria_val.yaml +++ /dev/null @@ -1,20 +0,0 @@ -experiment: - task: "inria" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: "imagenet" - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - in_channels: 3 - num_classes: 2 - ignore_index: null - datamodule: - root: "tests/data/inria" - batch_size: 1 - num_workers: 0 - val_split_pct: 0.2 - test_split_pct: 0.0 - patch_size: 2 - num_patches_per_tile: 2 diff --git a/tests/conf/potsdam2d.yaml b/tests/conf/potsdam2d.yaml index fcdfd07b37c..7492a8c0c86 100644 --- a/tests/conf/potsdam2d.yaml +++ b/tests/conf/potsdam2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/potsdam" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/conf/sen12ms_all.yaml b/tests/conf/sen12ms_all.yaml index ecf01b3bc72..e5676876550 100644 --- a/tests/conf/sen12ms_all.yaml +++ b/tests/conf/sen12ms_all.yaml @@ -15,4 +15,3 @@ experiment: band_set: "all" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s1.yaml b/tests/conf/sen12ms_s1.yaml index a5a4f083d9c..5289c3c8b63 100644 --- a/tests/conf/sen12ms_s1.yaml +++ b/tests/conf/sen12ms_s1.yaml @@ -16,4 +16,3 @@ experiment: band_set: "s1" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s2_all.yaml b/tests/conf/sen12ms_s2_all.yaml index dafe781db89..f1499b523e3 100644 --- a/tests/conf/sen12ms_s2_all.yaml +++ b/tests/conf/sen12ms_s2_all.yaml @@ -15,4 +15,3 @@ experiment: band_set: "s2-all" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/sen12ms_s2_reduced.yaml b/tests/conf/sen12ms_s2_reduced.yaml index 9891c4b44ef..72e85b56fc3 100644 --- a/tests/conf/sen12ms_s2_reduced.yaml +++ b/tests/conf/sen12ms_s2_reduced.yaml @@ -15,4 +15,3 @@ experiment: band_set: "s2-reduced" batch_size: 1 num_workers: 0 - seed: 0 diff --git a/tests/conf/so2sat_all.yaml b/tests/conf/so2sat_all.yaml new file mode 100644 index 00000000000..a8d8c0bb8e3 --- /dev/null +++ b/tests/conf/so2sat_all.yaml @@ -0,0 +1,15 @@ +experiment: + task: "so2sat" + module: + loss: "ce" + model: "resnet18" + learning_rate: 1e-3 + learning_rate_schedule_patience: 6 + weights: null + in_channels: 18 + num_classes: 17 + datamodule: + root: "tests/data/so2sat" + batch_size: 1 + num_workers: 0 + band_set: "all" diff --git a/tests/conf/so2sat_supervised.yaml b/tests/conf/so2sat_s1.yaml similarity index 79% rename from tests/conf/so2sat_supervised.yaml rename to tests/conf/so2sat_s1.yaml index 0cbe484d6fc..8c87ff55a53 100644 --- a/tests/conf/so2sat_supervised.yaml +++ b/tests/conf/so2sat_s1.yaml @@ -6,11 +6,10 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 8 num_classes: 17 datamodule: root: "tests/data/so2sat" batch_size: 1 num_workers: 0 - band_set: "rgb" - unsupervised_mode: False + band_set: "s1" diff --git a/tests/conf/so2sat_unsupervised.yaml b/tests/conf/so2sat_s2.yaml similarity index 79% rename from tests/conf/so2sat_unsupervised.yaml rename to tests/conf/so2sat_s2.yaml index 02c1e6a32e7..ab9c573a197 100644 --- a/tests/conf/so2sat_unsupervised.yaml +++ b/tests/conf/so2sat_s2.yaml @@ -6,11 +6,10 @@ experiment: learning_rate: 1e-3 learning_rate_schedule_patience: 6 weights: null - in_channels: 3 + in_channels: 10 num_classes: 17 datamodule: root: "tests/data/so2sat" batch_size: 1 num_workers: 0 - band_set: "rgb" - unsupervised_mode: True + band_set: "s2" diff --git a/tests/conf/vaihingen2d.yaml b/tests/conf/vaihingen2d.yaml index 0184b929a0a..7f542f3310b 100644 --- a/tests/conf/vaihingen2d.yaml +++ b/tests/conf/vaihingen2d.yaml @@ -14,8 +14,7 @@ experiment: ignore_index: null datamodule: root: "tests/data/vaihingen" - num_tiles_per_batch: 1 - num_patches_per_tile: 1 + batch_size: 1 patch_size: 2 val_split_pct: 0.5 num_workers: 0 diff --git a/tests/datamodules/test_chesapeake.py b/tests/datamodules/test_chesapeake.py index fab7b9088a1..ecf30775544 100644 --- a/tests/datamodules/test_chesapeake.py +++ b/tests/datamodules/test_chesapeake.py @@ -2,34 +2,13 @@ # Licensed under the MIT License. import os -from typing import Any, Dict, cast import pytest -import torch -from omegaconf import OmegaConf from torchgeo.datamodules import ChesapeakeCVPRDataModule class TestChesapeakeCVPRDataModule: - @pytest.fixture(scope="class") - def datamodule(self) -> ChesapeakeCVPRDataModule: - conf = OmegaConf.load(os.path.join("tests", "conf", "chesapeake_cvpr_5.yaml")) - kwargs = OmegaConf.to_object(conf.experiment.datamodule) - kwargs = cast(Dict[str, Any], kwargs) - - datamodule = ChesapeakeCVPRDataModule(**kwargs) - datamodule.prepare_data() - datamodule.setup() - return datamodule - - def test_nodata_check(self, datamodule: ChesapeakeCVPRDataModule) -> None: - nodata_check = datamodule.nodata_check(4) - sample = {"image": torch.ones(1, 2, 2), "mask": torch.ones(2, 2)} - out = nodata_check(sample) - assert torch.equal(out["image"], torch.zeros(1, 4, 4)) - assert torch.equal(out["mask"], torch.zeros(4, 4)) - def test_invalid_param_config(self) -> None: with pytest.raises(ValueError, match="The pre-generated prior labels"): ChesapeakeCVPRDataModule( @@ -37,9 +16,9 @@ def test_invalid_param_config(self) -> None: train_splits=["de-test"], val_splits=["de-test"], test_splits=["de-test"], - patch_size=32, - patches_per_tile=2, batch_size=2, + patch_size=32, + length=4, num_workers=0, class_set=7, use_prior_labels=True, diff --git a/tests/datamodules/test_fair1m.py b/tests/datamodules/test_fair1m.py index ac9d196d58f..dd3f2bd13f2 100644 --- a/tests/datamodules/test_fair1m.py +++ b/tests/datamodules/test_fair1m.py @@ -11,7 +11,7 @@ class TestFAIR1MDataModule: - @pytest.fixture(scope="class") + @pytest.fixture def datamodule(self) -> FAIR1MDataModule: root = os.path.join("tests", "data", "fair1m") batch_size = 2 @@ -23,20 +23,23 @@ def datamodule(self) -> FAIR1MDataModule: val_split_pct=0.33, test_split_pct=0.33, ) - dm.setup() return dm def test_train_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("fit") next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("validate") next(iter(datamodule.val_dataloader())) def test_test_dataloader(self, datamodule: FAIR1MDataModule) -> None: + datamodule.setup("test") next(iter(datamodule.test_dataloader())) def test_plot(self, datamodule: FAIR1MDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_geo.py b/tests/datamodules/test_geo.py new file mode 100644 index 00000000000..583e7050f9c --- /dev/null +++ b/tests/datamodules/test_geo.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Any, Dict + +import pytest +import torch +from torch import Tensor + +from torchgeo.datamodules import ( + GeoDataModule, + MisconfigurationException, + NonGeoDataModule, +) +from torchgeo.datasets import BoundingBox, GeoDataset, NonGeoDataset +from torchgeo.samplers import RandomBatchGeoSampler, RandomGeoSampler + + +class CustomGeoDataset(GeoDataset): + def __init__(self, split: str = "train", download: bool = False) -> None: + super().__init__() + self.index.insert(0, (0, 1, 2, 3, 4, 5)) + self.res = 1 + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} + + +class CustomGeoDataModule(GeoDataModule): + def __init__(self) -> None: + super().__init__(CustomGeoDataset, 1, 1, 1, 0, download=True) + + +class SamplerGeoDatModule(CustomGeoDataModule): + def setup(self, stage: str) -> None: + self.dataset = CustomGeoDataset() + self.train_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.val_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.test_sampler = RandomGeoSampler(self.dataset, 1, 1) + self.predict_sampler = RandomGeoSampler(self.dataset, 1, 1) + + +class BatchSamplerGeoDatModule(CustomGeoDataModule): + def setup(self, stage: str) -> None: + self.dataset = CustomGeoDataset() + self.train_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.val_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.test_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + self.predict_batch_sampler = RandomBatchGeoSampler(self.dataset, 1, 1, 1) + + +class CustomNonGeoDataset(NonGeoDataset): + def __init__(self, split: str = "train", download: bool = False) -> None: + pass + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + return {"image": torch.arange(3 * 2 * 2).view(3, 2, 2)} + + def __len__(self) -> int: + return 1 + + +class CustomNonGeoDataModule(NonGeoDataModule): + def __init__(self) -> None: + super().__init__(CustomNonGeoDataset, 1, 0, download=True) + + +class TestGeoDataModule: + @pytest.mark.parametrize("stage", ["fit", "validate", "test"]) + def test_setup(self, stage: str) -> None: + dm = CustomGeoDataModule() + dm.prepare_data() + dm.setup(stage) + + def test_sampler(self) -> None: + dm = SamplerGeoDatModule() + dm.setup("fit") + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + dm.predict_dataloader() + + def test_batch_sampler(self) -> None: + dm = BatchSamplerGeoDatModule() + dm.setup("fit") + dm.train_dataloader() + dm.val_dataloader() + dm.test_dataloader() + dm.predict_dataloader() + + def test_no_datasets(self) -> None: + dm = CustomGeoDataModule() + msg = "CustomGeoDataModule.setup does not define a '{}_dataset'" + with pytest.raises(MisconfigurationException, match=msg.format("train")): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("val")): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("test")): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("predict")): + dm.predict_dataloader() + + +class TestNonGeoDataModule: + @pytest.mark.parametrize("stage", ["fit", "validate", "test"]) + def test_setup(self, stage: str) -> None: + dm = CustomNonGeoDataModule() + dm.prepare_data() + dm.setup(stage) + + def test_no_datasets(self) -> None: + dm = CustomNonGeoDataModule() + msg = "CustomNonGeoDataModule.setup does not define a '{}_dataset'" + with pytest.raises(MisconfigurationException, match=msg.format("train")): + dm.train_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("val")): + dm.val_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("test")): + dm.test_dataloader() + with pytest.raises(MisconfigurationException, match=msg.format("predict")): + dm.predict_dataloader() diff --git a/tests/datamodules/test_oscd.py b/tests/datamodules/test_oscd.py index 3a8fdefc4e9..66d8500e386 100644 --- a/tests/datamodules/test_oscd.py +++ b/tests/datamodules/test_oscd.py @@ -5,59 +5,62 @@ import pytest from _pytest.fixtures import SubRequest +from pytorch_lightning import Trainer from torchgeo.datamodules import OSCDDataModule class TestOSCDDataModule: - @pytest.fixture(scope="class", params=zip(["all", "rgb"], [0.0, 0.5])) + @pytest.fixture(params=["all", "rgb"]) def datamodule(self, request: SubRequest) -> OSCDDataModule: - bands, val_split_pct = request.param - patch_size = (2, 2) - num_patches_per_tile = 2 + bands = request.param root = os.path.join("tests", "data", "oscd") - batch_size = 1 - num_workers = 0 dm = OSCDDataModule( root=root, download=True, bands=bands, - train_batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_pct, - patch_size=patch_size, - num_patches_per_tile=num_patches_per_tile, + batch_size=1, + patch_size=2, + val_split_pct=0.5, + num_workers=0, ) dm.prepare_data() - dm.setup() + dm.trainer = Trainer(max_epochs=1) return dm def test_train_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (2, 2) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 2 - if datamodule.test_dataset.bands == "all": - assert sample["image"].shape[1] == 26 + datamodule.setup("fit") + datamodule.trainer.training = True # type: ignore[union-attr] + batch = next(iter(datamodule.train_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 def test_val_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.val_dataloader())) + datamodule.setup("validate") + datamodule.trainer.validating = True # type: ignore[union-attr] + batch = next(iter(datamodule.val_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) if datamodule.val_split_pct > 0.0: - assert ( - sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - ) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.test_dataset.bands == "all": - assert sample["image"].shape[1] == 26 + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 def test_test_dataloader(self, datamodule: OSCDDataModule) -> None: - sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[-2:] == sample["mask"].shape[-2:] == (1280, 1280) - assert sample["image"].shape[0] == sample["mask"].shape[0] == 1 - if datamodule.test_dataset.bands == "all": - assert sample["image"].shape[1] == 26 + datamodule.setup("test") + datamodule.trainer.testing = True # type: ignore[union-attr] + batch = next(iter(datamodule.test_dataloader())) + batch = datamodule.on_after_batch_transfer(batch, 0) + assert batch["image"].shape[-2:] == batch["mask"].shape[-2:] == (2, 2) + assert batch["image"].shape[0] == batch["mask"].shape[0] == 1 + if datamodule.bands == "all": + assert batch["image"].shape[1] == 26 else: - assert sample["image"].shape[1] == 6 + assert batch["image"].shape[1] == 6 diff --git a/tests/datamodules/test_usavars.py b/tests/datamodules/test_usavars.py index ddcde7e6b26..7f04644cc20 100644 --- a/tests/datamodules/test_usavars.py +++ b/tests/datamodules/test_usavars.py @@ -12,7 +12,7 @@ class TestUSAVarsDataModule: - @pytest.fixture() + @pytest.fixture def datamodule(self, request: SubRequest) -> USAVarsDataModule: pytest.importorskip("pandas", minversion="0.23.2") root = os.path.join("tests", "data", "usavars") @@ -23,26 +23,29 @@ def datamodule(self, request: SubRequest) -> USAVarsDataModule: root=root, batch_size=batch_size, num_workers=num_workers, download=True ) dm.prepare_data() - dm.setup() return dm def test_train_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("fit") assert len(datamodule.train_dataloader()) == 3 - sample = next(iter(datamodule.train_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.train_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_val_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("validate") assert len(datamodule.val_dataloader()) == 2 - sample = next(iter(datamodule.val_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.val_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_test_dataloader(self, datamodule: USAVarsDataModule) -> None: + datamodule.setup("test") assert len(datamodule.test_dataloader()) == 1 - sample = next(iter(datamodule.test_dataloader())) - assert sample["image"].shape[0] == datamodule.batch_size + batch = next(iter(datamodule.test_dataloader())) + assert batch["image"].shape[0] == datamodule.batch_size def test_plot(self, datamodule: USAVarsDataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datamodules/test_xview2.py b/tests/datamodules/test_xview2.py index c190b5d2acd..1a0e158b366 100644 --- a/tests/datamodules/test_xview2.py +++ b/tests/datamodules/test_xview2.py @@ -5,40 +5,38 @@ import matplotlib.pyplot as plt import pytest -from _pytest.fixtures import SubRequest from torchgeo.datamodules import XView2DataModule from torchgeo.datasets import unbind_samples class TestXView2DataModule: - @pytest.fixture(scope="class", params=[0.0, 0.5]) - def datamodule(self, request: SubRequest) -> XView2DataModule: + @pytest.fixture + def datamodule(self) -> XView2DataModule: root = os.path.join("tests", "data", "xview2") batch_size = 1 num_workers = 0 - val_split_size = request.param dm = XView2DataModule( - root=root, - batch_size=batch_size, - num_workers=num_workers, - val_split_pct=val_split_size, + root=root, batch_size=batch_size, num_workers=num_workers, val_split_pct=0.5 ) dm.prepare_data() - dm.setup() return dm def test_train_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("fit") next(iter(datamodule.train_dataloader())) def test_val_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("validate") next(iter(datamodule.val_dataloader())) def test_test_dataloader(self, datamodule: XView2DataModule) -> None: + datamodule.setup("test") next(iter(datamodule.test_dataloader())) def test_plot(self, datamodule: XView2DataModule) -> None: - batch = next(iter(datamodule.train_dataloader())) + datamodule.setup("validate") + batch = next(iter(datamodule.val_dataloader())) sample = unbind_samples(batch)[0] datamodule.plot(sample) plt.close() diff --git a/tests/datasets/test_bigearthnet.py b/tests/datasets/test_bigearthnet.py index 349530e9b18..b6b7324ad5c 100644 --- a/tests/datasets/test_bigearthnet.py +++ b/tests/datasets/test_bigearthnet.py @@ -75,7 +75,7 @@ def test_getitem(self, dataset: BigEarthNet) -> None: assert isinstance(x["image"], torch.Tensor) assert isinstance(x["label"], torch.Tensor) assert x["label"].shape == (dataset.num_classes,) - assert x["image"].dtype == torch.int32 + assert x["image"].dtype == torch.float32 assert x["label"].dtype == torch.int64 if dataset.bands == "all": diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index a7310cc8cd9..1452ad18225 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -61,8 +61,8 @@ def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: assert isinstance(x["storm_id"], str) assert isinstance(x["relative_time"], int) assert isinstance(x["ocean"], int) - assert isinstance(x["label"], int) - assert x["image"].shape == (dataset.size, dataset.size) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape == (3, dataset.size, dataset.size) def test_len(self, dataset: TropicalCyclone) -> None: assert len(dataset) == 5 @@ -88,6 +88,6 @@ def test_plot(self, dataset: TropicalCyclone) -> None: plt.close() sample = dataset[0] - sample["prediction"] = torch.tensor(sample["label"]) + sample["prediction"] = sample["label"] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_geo.py b/tests/datasets/test_geo.py index b691961806e..5a27522f404 100644 --- a/tests/datasets/test_geo.py +++ b/tests/datasets/test_geo.py @@ -211,7 +211,7 @@ def test_getitem_uint_dtype(self, custom_dtype_ds: RasterDataset) -> None: x = custom_dtype_ds[custom_dtype_ds.bounds] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert x["image"].dtype == torch.int64 + assert x["image"].dtype == torch.float32 def test_invalid_query(self, sentinel: Sentinel2) -> None: query = BoundingBox(0, 0, 0, 0, 0, 0) diff --git a/tests/datasets/test_nasa_marine_debris.py b/tests/datasets/test_nasa_marine_debris.py index 1a428cdf858..706dcf52552 100644 --- a/tests/datasets/test_nasa_marine_debris.py +++ b/tests/datasets/test_nasa_marine_debris.py @@ -59,7 +59,6 @@ def test_already_downloaded_not_extracted( shutil.rmtree(dataset.root) os.makedirs(str(tmp_path), exist_ok=True) Dataset().download(output_dir=str(tmp_path)) - print(os.listdir(str(tmp_path))) NASAMarineDebris(root=str(tmp_path), download=False) def test_not_downloaded(self, tmp_path: Path) -> None: diff --git a/tests/datasets/test_oscd.py b/tests/datasets/test_oscd.py index f2f0582dc7e..365bcd97a60 100644 --- a/tests/datasets/test_oscd.py +++ b/tests/datasets/test_oscd.py @@ -73,14 +73,14 @@ def test_getitem(self, dataset: OSCD) -> None: x = dataset[0] assert isinstance(x, dict) assert isinstance(x["image"], torch.Tensor) - assert x["image"].ndim == 4 + assert x["image"].ndim == 3 assert isinstance(x["mask"], torch.Tensor) assert x["mask"].ndim == 2 if dataset.bands == "rgb": - assert x["image"].shape[:2] == (2, 3) + assert x["image"].shape[0] == 6 else: - assert x["image"].shape[:2] == (2, 13) + assert x["image"].shape[0] == 26 def test_len(self, dataset: OSCD) -> None: if dataset.split == "train": diff --git a/tests/trainers/conftest.py b/tests/trainers/conftest.py index ee0135380a1..6bae33ffdd6 100644 --- a/tests/trainers/conftest.py +++ b/tests/trainers/conftest.py @@ -15,6 +15,14 @@ from torch.nn.modules import Module +@pytest.fixture( + scope="package", params=[True, pytest.param(False, marks=pytest.mark.slow)] +) +def fast_dev_run(request: SubRequest) -> bool: + flag: bool = request.param + return flag + + @pytest.fixture(scope="package") def model() -> Module: kwargs: Dict[str, Optional[bool]] = {} diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 492ca87eda4..2f5dbd8d2a4 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -16,8 +16,10 @@ from torchvision.models import resnet18 from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException +from torchgeo.datasets import ChesapeakeCVPR from torchgeo.models import ResNet18_Weights +from torchgeo.samplers import GridGeoSampler from torchgeo.trainers import BYOLTask from torchgeo.trainers.byol import BYOL, SimCLRAugmentation @@ -29,6 +31,16 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +class PredictBYOLDataModule(ChesapeakeCVPRDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = ChesapeakeCVPR( + splits=self.test_splits, layers=self.layers, **self.kwargs + ) + self.predict_sampler = GridGeoSampler( + self.predict_dataset, self.original_patch_size, self.original_patch_size + ) + + class TestBYOL: def test_custom_augment_fn(self) -> None: backbone = resnet18() @@ -54,7 +66,9 @@ class TestBYOLTask: ("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule), ], ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) @@ -70,10 +84,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model.backbone = SegmentationTestModel(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: @@ -104,3 +124,18 @@ def test_weight_str( ) -> None: model_kwargs["weights"] = str(mocked_weights) BYOLTask(**model_kwargs) + + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + datamodule = PredictBYOLDataModule( + root="tests/data/chesapeake/cvpr", + train_splits=["de-test"], + val_splits=["de-test"], + test_splits=["de-test"], + batch_size=1, + patch_size=64, + num_workers=0, + ) + model_kwargs["in_channels"] = 4 + model = BYOLTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index 922e171d422..6f8596da35b 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -18,16 +18,28 @@ from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, + MisconfigurationException, RESISC45DataModule, So2SatDataModule, UCMercedDataModule, ) +from torchgeo.datasets import BigEarthNet, EuroSAT from torchgeo.models import ResNet18_Weights from torchgeo.trainers import ClassificationTask, MultiLabelClassificationTask from .test_utils import ClassificationTestModel +class PredictClassificationDataModule(EuroSATDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = EuroSAT(split="test", **self.kwargs) + + +class PredictMultiLabelClassificationDataModule(BigEarthNetDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = BigEarthNet(split="test", **self.kwargs) + + def create_model(*args: Any, **kwargs: Any) -> Module: return ClassificationTestModel(**kwargs) @@ -37,19 +49,28 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestClassificationTask: @pytest.mark.parametrize( "name,classname", [ ("eurosat", EuroSATDataModule), ("resisc45", RESISC45DataModule), - ("so2sat_supervised", So2SatDataModule), - ("so2sat_unsupervised", So2SatDataModule), + ("so2sat_all", So2SatDataModule), + ("so2sat_s1", So2SatDataModule), + ("so2sat_s2", So2SatDataModule), ("ucmerced", UCMercedDataModule), ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: if name.startswith("so2sat"): pytest.importorskip("h5py", minversion="2.6") @@ -68,29 +89,16 @@ def test_trainer( model = ClassificationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) - - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "ucmerced.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = UCMercedDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = ClassificationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: @@ -137,17 +145,25 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): ClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: - monkeypatch.delattr(EuroSATDataModule, "plot") + monkeypatch.setattr(EuroSATDataModule, "plot", plot) datamodule = EuroSATDataModule( root="tests/data/eurosat", batch_size=1, num_workers=0 ) model = ClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + datamodule = PredictClassificationDataModule( + root="tests/data/eurosat", batch_size=1, num_workers=0 + ) + model = ClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) + class TestMultiLabelClassificationTask: @pytest.mark.parametrize( @@ -159,7 +175,11 @@ class TestMultiLabelClassificationTask: ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) @@ -175,29 +195,16 @@ def test_trainer( model = MultiLabelClassificationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) - - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "bigearthnet_s1.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = BigEarthNetDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = MultiLabelClassificationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: @@ -215,13 +222,21 @@ def test_invalid_loss(self, model_kwargs: Dict[str, Any]) -> None: with pytest.raises(ValueError, match=match): MultiLabelClassificationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[str, Any], monkeypatch: MonkeyPatch + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: - monkeypatch.delattr(BigEarthNetDataModule, "plot") + monkeypatch.setattr(BigEarthNetDataModule, "plot", plot) datamodule = BigEarthNetDataModule( root="tests/data/bigearthnet", batch_size=1, num_workers=0 ) model = MultiLabelClassificationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + datamodule = PredictMultiLabelClassificationDataModule( + root="tests/data/bigearthnet", batch_size=1, num_workers=0 + ) + model = MultiLabelClassificationTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 68e48e92817..b3ff70cad0f 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,15 +9,27 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule +from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask +class PredictObjectDetectionDataModule(NASAMarineDebrisDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = NASAMarineDebris(**self.kwargs) + + +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestObjectDetectionTask: @pytest.mark.parametrize( "name,classname", [("nasa_marine_debris", NASAMarineDebrisDataModule)] ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", f"{name}.yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) @@ -31,10 +43,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model = ObjectDetectionTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: @@ -56,13 +74,21 @@ def test_non_pretrained_backbone(self, model_kwargs: Dict[Any, Any]) -> None: model_kwargs["pretrained"] = False ObjectDetectionTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: - monkeypatch.delattr(NASAMarineDebrisDataModule, "plot") + monkeypatch.setattr(NASAMarineDebrisDataModule, "plot", plot) datamodule = NASAMarineDebrisDataModule( root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 ) model = ObjectDetectionTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) + + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + datamodule = PredictObjectDetectionDataModule( + root="tests/data/nasa_marine_debris", batch_size=1, num_workers=0 + ) + model = ObjectDetectionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 22524b0a8e2..65fbeabfeca 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -14,7 +14,12 @@ from pytorch_lightning import LightningDataModule, Trainer from torchvision.models._api import WeightsEnum -from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule +from torchgeo.datamodules import ( + COWCCountingDataModule, + MisconfigurationException, + TropicalCycloneDataModule, +) +from torchgeo.datasets import TropicalCyclone from torchgeo.models import ResNet18_Weights from torchgeo.trainers import RegressionTask @@ -26,6 +31,15 @@ def load(url: str, *args: Any, **kwargs: Any) -> Dict[str, Any]: return state_dict +class PredictRegressionDataModule(TropicalCycloneDataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = TropicalCyclone(split="test", **self.kwargs) + + +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestRegressionTask: @pytest.mark.parametrize( "name,classname", @@ -34,7 +48,9 @@ class TestRegressionTask: ("cyclone", TropicalCycloneDataModule), ], ) - def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: + def test_trainer( + self, name: str, classname: Type[LightningDataModule], fast_dev_run: bool + ) -> None: conf = OmegaConf.load(os.path.join("tests", "conf", name + ".yaml")) conf_dict = OmegaConf.to_object(conf.experiment) conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) @@ -50,29 +66,16 @@ def test_trainer(self, name: str, classname: Type[LightningDataModule]) -> None: model.model = RegressionTestModel() # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) - trainer.fit(model=model, datamodule=datamodule) - trainer.test(model=model, datamodule=datamodule) - trainer.predict(model=model, dataloaders=datamodule.val_dataloader()) - - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "cyclone.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[str, Dict[str, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = TropicalCycloneDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = RegressionTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) + try: + trainer.test(model=model, datamodule=datamodule) + except MisconfigurationException: + pass + try: + trainer.predict(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[str, Any]: @@ -111,3 +114,22 @@ def test_weight_str( model_kwargs["weights"] = str(mocked_weights) with pytest.warns(UserWarning): RegressionTask(**model_kwargs) + + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool + ) -> None: + monkeypatch.setattr(TropicalCycloneDataModule, "plot", plot) + datamodule = TropicalCycloneDataModule( + root="tests/data/cyclone", batch_size=1, num_workers=0 + ) + model = RegressionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.validate(model=model, datamodule=datamodule) + + def test_predict(self, model_kwargs: Dict[Any, Any], fast_dev_run: bool) -> None: + datamodule = PredictRegressionDataModule( + root="tests/data/cyclone", batch_size=1, num_workers=0 + ) + model = RegressionTask(**model_kwargs) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) + trainer.predict(model=model, datamodule=datamodule) diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 35467c179d9..aa458c883a4 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -19,6 +19,7 @@ InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, + MisconfigurationException, NAIPChesapeakeDataModule, Potsdam2DDataModule, SEN12MSDataModule, @@ -35,6 +36,10 @@ def create_model(**kwargs: Any) -> Module: return SegmentationTestModel(**kwargs) +def plot(*args: Any, **kwargs: Any) -> None: + raise ValueError + + class TestSemanticSegmentationTask: @pytest.mark.parametrize( "name,classname", @@ -43,9 +48,7 @@ class TestSemanticSegmentationTask: ("deepglobelandcover", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), ("gid15", GID15DataModule), - ("inria_train", InriaAerialImageLabelingDataModule), - ("inria_val", InriaAerialImageLabelingDataModule), - ("inria_test", InriaAerialImageLabelingDataModule), + ("inria", InriaAerialImageLabelingDataModule), ("landcoverai", LandCoverAIDataModule), ("loveda", LoveDADataModule), ("naipchesapeake", NAIPChesapeakeDataModule), @@ -59,7 +62,11 @@ class TestSemanticSegmentationTask: ], ) def test_trainer( - self, monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule] + self, + monkeypatch: MonkeyPatch, + name: str, + classname: Type[LightningDataModule], + fast_dev_run: bool, ) -> None: if name == "naipchesapeake": pytest.importorskip("zipfile_deflate64") @@ -83,33 +90,16 @@ def test_trainer( model = SemanticSegmentationTask(**model_kwargs) # Instantiate trainer - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.fit(model=model, datamodule=datamodule) - - if hasattr(datamodule, "test_dataset") or hasattr(datamodule, "test_sampler"): + try: trainer.test(model=model, datamodule=datamodule) - - if hasattr(datamodule, "predict_dataset"): + except MisconfigurationException: + pass + try: trainer.predict(model=model, datamodule=datamodule) - - def test_no_logger(self) -> None: - conf = OmegaConf.load(os.path.join("tests", "conf", "landcoverai.yaml")) - conf_dict = OmegaConf.to_object(conf.experiment) - conf_dict = cast(Dict[Any, Dict[Any, Any]], conf_dict) - - # Instantiate datamodule - datamodule_kwargs = conf_dict["datamodule"] - datamodule = LandCoverAIDataModule(**datamodule_kwargs) - - # Instantiate model - model_kwargs = conf_dict["module"] - model = SemanticSegmentationTask(**model_kwargs) - - # Instantiate trainer - trainer = Trainer( - logger=False, fast_dev_run=True, log_every_n_steps=1, max_epochs=1 - ) - trainer.fit(model=model, datamodule=datamodule) + except MisconfigurationException: + pass @pytest.fixture def model_kwargs(self) -> Dict[Any, Any]: @@ -148,13 +138,14 @@ def test_ignoreindex_with_jaccard(self, model_kwargs: Dict[Any, Any]) -> None: with pytest.warns(UserWarning, match=match): SemanticSegmentationTask(**model_kwargs) - def test_missing_attributes( - self, model_kwargs: Dict[Any, Any], monkeypatch: MonkeyPatch + def test_no_rgb( + self, monkeypatch: MonkeyPatch, model_kwargs: Dict[Any, Any], fast_dev_run: bool ) -> None: - monkeypatch.delattr(LandCoverAIDataModule, "plot") - datamodule = LandCoverAIDataModule( - root="tests/data/landcoverai", batch_size=1, num_workers=0 + model_kwargs["in_channels"] = 15 + monkeypatch.setattr(SEN12MSDataModule, "plot", plot) + datamodule = SEN12MSDataModule( + root="tests/data/sen12ms", batch_size=1, num_workers=0 ) model = SemanticSegmentationTask(**model_kwargs) - trainer = Trainer(fast_dev_run=True, log_every_n_steps=1, max_epochs=1) + trainer = Trainer(fast_dev_run=fast_dev_run, log_every_n_steps=1, max_epochs=1) trainer.validate(model=model, datamodule=datamodule) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 68d9632d1b9..4b0c2630f99 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -24,7 +24,7 @@ def batch_gray() -> Dict[str, Tensor]: return { "image": torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -43,7 +43,7 @@ def batch_rgb() -> Dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -64,7 +64,7 @@ def batch_multispectral() -> Dict[str, Tensor]: ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -80,7 +80,7 @@ def assert_matching(output: Dict[str, Tensor], expected: Dict[str, Tensor]) -> N def test_augmentation_sequential_gray(batch_gray: Dict[str, Tensor]) -> None: expected = { "image": torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -103,7 +103,7 @@ def test_augmentation_sequential_rgb(batch_rgb: Dict[str, Tensor]) -> None: ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -130,7 +130,7 @@ def test_augmentation_sequential_multispectral( ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -157,7 +157,7 @@ def test_augmentation_sequential_image_only( ], dtype=torch.float, ), - "mask": torch.tensor([[[[0, 0, 1], [0, 1, 1], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[0, 1], [1, 1], [1, 0], [0, 0]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } @@ -189,7 +189,7 @@ def test_sequential_transforms_augmentations( ], dtype=torch.float, ), - "mask": torch.tensor([[[[1, 0, 0], [1, 1, 0], [1, 1, 1]]]], dtype=torch.long), + "mask": torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), "boxes": torch.tensor([[[1, 0], [2, 0], [2, 1], [1, 1]]], dtype=torch.float), "labels": torch.tensor([[0, 1]]), } diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 14e50ad7d45..fe4dafaa986 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -11,6 +11,7 @@ from .etci2021 import ETCI2021DataModule from .eurosat import EuroSATDataModule from .fair1m import FAIR1MDataModule +from .geo import GeoDataModule, NonGeoDataModule from .gid15 import GID15DataModule from .inria import InriaAerialImageLabelingDataModule from .landcoverai import LandCoverAIDataModule @@ -25,6 +26,7 @@ from .spacenet import SpaceNet1DataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule +from .utils import MisconfigurationException from .vaihingen import Vaihingen2DDataModule from .xview import XView2DataModule @@ -55,4 +57,9 @@ "USAVarsDataModule", "Vaihingen2DDataModule", "XView2DataModule", + # Base classes + "GeoDataModule", + "NonGeoDataModule", + # Utilities + "MisconfigurationException", ) diff --git a/torchgeo/datamodules/bigearthnet.py b/torchgeo/datamodules/bigearthnet.py index 348f7773fb0..b75c38af3dd 100644 --- a/torchgeo/datamodules/bigearthnet.py +++ b/torchgeo/datamodules/bigearthnet.py @@ -3,18 +3,15 @@ """BigEarthNet datamodule.""" -from typing import Any, Dict, Optional +from typing import Any -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import BigEarthNet +from .geo import NonGeoDataModule -class BigEarthNetDataModule(pl.LightningDataModule): +class BigEarthNetDataModule(NonGeoDataModule): """LightningDataModule implementation for the BigEarthNet dataset. Uses the train/val/test splits from the dataset. @@ -22,10 +19,10 @@ class BigEarthNetDataModule(pl.LightningDataModule): # (VV, VH, B01, B02, B03, B04, B05, B06, B07, B08, B8A, B09, B11, B12) # min/max band statistics computed on 100k random samples - band_mins_raw = torch.tensor( + mins_raw = torch.tensor( [-70.0, -72.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0] ) - band_maxs_raw = torch.tensor( + maxs_raw = torch.tensor( [ 31.0, 35.0, @@ -46,10 +43,10 @@ class BigEarthNetDataModule(pl.LightningDataModule): # min/max band statistics computed by percentile clipping the # above to samples to [2, 98] - band_mins = torch.tensor( + mins = torch.tensor( [-48.0, -42.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] ) - band_maxs = torch.tensor( + maxs = torch.tensor( [ 6.0, 16.0, @@ -71,91 +68,25 @@ class BigEarthNetDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for BigEarthNet based DataLoaders. + """Initialize a new BigEarthNetDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.BigEarthNet` + :class:`~torchgeo.datasets.BigEarthNet`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - bands = kwargs.get("bands", "all") if bands == "all": - self.mins = self.band_mins[:, None, None] - self.maxs = self.band_maxs[:, None, None] + mins = self.mins + maxs = self.maxs elif bands == "s1": - self.mins = self.band_mins[:2, None, None] - self.maxs = self.band_maxs[:2, None, None] + mins = self.mins[:2] + maxs = self.maxs[:2] else: - self.mins = self.band_mins[2:, None, None] - self.maxs = self.band_maxs[2:, None, None] - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = (sample["image"] - self.mins) / (self.maxs - self.mins) - sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) - return sample + mins = self.mins[2:] + maxs = self.maxs[2:] + self.mean = mins + self.std = maxs - mins - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - BigEarthNet(split="train", **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - """ - transforms = Compose([self.preprocess]) - self.train_dataset = BigEarthNet( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = BigEarthNet( - split="val", transforms=transforms, **self.kwargs - ) - self.test_dataset = BigEarthNet( - split="test", transforms=transforms, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.BigEarthNet.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) + super().__init__(BigEarthNet, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 3163e2683c7..50e699aa188 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -3,21 +3,53 @@ """Chesapeake Bay High-Resolution Land Cover Project datamodule.""" -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List -import torch +import kornia.augmentation as K +import torch.nn as nn import torch.nn.functional as F -from pytorch_lightning.core.datamodule import LightningDataModule +from einops import rearrange from torch import Tensor -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets import ChesapeakeCVPR, stack_samples -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler +from ..datasets import ChesapeakeCVPR +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..transforms import AugmentationSequential +from .geo import GeoDataModule -class ChesapeakeCVPRDataModule(LightningDataModule): +class _Transform(nn.Module): + """Version of AugmentationSequential designed for samples, not batches.""" + + def __init__(self, aug: nn.Module) -> None: + """Initialize a new _Transform instance. + + Args: + aug: Augmentation to apply. + """ + super().__init__() + self.aug = aug + + def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Apply the augmentation. + + Args: + sample: Input sample. + + Returns: + Augmented sample. + """ + for key in ["image", "mask"]: + dtype = sample[key].dtype + # All inputs must be float + sample[key] = sample[key].float() + sample[key] = self.aug(sample[key]) + sample[key] = sample[key].to(dtype) + # Kornia adds batch dimension + sample[key] = rearrange(sample[key], "() c h w -> c h w") + return sample + + +class ChesapeakeCVPRDataModule(GeoDataModule): """LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset. Uses the random splits defined per state to partition tiles into train, val, @@ -29,41 +61,47 @@ def __init__( train_splits: List[str], val_splits: List[str], test_splits: List[str], - patches_per_tile: int = 200, - patch_size: int = 256, batch_size: int = 64, + patch_size: int = 256, + length: int = 1000, num_workers: int = 0, class_set: int = 7, use_prior_labels: bool = False, prior_smoothing_constant: float = 1e-4, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for Chesapeake CVPR based DataLoaders. + """Initialize a new ChesapeakeCVPRDataModule instance. Args: - train_splits: The splits used to train the model, e.g. ["ny-train"] - val_splits: The splits used to validate the model, e.g. ["ny-val"] - test_splits: The splits used to test the model, e.g. ["ny-test"] - patches_per_tile: The number of patches per tile to sample - patch_size: The size of each patch in pixels (test patches will be 1.5 times - this size) - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - class_set: The high-resolution land cover class set to use - 5 or 7 + train_splits: Splits used to train the model, e.g., ["ny-train"]. + val_splits: Splits used to validate the model, e.g., ["ny-val"]. + test_splits: Splits used to test the model, e.g., ["ny-test"]. + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + class_set: The high-resolution land cover class set to use (5 or 7). use_prior_labels: Flag for using a prior over high-resolution classes - instead of the high-resolution labels themselves - prior_smoothing_constant: additive smoothing to add when using prior labels + instead of the high-resolution labels themselves. + prior_smoothing_constant: Additive smoothing to add when using prior labels. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.ChesapeakeCVPR` + :class:`~torchgeo.datasets.ChesapeakeCVPR`. Raises: - ValueError: if ``use_prior_labels`` is used with ``class_set==7`` + ValueError: If ``use_prior_labels=True`` is used with ``class_set=7``. """ - super().__init__() - for state in train_splits + val_splits + test_splits: - assert state in ChesapeakeCVPR.splits + # This is a rough estimate of how large of a patch we will need to sample in + # EPSG:3857 in order to guarantee a large enough patch in the local CRS. + self.original_patch_size = patch_size * 2 + kwargs["transforms"] = _Transform(K.CenterCrop(patch_size)) + + super().__init__( + ChesapeakeCVPR, batch_size, patch_size, length, num_workers, **kwargs + ) + assert class_set in [5, 7] - if use_prior_labels and class_set != 5: + if use_prior_labels and class_set == 7: raise ValueError( "The pre-generated prior labels are only valid for the 5" + " class set of labels" @@ -72,17 +110,9 @@ def __init__( self.train_splits = train_splits self.val_splits = val_splits self.test_splits = test_splits - self.patches_per_tile = patches_per_tile - self.patch_size = patch_size - # This is a rough estimate of how large of a patch we will need to sample in - # EPSG:3857 in order to guarantee a large enough patch in the local CRS. - self.original_patch_size = patch_size * 2 - self.batch_size = batch_size - self.num_workers = num_workers self.class_set = class_set self.use_prior_labels = use_prior_labels self.prior_smoothing_constant = prior_smoothing_constant - self.kwargs = kwargs if self.use_prior_labels: self.layers = [ @@ -92,244 +122,61 @@ def __init__( else: self.layers = ["naip-new", "lc"] - def pad_to( - self, size: int = 512, image_value: int = 0, mask_value: int = 0 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a padding transform on a single sample. + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + ) - Args: - size: output image size - image_value: value to pad image with - mask_value: value to pad mask with + def setup(self, stage: str) -> None: + """Set up datasets and samplers. - Returns: - function to perform padding + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - - def pad_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - assert height <= size and width <= size - - height_pad = size - height - width_pad = size - width - - # See https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html - # for a description of the format of the padding tuple - sample["image"] = F.pad( - sample["image"], - (0, width_pad, 0, height_pad), - mode="constant", - value=image_value, + if stage in ["fit"]: + self.train_dataset = ChesapeakeCVPR( + splits=self.train_splits, layers=self.layers, **self.kwargs ) - sample["mask"] = F.pad( - sample["mask"], - (0, width_pad, 0, height_pad), - mode="constant", - value=mask_value, + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, + self.original_patch_size, + self.batch_size, + self.length, + ) + if stage in ["fit", "validate"]: + self.val_dataset = ChesapeakeCVPR( + splits=self.val_splits, layers=self.layers, **self.kwargs + ) + self.val_sampler = GridGeoSampler( + self.val_dataset, self.original_patch_size, self.original_patch_size + ) + if stage in ["test"]: + self.test_dataset = ChesapeakeCVPR( + splits=self.test_splits, layers=self.layers, **self.kwargs + ) + self.test_sampler = GridGeoSampler( + self.test_dataset, self.original_patch_size, self.original_patch_size ) - return sample - - return pad_inner - - def center_crop( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to perform a center crop transform on a single sample. - - Args: - size: output image size - - Returns: - function to perform center crop - """ - - def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - _, height, width = sample["image"].shape - - y1 = round((height - size) / 2) - x1 = round((width - size) / 2) - sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] - sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] - - return sample - - return center_crop_inner - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Preprocesses a single sample. - - Args: - sample: sample dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - if "mask" in sample: - sample["mask"] = sample["mask"].squeeze() - if self.use_prior_labels: - sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) - sample["mask"] = F.normalize( - sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 - ) - else: - if self.class_set == 5: - sample["mask"][sample["mask"] == 5] = 4 - sample["mask"][sample["mask"] == 6] = 4 - sample["mask"] = sample["mask"].long() - - return sample - - def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Removes the bounding box property from a sample. - - Args: - sample: dictionary with geographic metadata - - Returns - sample without the bbox property - """ - del sample["bbox"] - return sample - - def nodata_check( - self, size: int = 512 - ) -> Callable[[Dict[str, Tensor]], Dict[str, Tensor]]: - """Returns a function to check for nodata or mis-sized input. - - Args: - size: output image size - - Returns: - function to check for nodata values - """ - - def nodata_check_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: - num_channels, height, width = sample["image"].shape - - if height < size or width < size: - sample["image"] = torch.zeros((num_channels, size, size)) - sample["mask"] = torch.zeros((size, size)) - - return sample - - return nodata_check_inner - - def prepare_data(self) -> None: - """Confirms that the dataset is downloaded on the local node. - - This method is called once per node, while :func:`setup` is called once per GPU. - """ - if self.kwargs.get("download", False): - ChesapeakeCVPR(splits=self.train_splits, layers=self.layers, **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. Args: - stage: stage to set up - """ - train_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - self.remove_bbox, - ] - ) - val_transforms = Compose( - [ - self.center_crop(self.patch_size), - self.nodata_check(self.patch_size), - self.preprocess, - self.remove_bbox, - ] - ) - test_transforms = Compose( - [ - self.pad_to(self.original_patch_size, image_value=0, mask_value=0), - self.preprocess, - self.remove_bbox, - ] - ) - - self.train_dataset = ChesapeakeCVPR( - splits=self.train_splits, - layers=self.layers, - transforms=train_transforms, - **self.kwargs, - ) - self.val_dataset = ChesapeakeCVPR( - splits=self.val_splits, - layers=self.layers, - transforms=val_transforms, - **self.kwargs, - ) - self.test_dataset = ChesapeakeCVPR( - splits=self.test_splits, - layers=self.layers, - transforms=test_transforms, - **self.kwargs, - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - sampler = RandomBatchGeoSampler( - self.train_dataset, - size=self.original_patch_size, - batch_size=self.batch_size, - length=self.patches_per_tile * len(self.train_dataset), - ) - return DataLoader( - self.train_dataset, - batch_sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. Returns: - validation data loader + A batch of data. """ - sampler = GridGeoSampler( - self.val_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + if self.use_prior_labels: + batch["mask"] = F.normalize(batch["mask"].float(), p=1, dim=1) + batch["mask"] = F.normalize( + batch["mask"] + self.prior_smoothing_constant, p=1, dim=1 + ).long() + else: + if self.class_set == 5: + batch["mask"][batch["mask"] == 5] = 4 + batch["mask"][batch["mask"] == 6] = 4 - Returns: - testing data loader - """ - sampler = GridGeoSampler( - self.test_dataset, - size=self.original_patch_size, - stride=self.original_patch_size, - ) - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - sampler=sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index f26db807f94..799850f4e4a 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -3,124 +3,41 @@ """COWC datamodule.""" -from typing import Any, Dict, Optional +from typing import Any -import matplotlib.pyplot as plt -import pytorch_lightning as pl from torch import Generator -from torch.utils.data import DataLoader, random_split +from torch.utils.data import random_split from ..datasets import COWCCounting +from .geo import NonGeoDataModule -class COWCCountingDataModule(pl.LightningDataModule): +class COWCCountingDataModule(NonGeoDataModule): """LightningDataModule implementation for the COWC Counting dataset.""" def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for COWC Counting based DataLoaders. + """Initialize a new COWCCountingDataModule instance. Args: - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.COWCCounting` + :class:`~torchgeo.datasets.COWCCounting`. """ - super().__init__() - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(COWCCounting, batch_size, num_workers, **kwargs) - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 # scale to [0, 1] - if "label" in sample: - sample["label"] = sample["label"].float() - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if self.kwargs.get("download", False): - COWCCounting(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - Args: - stage: stage to set up - """ - train_val_dataset = COWCCounting( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = COWCCounting( - split="test", transforms=self.preprocess, **self.kwargs - ) + self.dataset = COWCCounting(split="train", **self.kwargs) + self.test_dataset = COWCCounting(split="test", **self.kwargs) self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [len(train_val_dataset) - len(self.test_dataset), len(self.test_dataset)], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, + self.dataset, + [len(self.dataset) - len(self.test_dataset), len(self.test_dataset)], + generator=Generator().manual_seed(0), ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.COWC.plot`. - - .. versionadded:: 0.2 - """ - return self.test_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 0a7f1a4eab4..b3c8d3121a9 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -3,149 +3,60 @@ """Tropical Cyclone Wind Estimation Competition datamodule.""" -from typing import Any, Dict, Optional +from typing import Any -import pytorch_lightning as pl -import torch from sklearn.model_selection import GroupShuffleSplit -from torch.utils.data import DataLoader, Subset +from torch.utils.data import Subset from ..datasets import TropicalCyclone +from .geo import NonGeoDataModule -class TropicalCycloneDataModule(pl.LightningDataModule): +class TropicalCycloneDataModule(NonGeoDataModule): """LightningDataModule implementation for the NASA Cyclone dataset. Implements 80/20 train/val splits based on hurricane storm ids. See :func:`setup` for more details. - .. versionchanged:: 0.4.0 + .. versionchanged:: 0.4 Class name changed from CycloneDataModule to TropicalCycloneDataModule to be consistent with TropicalCyclone dataset. """ def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for NASA Cyclone based DataLoaders. + """Initialize a new TropicalCycloneDataModule instance. Args: - seed: The seed value to use when doing the sklearn based GroupShuffleSplit - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.TropicalCyclone` + :class:`~torchgeo.datasets.TropicalCyclone`. """ - super().__init__() - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(TropicalCyclone, batch_size, num_workers, **kwargs) - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - sample: dictionary containing image and target - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = ( - sample["image"].unsqueeze(0).repeat(3, 1, 1) - ) # convert from grayscale to 3 channel - if "label" in sample: - sample["label"] = torch.as_tensor(sample["label"]).float() - - return sample - - def prepare_data(self) -> None: - """Initialize the main ``Dataset`` objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - if self.kwargs.get("download", False): - TropicalCyclone(split="train", **self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train/val by the ``storm_id`` property. I.e. all - samples with the same ``storm_id`` value will be either in the train or the val - split. This is important to test one type of generalizability -- given a new - storm, can we predict its windspeed. The test set, however, contains *some* - storms from the training set (specifically, the latter parts of the storms) as - well as some novel storms. - - Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.all_train_dataset = TropicalCyclone( - split="train", transforms=self.preprocess, **self.kwargs - ) - - self.all_test_dataset = TropicalCyclone( - split="test", transforms=self.preprocess, **self.kwargs - ) - - storm_ids = [] - for item in self.all_train_dataset.collection: - storm_id = item["href"].split("/")[0].split("_")[-2] - storm_ids.append(storm_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - storm_ids, groups=storm_ids + if stage in ["fit", "validate"]: + self.dataset = TropicalCyclone(split="train", **self.kwargs) + + storm_ids = [] + for item in self.dataset.collection: + storm_id = item["href"].split("/")[0].split("_")[-2] + storm_ids.append(storm_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split( + storm_ids, groups=storm_ids + ) ) - ) - - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) + if stage in ["test"]: + self.test_dataset = TropicalCyclone(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 08871f38ad7..fdf00265016 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,23 +3,19 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from einops import rearrange -from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader +import kornia.augmentation as K from ..datasets import DeepGlobeLandCover from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class DeepGlobeLandCoverDataModule(pl.LightningDataModule): +class DeepGlobeLandCoverDataModule(NonGeoDataModule): """LightningDataModule implementation for the DeepGlobe Land Cover dataset. Uses the train/test splits from the dataset. @@ -27,133 +23,44 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. - - The DeepGlobe Land Cover dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. + """Initialize a new DeepGlobeLandCoverDataModule instance. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.DeepGlobeLandCover` + :class:`~torchgeo.datasets.DeepGlobeLandCover`. """ - super().__init__() + super().__init__(DeepGlobeLandCover, 1, num_workers, **kwargs) - self.num_tiles_per_batch = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs - self.train_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) - self.test_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _ExtractTensorPatches(self.patch_size), + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up - """ - train_dataset = DeepGlobeLandCover(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. - - .. versionadded:: 0.4 + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["fit", "validate"]: + self.dataset = DeepGlobeLandCover(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, self.val_split_pct + ) + if stage in ["test"]: + self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 101c9a3a318..3c4a8dd1e83 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -3,19 +3,16 @@ """ETCI 2021 datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch -from torch import Generator -from torch.utils.data import DataLoader, random_split -from torchvision.transforms import Normalize +from torch import Tensor from ..datasets import ETCI2021 +from .geo import NonGeoDataModule -class ETCI2021DataModule(pl.LightningDataModule): +class ETCI2021DataModule(NonGeoDataModule): """LightningDataModule implementation for the ETCI2021 dataset. Splits the existing train split from the dataset into train/val with 80/20 @@ -24,128 +21,62 @@ class ETCI2021DataModule(pl.LightningDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( - [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701] + mean = torch.tensor( + [ + 128.02253931, + 128.02253931, + 128.02253931, + 128.11221701, + 128.11221701, + 128.11221701, + ] ) - - band_stds = torch.tensor( - [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622] + std = torch.tensor( + [89.8145088, 89.8145088, 89.8145088, 95.2797861, 95.2797861, 95.2797861] ) def __init__( - self, seed: int = 0, batch_size: int = 64, num_workers: int = 0, **kwargs: Any + self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for ETCI2021 based DataLoaders. + """Initialize a new ETCI2021DataModule instance. Args: - seed: The seed value to use when doing the dataset random_split - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.ETCI2021` + :class:`~torchgeo.datasets.ETCI2021`. """ - super().__init__() - self.seed = seed - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. + super().__init__(ETCI2021, batch_size, num_workers, **kwargs) - Notably, moves the given water mask to act as an input layer. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - - if "mask" in sample: - flood_mask = sample["mask"][1] - flood_mask = (flood_mask > 0).long() - sample["mask"] = flood_mask - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - if self.kwargs.get("download", False): - ETCI2021(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + if stage in ["fit"]: + self.train_dataset = ETCI2021(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = ETCI2021(split="val", **self.kwargs) + if stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = ETCI2021(split="test", **self.kwargs) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. Args: - stage: stage to set up - """ - train_val_dataset = ETCI2021( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = ETCI2021( - split="val", transforms=self.preprocess, **self.kwargs - ) - - size_train_val = len(train_val_dataset) - size_train = round(0.8 * size_train_val) - size_val = size_train_val - size_train - - self.train_dataset, self.val_dataset = random_split( - train_val_dataset, - [size_train, size_val], - generator=Generator().manual_seed(self.seed), - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. Returns: - validation data loader + A batch of data. """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + if self.trainer: + if not self.trainer.predicting: + # Evaluate against flood mask, not water mask + batch["mask"] = (batch["mask"][:, 1] > 0).long() - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.ETCI2021.plot`.""" - return self.test_dataset.plot(*args, **kwargs) + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/eurosat.py b/torchgeo/datamodules/eurosat.py index cb672011fb6..44f7c0ec7a4 100644 --- a/torchgeo/datamodules/eurosat.py +++ b/torchgeo/datamodules/eurosat.py @@ -3,18 +3,15 @@ """EuroSAT datamodule.""" -from typing import Any, Dict, Optional +from typing import Any -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import EuroSAT +from .geo import NonGeoDataModule -class EuroSATDataModule(pl.LightningDataModule): +class EuroSATDataModule(NonGeoDataModule): """LightningDataModule implementation for the EuroSAT dataset. Uses the train/val/test splits from the dataset. @@ -22,7 +19,7 @@ class EuroSATDataModule(pl.LightningDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( + mean = torch.tensor( [ 1354.40546513, 1118.24399958, @@ -40,7 +37,7 @@ class EuroSATDataModule(pl.LightningDataModule): ] ) - band_stds = torch.tensor( + std = torch.tensor( [ 245.71762908, 333.00778264, @@ -61,97 +58,12 @@ class EuroSATDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for EuroSAT based DataLoaders. + """Initialize a new EuroSATDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.EuroSAT` + :class:`~torchgeo.datasets.EuroSAT`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - EuroSAT(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = EuroSAT( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = EuroSAT(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = EuroSAT(split="test", transforms=transforms, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.EuroSAT.plot`.""" - return self.val_dataset.plot(*args, **kwargs) + super().__init__(EuroSAT, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index da923be83a8..24b6f46eae1 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -3,33 +3,14 @@ """FAIR1M datamodule.""" -from typing import Any, Dict, List, Optional - -import matplotlib.pyplot as plt -import pytorch_lightning as pl -import torch -from torch import Tensor -from torch.utils.data import DataLoader +from typing import Any from ..datasets import FAIR1M +from .geo import NonGeoDataModule from .utils import dataset_split -def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: - """Custom object detection collate fn to handle variable number of boxes. - - Args: - batch: list of sample dicts return by dataset - Returns: - batch dict output - """ - output: Dict[str, Any] = {} - output["image"] = torch.stack([sample["image"] for sample in batch]) - output["boxes"] = [sample["boxes"] for sample in batch] - return output - - -class FAIR1MDataModule(pl.LightningDataModule): +class FAIR1MDataModule(NonGeoDataModule): """LightningDataModule implementation for the FAIR1M dataset. .. versionadded:: 0.2 @@ -43,94 +24,28 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for FAIR1M based DataLoaders. + """Initialize a new FAIR1MDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.FAIR1M` + :class:`~torchgeo.datasets.FAIR1M`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(FAIR1M, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = FAIR1M(transforms=self.preprocess, **self.kwargs) + self.dataset = FAIR1M(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.FAIR1M.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py new file mode 100644 index 00000000000..c1b1944c933 --- /dev/null +++ b/torchgeo/datamodules/geo.py @@ -0,0 +1,556 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Base classes for all :mod:`torchgeo` data modules.""" + +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +import kornia.augmentation as K +import matplotlib.pyplot as plt +import torch +from pytorch_lightning import LightningDataModule +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, default_collate + +from ..datasets import GeoDataset, NonGeoDataset, stack_samples +from ..samplers import ( + BatchGeoSampler, + GeoSampler, + GridGeoSampler, + RandomBatchGeoSampler, +) +from ..transforms import AugmentationSequential +from .utils import MisconfigurationException + + +class GeoDataModule(LightningDataModule): + """Base class for data modules containing geospatial information. + + .. versionadded:: 0.4 + """ + + mean = torch.tensor(0) + std = torch.tensor(255) + + def __init__( + self, + dataset_class: Type[GeoDataset], + batch_size: int = 1, + patch_size: Union[int, Tuple[int, int]] = 64, + length: int = 1000, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new GeoDataModule instance. + + Args: + dataset_class: Class used to instantiate a new dataset. + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to ``dataset_class`` + """ + super().__init__() + + self.dataset_class = dataset_class + self.batch_size = batch_size + self.patch_size = patch_size + self.length = length + self.num_workers = num_workers + self.kwargs = kwargs + + # Datasets + self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + # Samplers + self.sampler: Optional[GeoSampler] = None + self.train_sampler: Optional[GeoSampler] = None + self.val_sampler: Optional[GeoSampler] = None + self.test_sampler: Optional[GeoSampler] = None + self.predict_sampler: Optional[GeoSampler] = None + + # Batch samplers + self.batch_sampler: Optional[BatchGeoSampler] = None + self.train_batch_sampler: Optional[BatchGeoSampler] = None + self.val_batch_sampler: Optional[BatchGeoSampler] = None + self.test_batch_sampler: Optional[BatchGeoSampler] = None + self.predict_batch_sampler: Optional[BatchGeoSampler] = None + + # Data loaders + self.train_batch_size: Optional[int] = None + self.val_batch_size: Optional[int] = None + self.test_batch_size: Optional[int] = None + self.predict_batch_size: Optional[int] = None + + # Collation + self.collate_fn = stack_samples + + # Data augmentation + Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + self.aug: Transform = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + ) + self.train_aug: Optional[Transform] = None + self.val_aug: Optional[Transform] = None + self.test_aug: Optional[Transform] = None + self.predict_aug: Optional[Transform] = None + + def prepare_data(self) -> None: + """Download and prepare data. + + During distributed training, this method is called only within a single process + to avoid corrupted data. This method should not set state since it is not called + on every device, use :meth:`setup` instead. + """ + if self.kwargs.get("download", False): + self.dataset_class(**self.kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = self.dataset_class( # type: ignore[call-arg] + split="train", **self.kwargs + ) + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( # type: ignore[call-arg] + split="val", **self.kwargs + ) + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( # type: ignore[call-arg] + split="test", **self.kwargs + ) + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for training. + + Returns: + A collection of data loaders specifying training samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'train_dataset'. + """ + dataset = self.train_dataset or self.dataset + sampler = self.train_sampler or self.sampler + batch_sampler = self.train_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: + batch_size = self.train_batch_size or self.batch_size + if batch_sampler is not None: + batch_size = 1 + sampler = None + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" + raise MisconfigurationException(msg) + + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for validation. + + Returns: + A collection of data loaders specifying validation samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'val_dataset'. + """ + dataset = self.val_dataset or self.dataset + sampler = self.val_sampler or self.sampler + batch_sampler = self.val_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: + batch_size = self.val_batch_size or self.batch_size + if batch_sampler is not None: + batch_size = 1 + sampler = None + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" + raise MisconfigurationException(msg) + + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for testing. + + Returns: + A collection of data loaders specifying testing samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'test_dataset'. + """ + dataset = self.test_dataset or self.dataset + sampler = self.test_sampler or self.sampler + batch_sampler = self.test_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: + batch_size = self.test_batch_size or self.batch_size + if batch_sampler is not None: + batch_size = 1 + sampler = None + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" + raise MisconfigurationException(msg) + + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for prediction. + + Returns: + A collection of data loaders specifying prediction samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'predict_dataset'. + """ + dataset = self.predict_dataset or self.dataset + sampler = self.predict_sampler or self.sampler + batch_sampler = self.predict_batch_sampler or self.batch_sampler + if dataset is not None and (sampler or batch_sampler) is not None: + batch_size = self.predict_batch_size or self.batch_size + if batch_sampler is not None: + batch_size = 1 + sampler = None + return DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" + raise MisconfigurationException(msg) + + def transfer_batch_to_device( + self, batch: Dict[str, Tensor], device: torch.device, dataloader_idx: int + ) -> Dict[str, Tensor]: + """Transfer batch to device. + + Defines how custom data types are moved to the target device. + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A reference to the data on the new device. + """ + # Non-Tensor values cannot be moved to a device + del batch["crs"] + del batch["bbox"] + + batch = super().transfer_batch_to_device(batch, device, dataloader_idx) + return batch + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if self.trainer: + if self.trainer.training: + aug = self.train_aug or self.aug + elif self.trainer.validating or self.trainer.sanity_checking: + aug = self.val_aug or self.aug + elif self.trainer.testing: + aug = self.test_aug or self.aug + elif self.trainer.predicting: + aug = self.predict_aug or self.aug + + batch = aug(batch) + + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run the plot method of the validation dataset if one exists. + + Should only be called during 'fit' or 'validate' stages as ``val_dataset`` + may not exist during other stages. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A matplotlib Figure with the image, ground truth, and predictions. + """ + dataset = self.val_dataset or self.dataset + if dataset is not None: + if hasattr(dataset, "plot"): + return dataset.plot(*args, **kwargs) + + +class NonGeoDataModule(LightningDataModule): + """Base class for data modules lacking geospatial information. + + .. versionadded:: 0.4 + """ + + mean = torch.tensor(0) + std = torch.tensor(255) + + def __init__( + self, + dataset_class: Type[NonGeoDataset], + batch_size: int = 1, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new NonGeoDataModule instance. + + Args: + dataset_class: Class used to instantiate a new dataset. + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to ``dataset_class`` + """ + super().__init__() + + self.dataset_class = dataset_class + self.batch_size = batch_size + self.num_workers = num_workers + self.kwargs = kwargs + + # Datasets + self.dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.train_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.val_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.test_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + self.predict_dataset: Optional[Dataset[Dict[str, Tensor]]] = None + + # Data loaders + self.train_batch_size: Optional[int] = None + self.val_batch_size: Optional[int] = None + self.test_batch_size: Optional[int] = None + self.predict_batch_size: Optional[int] = None + + # Collation + self.collate_fn = default_collate + + # Data augmentation + Transform = Callable[[Dict[str, Tensor]], Dict[str, Tensor]] + self.aug: Transform = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image"] + ) + self.train_aug: Optional[Transform] = None + self.val_aug: Optional[Transform] = None + self.test_aug: Optional[Transform] = None + self.predict_aug: Optional[Transform] = None + + def prepare_data(self) -> None: + """Download and prepare data. + + During distributed training, this method is called only within a single process + to avoid corrupted data. This method should not set state since it is not called + on every device, use :meth:`setup` instead. + """ + if self.kwargs.get("download", False): + self.dataset_class(**self.kwargs) + + def setup(self, stage: str) -> None: + """Set up datasets. + + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = self.dataset_class( # type: ignore[call-arg] + split="train", **self.kwargs + ) + if stage in ["fit", "validate"]: + self.val_dataset = self.dataset_class( # type: ignore[call-arg] + split="val", **self.kwargs + ) + if stage in ["test"]: + self.test_dataset = self.dataset_class( # type: ignore[call-arg] + split="test", **self.kwargs + ) + + def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for training. + + Returns: + A collection of data loaders specifying training samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'train_dataset'. + """ + dataset = self.train_dataset or self.dataset + if dataset is not None: + return DataLoader( + dataset=dataset, + batch_size=self.train_batch_size or self.batch_size, + shuffle=True, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'train_dataset'" + raise MisconfigurationException(msg) + + def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for validation. + + Returns: + A collection of data loaders specifying validation samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'val_dataset'. + """ + dataset = self.val_dataset or self.dataset + if dataset is not None: + return DataLoader( + dataset=dataset, + batch_size=self.val_batch_size or self.batch_size, + shuffle=False, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'val_dataset'" + raise MisconfigurationException(msg) + + def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for testing. + + Returns: + A collection of data loaders specifying testing samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'test_dataset'. + """ + dataset = self.test_dataset or self.dataset + if dataset is not None: + return DataLoader( + dataset=dataset, + batch_size=self.test_batch_size or self.batch_size, + shuffle=False, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'test_dataset'" + raise MisconfigurationException(msg) + + def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: + """Implement one or more PyTorch DataLoaders for prediction. + + Returns: + A collection of data loaders specifying prediction samples. + + Raises: + MisconfigurationException: If :meth:`setup` does not define a + 'predict_dataset'. + """ + dataset = self.predict_dataset or self.dataset + if dataset is not None: + return DataLoader( + dataset=dataset, + batch_size=self.predict_batch_size or self.batch_size, + shuffle=False, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + else: + msg = f"{self.__class__.__name__}.setup does not define a 'predict_dataset'" + raise MisconfigurationException(msg) + + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. + + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. + + Returns: + A batch of data. + """ + if self.trainer: + if self.trainer.training: + aug = self.train_aug or self.aug + elif self.trainer.validating or self.trainer.sanity_checking: + aug = self.val_aug or self.aug + elif self.trainer.testing: + aug = self.test_aug or self.aug + elif self.trainer.predicting: + aug = self.predict_aug or self.aug + + batch = aug(batch) + + return batch + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run the plot method of the validation dataset if one exists. + + Should only be called during 'fit' or 'validate' stages as ``val_dataset`` + may not exist during other stages. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A matplotlib Figure with the image, ground truth, and predictions. + """ + dataset = self.dataset or self.val_dataset + if dataset is not None: + if hasattr(dataset, "plot"): + return dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/gid15.py b/torchgeo/datamodules/gid15.py index 045509809e2..1297741eb2f 100644 --- a/torchgeo/datamodules/gid15.py +++ b/torchgeo/datamodules/gid15.py @@ -3,23 +3,19 @@ """GID-15 datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from einops import rearrange -from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader +import kornia.augmentation as K from ..datasets import GID15 from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class GID15DataModule(pl.LightningDataModule): +class GID15DataModule(NonGeoDataModule): """LightningDataModule implementation for the GID-15 dataset. Uses the train/test splits from the dataset. @@ -29,149 +25,50 @@ class GID15DataModule(pl.LightningDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. - - The GID-15 dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. + """Initialize a new GID15DataModule instance. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.GID15` + :class:`~torchgeo.datasets.GID15`. """ - super().__init__() + super().__init__(GID15, 1, num_workers, **kwargs) - self.num_tiles_per_batch = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs - self.train_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) - self.val_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _ExtractTensorPatches(self.patch_size), + self.train_aug = self.val_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) - self.predict_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _ExtractTensorPatches(self.patch_size), + self.predict_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image"], ) - def prepare_data(self) -> None: - """Initialize the main Dataset objects for use in :func:`setup`. - - This includes optionally downloading the dataset. This is done once per node, - while :func:`setup` is done once per GPU. - """ - if self.kwargs.get("download", False): - GID15(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_dataset = GID15(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - - # Test set masks are not public, use for prediction instead - self.predict_dataset = GID15(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def predict_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for predicting. - - Returns: - predicting data loader - """ - return DataLoader( - self.predict_dataset, - batch_size=1, - num_workers=self.num_workers, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - # Kornia requires masks to have a channel dimension - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating: - batch = self.val_transform(batch) - elif self.trainer.predicting: - batch = self.predict_transform(batch) - - # Torchmetrics does not support masks with a channel dimension - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.GID15.plot`.""" - return self.predict_dataset.plot(*args, **kwargs) + if stage in ["fit", "validate"]: + self.dataset = GID15(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, self.val_split_pct + ) + if stage in ["test"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = GID15(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 2c72de9b69c..bc524e5e921 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,34 +3,19 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Tuple, Union import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl -import torch -import torchvision.transforms as T -from einops import rearrange -from kornia.contrib import compute_padding, extract_tensor_patches -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate from ..datasets import InriaAerialImageLabeling from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - """Flatten wrapper.""" - r_batch: Dict[str, Any] = default_collate(batch) # type: ignore[no-untyped-call] - r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) - if "mask" in r_batch: - r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) - - return r_batch - - -class InriaAerialImageLabelingDataModule(pl.LightningDataModule): +class InriaAerialImageLabelingDataModule(NonGeoDataModule): """LightningDataModule implementation for the InriaAerialImageLabeling dataset. Uses the train/test splits from the dataset and further splits @@ -39,214 +24,62 @@ class InriaAerialImageLabelingDataModule(pl.LightningDataModule): .. versionadded:: 0.3 """ - h, w = 5000, 5000 - def __init__( self, - batch_size: int = 32, + batch_size: int = 64, + patch_size: Union[Tuple[int, int], int] = 64, num_workers: int = 0, val_split_pct: float = 0.1, test_split_pct: float = 0.1, - patch_size: Union[int, Tuple[int, int]] = 512, - num_patches_per_tile: int = 32, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for InriaAerialImageLabeling. + """Initialize a new InriaAerialImageLabelingDataModule instance. Args: - batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: Number of random patches per sample + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.InriaAerialImageLabeling` + :class:`~torchgeo.datasets.InriaAerialImageLabeling`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(InriaAerialImageLabeling, 1, num_workers, **kwargs) + + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.patch_size = _to_tuple(patch_size) - self.num_patches_per_tile = num_patches_per_tile - self.kwargs = kwargs - self.augmentations = K.AugmentationSequential( + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), K.RandomHorizontalFlip(p=0.5), K.RandomVerticalFlip(p=0.5), - data_keys=["input", "mask"], + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image", "mask"], ) - self.random_crop = K.AugmentationSequential( - K.RandomCrop(self.patch_size, p=1.0, keepdim=False), - data_keys=["input", "mask"], + self.val_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image", "mask"], ) - - def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Extract patches from single sample.""" - assert sample["image"].ndim == 3 - _, h, w = sample["image"].shape - - padding = compute_padding((h, w), self.patch_size) - sample["original_shape"] = (h, w) - sample["patch_shape"] = self.patch_size - sample["padding"] = padding - sample["image"] = extract_tensor_patches( - sample["image"].unsqueeze(0), - self.patch_size, - self.patch_size, - padding=padding, + self.predict_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image"], ) - # Needed for reconstruction of patches later - sample["num_patches"] = sample["image"].shape[1] - sample["image"] = rearrange(sample["image"], "b n c h w -> (b n) c h w") - return sample - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - sample: input image dictionary - - Returns: - preprocessed sample + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) - - if "mask" in sample: - sample["mask"] = rearrange(sample["mask"], "h w -> () h w") - - return sample - - def n_random_crop(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Get n random crops.""" - images, masks = [], [] - for _ in range(self.num_patches_per_tile): - image, mask = sample["image"], sample["mask"] - # RandomCrop needs image and mask to be in float - mask = mask.to(torch.float) - image, mask = self.random_crop(image, mask) - images.append(image.squeeze()) - masks.append(mask.squeeze(0).long()) - sample["image"] = torch.stack(images) # (t,c,h,w) - sample["mask"] = torch.stack(masks) # (t, 1, h, w) - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - """ - train_transforms = T.Compose([self.preprocess, self.n_random_crop]) - test_transforms = T.Compose([self.preprocess, self.patch_sample]) - - self.dataset = InriaAerialImageLabeling( - split="train", transforms=train_transforms, **self.kwargs - ) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - self.test_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - if self.test_split_pct > 0.0: - self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, - val_pct=self.val_split_pct, - test_pct=self.test_split_pct, - ) - else: - self.train_dataset, self.val_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct - ) - self.test_dataset = self.val_dataset - else: - self.train_dataset = self.dataset - self.val_dataset = self.dataset - self.test_dataset = self.dataset - - self.predict_dataset = InriaAerialImageLabeling( - split="test", transforms=test_transforms, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=1, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=1, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=False, - ) - - def predict_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for prediction.""" - return DataLoader( - self.predict_dataset, - batch_size=1, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=False, - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Any], dataloader_idx: int - ) -> Dict[str, Any]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch (dict): A batch of data that needs to be altered or augmented. - dataloader_idx (int): The index of the dataloader to which the batch - belongs. - - Returns: - dict: A batch of data - """ - # Training - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - and self.augmentations is not None - ): - batch["mask"] = batch["mask"].to(torch.float) - batch["image"], batch["mask"] = self.augmentations( - batch["image"], batch["mask"] + if stage in ["fit", "validate", "test"]: + self.dataset = InriaAerialImageLabeling(split="train", **self.kwargs) + self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( + self.dataset, self.val_split_pct, self.test_split_pct ) - batch["mask"] = batch["mask"].to(torch.long) - - # Validation - if "mask" in batch: - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.InriaAerialImageLabeling.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) + if stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = InriaAerialImageLabeling(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index 67f952e37b3..0ac1ec385ac 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -3,17 +3,16 @@ """LandCover.ai datamodule.""" -from typing import Any, Dict, Optional +from typing import Any import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader from ..datasets import LandCoverAI +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class LandCoverAIDataModule(pl.LightningDataModule): +class LandCoverAIDataModule(NonGeoDataModule): """LightningDataModule implementation for the LandCover.ai dataset. Uses the train/val/test splits from the dataset. @@ -22,154 +21,32 @@ class LandCoverAIDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for LandCover.ai based DataLoaders. + """Initialize a new LandCoverAIDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.LandCoverAI` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - # Kornia expects masks to be floats with a channel dimension - x = batch["image"] - y = batch["mask"].float().unsqueeze(1) - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input", "mask"], - ) - x, y = train_augmentations(x, y) - - # torchmetrics expects masks to be longs without a channel dimension - batch["image"] = x - batch["mask"] = y.squeeze(1).long() - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - if "mask" in sample: - sample["mask"] = sample["mask"].long() - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - LandCoverAI(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_test_transforms = self.preprocess - - self.train_dataset = LandCoverAI( - split="train", transforms=train_transforms, **self.kwargs + :class:`~torchgeo.datasets.LandCoverAI`. + """ + super().__init__(LandCoverAI, batch_size, num_workers, **kwargs) + + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image", "mask"], ) - - self.val_dataset = LandCoverAI( - split="val", transforms=val_test_transforms, **self.kwargs - ) - - self.test_dataset = LandCoverAI( - split="test", transforms=val_test_transforms, **self.kwargs + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.LandCoverAI.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index a82233f9844..f8462ff6588 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -3,16 +3,13 @@ """LoveDA datamodule.""" -from typing import Any, Dict, Optional - -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader +from typing import Any from ..datasets import LoveDA +from .geo import NonGeoDataModule -class LoveDADataModule(pl.LightningDataModule): +class LoveDADataModule(NonGeoDataModule): """LightningDataModule implementation for the LoveDA dataset. Uses the train/val/test splits from the dataset. @@ -23,107 +20,26 @@ class LoveDADataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 32, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for LoveDA based DataLoaders. + """Initialize a new LoveDADataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.LoveDA` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. + :class:`~torchgeo.datasets.LoveDA`. """ - if self.kwargs.get("download", False): - LoveDA(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + super().__init__(LoveDA, batch_size, num_workers, **kwargs) - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up - """ - train_transforms = self.preprocess - val_predict_transforms = self.preprocess - - self.train_dataset = LoveDA( - split="train", transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = LoveDA( - split="val", transforms=val_predict_transforms, **self.kwargs - ) - - # Test set masks are not public, use for prediction instead - self.predict_dataset = LoveDA( - split="test", transforms=val_predict_transforms, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def predict_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for prediction. - - Returns: - predict data loader - """ - return DataLoader( - self.predict_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.LoveDA.plot`. - - .. versionadded:: 0.4 - """ - return self.train_dataset.plot(*args, **kwargs) + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit"]: + self.train_dataset = LoveDA(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = LoveDA(split="val", **self.kwargs) + if stage in ["predict"]: + # Test set masks are not public, use for prediction instead + self.predict_dataset = LoveDA(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index f9e9038da8d..63ef8c8501e 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -3,51 +3,43 @@ """National Agriculture Imagery Program (NAIP) datamodule.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Tuple, Union +import kornia.augmentation as K import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader -from torchvision.transforms import Compose -from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples -from ..samplers.batch import RandomBatchGeoSampler -from ..samplers.single import GridGeoSampler +from ..datasets import NAIP, BoundingBox, Chesapeake13 +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..transforms import AugmentationSequential +from .geo import GeoDataModule -class NAIPChesapeakeDataModule(pl.LightningDataModule): +class NAIPChesapeakeDataModule(GeoDataModule): """LightningDataModule implementation for the NAIP and Chesapeake datasets. Uses the train/val/test splits from the dataset. """ - # TODO: tune these hyperparams - length = 1000 - stride = 128 - def __init__( self, batch_size: int = 64, + patch_size: Union[int, Tuple[int, int]] = 256, + length: int = 1000, num_workers: int = 0, - patch_size: int = 256, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for NAIP and Chesapeake based DataLoaders. + """Initialize a new NAIPChesapeakeDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - patch_size: size of patches to sample + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.NAIP` (prefix keys with ``naip_``) and :class:`~torchgeo.datasets.Chesapeake13` - (prefix keys with ``chesapeake_``) + (prefix keys with ``chesapeake_``). """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.patch_size = patch_size - self.naip_kwargs = {} self.chesapeake_kwargs = {} for key, val in kwargs.items(): @@ -56,145 +48,63 @@ def __init__( elif key.startswith("chesapeake_"): self.chesapeake_kwargs[key[11:]] = val - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the NAIP Dataset. - - Args: - sample: NAIP image dictionary - - Returns: - preprocessed NAIP data - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - - return sample - - def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Chesapeake Dataset. - - Args: - sample: Chesapeake mask dictionary - - Returns: - preprocessed Chesapeake data - """ - sample["mask"] = sample["mask"].long()[0] - - return sample - - def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Removes the bounding box property from a sample. - - Args: - sample: dictionary with geographic metadata - - Returns - sample without the bbox property - """ - del sample["bbox"] - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.chesapeake_kwargs.get("download", False): - Chesapeake13(**self.chesapeake_kwargs) + super().__init__( + Chesapeake13, + batch_size, + patch_size, + length, + num_workers, + **self.chesapeake_kwargs, + ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + ) - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets and samplers. Args: - stage: state to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - # TODO: these transforms will be applied independently, this won't work if we - # add things like random horizontal flip - - naip_transforms = Compose([self.preprocess, self.remove_bbox]) - chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox]) - - self.chesapeake = Chesapeake13( - transforms=chesapeak_transforms, **self.chesapeake_kwargs - ) - self.naip = NAIP( - crs=self.chesapeake.crs, - res=self.chesapeake.res, - transforms=naip_transforms, - **self.naip_kwargs, - ) + self.chesapeake = Chesapeake13(**self.chesapeake_kwargs) + self.naip = NAIP(**self.naip_kwargs) self.dataset = self.chesapeake & self.naip - # TODO: figure out better train/val/test split roi = self.dataset.bounds midx = roi.minx + (roi.maxx - roi.minx) / 2 midy = roi.miny + (roi.maxy - roi.miny) / 2 - train_roi = BoundingBox(roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt) - val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) - test_roi = BoundingBox(roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt) - - self.train_sampler = RandomBatchGeoSampler( - self.naip, self.patch_size, self.batch_size, self.length, train_roi - ) - self.val_sampler = GridGeoSampler( - self.naip, self.patch_size, self.stride, val_roi - ) - self.test_sampler = GridGeoSampler( - self.naip, self.patch_size, self.stride, test_roi - ) - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. + if stage in ["fit"]: + train_roi = BoundingBox( + roi.minx, midx, roi.miny, roi.maxy, roi.mint, roi.maxt + ) + self.train_batch_sampler = RandomBatchGeoSampler( + self.dataset, self.patch_size, self.batch_size, self.length, train_roi + ) + if stage in ["fit", "validate"]: + val_roi = BoundingBox(midx, roi.maxx, roi.miny, midy, roi.mint, roi.maxt) + self.val_sampler = GridGeoSampler( + self.dataset, self.patch_size, self.patch_size, val_roi + ) + if stage in ["test"]: + test_roi = BoundingBox( + roi.minx, roi.maxx, midy, roi.maxy, roi.mint, roi.maxt + ) + self.test_sampler = GridGeoSampler( + self.dataset, self.patch_size, self.patch_size, test_roi + ) + + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: + """Run NAIP plot method. - Returns: - training data loader - """ - return DataLoader( - self.dataset, - batch_sampler=self.train_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.val_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. Returns: - testing data loader - """ - return DataLoader( - self.dataset, - batch_size=self.batch_size, - sampler=self.test_sampler, - num_workers=self.num_workers, - collate_fn=stack_samples, - ) - - def plot(self, *args: Any, **kwargs: Any) -> Tuple[plt.Figure, plt.Figure]: - """Run NAIP and Chesapeake plot methods. - - See :meth:`torchgeo.datasets.NAIP.plot` and - :meth:`torchgeo.datasets.Chesapeake.plot`. + A matplotlib Figure with the image, ground truth, and predictions. .. versionadded:: 0.4 """ - image = self.naip.plot(*args, **kwargs) - label = self.chesapeake.plot(*args, **kwargs) - return image, label + return self.naip.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index d22324ce2fc..cdcab7f7e9c 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -3,15 +3,13 @@ """NASA Marine Debris datamodule.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch from torch import Tensor -from torch.utils.data import DataLoader from ..datasets import NASAMarineDebris +from .geo import NonGeoDataModule from .utils import dataset_split @@ -31,7 +29,7 @@ def collate_fn(batch: List[Dict[str, Tensor]]) -> Dict[str, Any]: return output -class NASAMarineDebrisDataModule(pl.LightningDataModule): +class NASAMarineDebrisDataModule(NonGeoDataModule): """LightningDataModule implementation for the NASA Marine Debris dataset. .. versionadded:: 0.2 @@ -45,102 +43,30 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for NASA Marine Debris based DataLoaders. + """Initialize a new NASAMarineDebrisDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.NASAMarineDebris` + :class:`~torchgeo.datasets.NASAMarineDebris`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(NASAMarineDebris, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. + self.collate_fn = collate_fn - This method is only called once per run. - """ - if self.kwargs.get("download", False): - NASAMarineDebris(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = NASAMarineDebris(transforms=self.preprocess, **self.kwargs) + self.dataset = NASAMarineDebris(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - collate_fn=collate_fn, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.NASAMarineDebris.plot`. - - .. versionadded:: 0.4 - """ - return self.dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 90a4c94797b..624a47c615f 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -3,21 +3,21 @@ """OSCD datamodule.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Tuple, Union import kornia.augmentation as K -import pytorch_lightning as pl import torch from einops import repeat -from torch.utils.data import DataLoader, Dataset -from torch.utils.data._utils.collate import default_collate -from torchvision.transforms import Compose, Normalize from ..datasets import OSCD +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class OSCDDataModule(pl.LightningDataModule): +class OSCDDataModule(NonGeoDataModule): """LightningDataModule implementation for the OSCD dataset. Uses the train/test splits from the dataset and further splits @@ -26,7 +26,7 @@ class OSCDDataModule(pl.LightningDataModule): .. versionadded:: 0.2 """ - band_means = torch.tensor( + mean = torch.tensor( [ 1583.0741, 1374.3202, @@ -44,7 +44,7 @@ class OSCDDataModule(pl.LightningDataModule): ] ) - band_stds = torch.tensor( + std = torch.tensor( [ 52.1937, 83.4168, @@ -64,136 +64,53 @@ class OSCDDataModule(pl.LightningDataModule): def __init__( self, - train_batch_size: int = 32, - num_workers: int = 0, + batch_size: int = 64, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, - patch_size: Tuple[int, int] = (64, 64), - num_patches_per_tile: int = 32, - pad_size: Tuple[int, int] = (1280, 1280), + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for OSCD based DataLoaders. + """Initialize a new OSCDDataModule instance. Args: - train_batch_size: The batch size used in the train DataLoader - (val_batch_size == test_batch_size == 1) - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - patch_size: Size of random patch from image and mask (height, width) - num_patches_per_tile: number of random patches per sample - pad_size: size to pad images to during val/test steps + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.OSCD` - """ - super().__init__() - self.train_batch_size = train_batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.patch_size = patch_size - self.num_patches_per_tile = num_patches_per_tile - self.kwargs = kwargs - - bands = kwargs.get("bands", "all") - if bands == "rgb": - self.band_means = self.band_means[[3, 2, 1]] - self.band_stds = self.band_stds[[3, 2, 1]] - - self.rcrop = K.AugmentationSequential( - K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True - ) - self.padto = K.PadTo(pad_size) - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - sample["image"] = torch.flatten(sample["image"], 0, 1) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. + :class:`~torchgeo.datasets.OSCD`. """ - if self.kwargs.get("download", False): - OSCD(split="train", **self.kwargs) + super().__init__(OSCD, 1, num_workers, **kwargs) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + self.patch_size = _to_tuple(patch_size) + self.val_split_pct = val_split_pct - This method is called once per GPU per run. - """ + self.bands = kwargs.get("bands", "all") + if self.bands == "rgb": + self.mean = self.mean[[3, 2, 1]] + self.std = self.std[[3, 2, 1]] - def n_random_crop(sample: Dict[str, Any]) -> Dict[str, Any]: - images, masks = [], [] - for i in range(self.num_patches_per_tile): - mask = repeat(sample["mask"], "h w -> t h w", t=2).float() - image, mask = self.rcrop(sample["image"], mask) - mask = mask.squeeze()[0] - images.append(image.squeeze()) - masks.append(mask.long()) - sample["image"] = torch.stack(images) - sample["mask"] = torch.stack(masks) - return sample - - def pad_to(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["image"] = self.padto(sample["image"])[0] - sample["mask"] = self.padto(sample["mask"].float()).long()[0, 0] - return sample - - train_transforms = Compose([self.preprocess, n_random_crop]) - # for testing and validation we pad all inputs to a fixed size to avoid issues - # with the upsampling paths in encoder-decoder architectures - test_transforms = Compose([self.preprocess, pad_to]) - - train_dataset = OSCD(split="train", transforms=train_transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - val_dataset = OSCD(split="train", transforms=test_transforms, **self.kwargs) - self.train_dataset, self.val_dataset, _ = dataset_split( - train_dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - self.val_dataset.dataset = val_dataset - else: - self.train_dataset = train_dataset - self.val_dataset = train_dataset + # Change detection, 2 images from different times + self.mean = repeat(self.mean, "c -> (t c)", t=2) + self.std = repeat(self.std, "c -> (t c)", t=2) - self.test_dataset = OSCD( - split="test", transforms=test_transforms, **self.kwargs + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), + data_keys=["image", "mask"], ) - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" + def setup(self, stage: str) -> None: + """Set up datasets. - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - r_batch: Dict[str, Any] = default_collate( # type: ignore[no-untyped-call] - batch + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + if stage in ["fit", "validate"]: + self.dataset = OSCD(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct ) - r_batch["image"] = torch.flatten(r_batch["image"], 0, 1) - r_batch["mask"] = torch.flatten(r_batch["mask"], 0, 1) - return r_batch - - return DataLoader( - self.train_dataset, - batch_size=self.train_batch_size, - num_workers=self.num_workers, - collate_fn=collate_wrapper, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) + if stage in ["test"]: + self.test_dataset = OSCD(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/potsdam.py b/torchgeo/datamodules/potsdam.py index d5024f197e9..f22558a5908 100644 --- a/torchgeo/datamodules/potsdam.py +++ b/torchgeo/datamodules/potsdam.py @@ -3,23 +3,19 @@ """Potsdam datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from einops import rearrange -from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader +import kornia.augmentation as K from ..datasets import Potsdam2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class Potsdam2DDataModule(pl.LightningDataModule): +class Potsdam2DDataModule(NonGeoDataModule): """LightningDataModule implementation for the Potsdam2D dataset. Uses the train/test splits from the dataset. @@ -29,133 +25,44 @@ class Potsdam2DDataModule(pl.LightningDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. - - The Potsdam2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. + """Initialize a new Potsdam2DDataModule instance. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.Potsdam2D` + :class:`~torchgeo.datasets.Potsdam2D`. """ - super().__init__() + super().__init__(Potsdam2D, 1, num_workers, **kwargs) - self.num_tiles_per_batch = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs - self.train_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) - self.test_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _ExtractTensorPatches(self.patch_size), + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up - """ - train_dataset = Potsdam2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = Potsdam2D(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.Potsdam2D.plot`. - - .. versionadded:: 0.4 + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["fit", "validate"]: + self.dataset = Potsdam2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, self.val_split_pct + ) + if stage in ["test"]: + self.test_dataset = Potsdam2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index 651bb47704a..cb1e9553553 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -3,167 +3,52 @@ """RESISC45 datamodule.""" -from typing import Any, Dict, Optional +from typing import Any import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import RESISC45 +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class RESISC45DataModule(pl.LightningDataModule): +class RESISC45DataModule(NonGeoDataModule): """LightningDataModule implementation for the RESISC45 dataset. Uses the train/val/test splits from the dataset. """ - band_means = torch.tensor([0.36820969, 0.38083247, 0.34341029]) - band_stds = torch.tensor([0.20339924, 0.18524736, 0.18455448]) + mean = torch.tensor([127.86820969, 127.88083247, 127.84341029]) + std = torch.tensor([51.8668062, 47.2380768, 47.0613924]) def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for RESISC45 based DataLoaders. + """Initialize a new RESISC45DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.RESISC45` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - self.norm = Normalize(self.band_means, self.band_stds) - - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - x = batch["image"] - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.RandomErasing(p=0.1), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input"], - ) - x = train_augmentations(x) - - batch["image"] = x - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - sample["image"] = self.norm(sample["image"]) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - RESISC45(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = RESISC45( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = RESISC45(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = RESISC45(split="test", transforms=transforms, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, + :class:`~torchgeo.datasets.RESISC45`. + """ + super().__init__(RESISC45, batch_size, num_workers, **kwargs) + + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.RandomErasing(p=0.1), + K.ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image"], ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.RESISC45.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 8253791211c..d412c93d978 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -3,187 +3,113 @@ """SEN12MS datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict -import pytorch_lightning as pl import torch from sklearn.model_selection import GroupShuffleSplit -from torch.utils.data import DataLoader, Subset +from torch import Tensor +from torch.utils.data import Subset from ..datasets import SEN12MS +from .geo import NonGeoDataModule -class SEN12MSDataModule(pl.LightningDataModule): +class SEN12MSDataModule(NonGeoDataModule): """LightningDataModule implementation for the SEN12MS dataset. Implements 80/20 geographic train/val splits and uses the test split from the - classification dataset definitions. See :func:`setup` for more details. + classification dataset definitions. Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See https://arxiv.org/abs/2002.08254. """ #: Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader - #: here https://github.com/lukasliebel/dfc2020_baseline. + #: here: https://github.com/lukasliebel/dfc2020_baseline. DFC2020_CLASS_MAPPING = torch.tensor( - [ - 0, # maps 0s to 0 - 1, # maps 1s to 1 - 1, # maps 2s to 1 - 1, # ... - 1, - 1, - 2, - 2, - 3, - 3, - 4, - 5, - 6, - 7, - 6, - 8, - 9, - 10, - ] + [0, 1, 1, 1, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 6, 8, 9, 10] + ) + + std = torch.tensor( + [-25, -25, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4, 1e4] ) def __init__( self, - seed: int = 0, - band_set: str = "all", batch_size: int = 64, num_workers: int = 0, + band_set: str = "all", **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for SEN12MS based DataLoaders. + """Initialize a new SEN12MSDataModule instance. Args: - seed: The seed value to use when doing the sklearn based ShuffleSplit - band_set: The subset of S1/S2 bands to use. Options are: "all", + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + band_set: Subset of S1/S2 bands to use. Options are: "all", "s1", "s2-all", and "s2-reduced" where the "s2-reduced" set includes: B2, B3, B4, B8, B11, and B12. - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SEN12MS` - """ - super().__init__() - assert band_set in SEN12MS.BAND_SETS.keys() - - self.seed = seed - self.band_set = band_set - self.bands = SEN12MS.BAND_SETS[band_set] - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample + :class:`~torchgeo.datasets.SEN12MS`. """ - sample["image"] = sample["image"].float() + kwargs["bands"] = SEN12MS.BAND_SETS[band_set] - if self.band_set == "all": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - sample["image"][2:] = sample["image"][2:].clamp(0, 10000) / 10000 - elif self.band_set == "s1": - sample["image"][:2] = sample["image"][:2].clamp(-25, 0) / -25 - else: - sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 + if band_set == "s1": + self.std = self.std[:2] + elif band_set == "s2-all": + self.std = self.std[2:] + elif band_set == "s2-reduced": + self.std = self.std[torch.tensor([3, 4, 5, 9, 12, 13])] - if "mask" in sample: - sample["mask"] = sample["mask"][0, :, :].long() - sample["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, sample["mask"]) + super().__init__(SEN12MS, batch_size, num_workers, **kwargs) - return sample - - def setup(self, stage: Optional[str] = None) -> None: - """Create the train/val/test splits based on the original Dataset objects. - - The splits should be done here vs. in :func:`__init__` per the docs: - https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html#setup. - - We split samples between train and val geographically with proportions of 80/20. - This mimics the geographic test set split. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} - - self.all_train_dataset = SEN12MS( - split="train", bands=self.bands, transforms=self.preprocess, **self.kwargs - ) - - self.all_test_dataset = SEN12MS( - split="test", bands=self.bands, transforms=self.preprocess, **self.kwargs - ) - - # A patch is a filename like: "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" - # This patch will belong to the scene that is uniquelly identified by its - # (season, scene_id) tuple. Because the largest scene_id is 149, we can simply - # give each season a large number and representing a `unique_scene_id` as - # `season_id + scene_id`. - scenes = [] - for scene_fn in self.all_train_dataset.ids: - parts = scene_fn.split("_") - season_id = season_to_int[parts[1]] - scene_id = int(parts[3]) - scenes.append(season_id + scene_id) - - train_indices, val_indices = next( - GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=self.seed).split( - scenes, groups=scenes + if stage in ["fit", "validate"]: + season_to_int = {"winter": 0, "spring": 1000, "summer": 2000, "fall": 3000} + + self.dataset = SEN12MS(split="train", **self.kwargs) + + # A patch is a filename like: + # "ROIs{num}_{season}_s2_{scene_id}_p{patch_id}.tif" + # This patch will belong to the scene that is uniquely identified by its + # (season, scene_id) tuple. Because the largest scene_id is 149, we can + # simply give each season a large number and representing a unique_scene_id + # as (season_id + scene_id). + scenes = [] + for scene_fn in self.dataset.ids: + parts = scene_fn.split("_") + season_id = season_to_int[parts[1]] + scene_id = int(parts[3]) + scenes.append(season_id + scene_id) + + train_indices, val_indices = next( + GroupShuffleSplit(test_size=0.2, n_splits=2, random_state=0).split( + scenes, groups=scenes + ) ) - ) - self.train_dataset = Subset(self.all_train_dataset, train_indices) - self.val_dataset = Subset(self.all_train_dataset, val_indices) - self.test_dataset = Subset( - self.all_test_dataset, range(len(self.all_test_dataset)) - ) + self.train_dataset = Subset(self.dataset, train_indices) + self.val_dataset = Subset(self.dataset, val_indices) + if stage in ["test"]: + self.test_dataset = SEN12MS(split="test", **self.kwargs) - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. Returns: - validation data loader + A batch of data. """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + batch["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, batch["mask"]) - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index 1f331ea38c0..e8e9c43f742 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -3,25 +3,31 @@ """So2Sat datamodule.""" -from typing import Any, Dict, Optional, cast +from typing import Any -import matplotlib.pyplot as plt -import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize from ..datasets import So2Sat +from .geo import NonGeoDataModule -class So2SatDataModule(pl.LightningDataModule): +class So2SatDataModule(NonGeoDataModule): """LightningDataModule implementation for the So2Sat dataset. Uses the train/val/test splits from the dataset. """ - band_means = torch.tensor( + # TODO: calculate mean/std dev of s1 bands + mean = torch.tensor( [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, 0.12375696117681859, 0.1092774636368323, 0.1010855203267882, @@ -34,9 +40,16 @@ class So2SatDataModule(pl.LightningDataModule): 0.10905050699570007, ] ) - - band_stds = torch.tensor( + std = torch.tensor( [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, 0.03958795985905458, 0.047778262752410296, 0.06636616706371974, @@ -50,146 +63,46 @@ class So2SatDataModule(pl.LightningDataModule): ] ) - # this reorders the bands to put S2 RGB first, then remainder of S2 - reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] - def __init__( self, batch_size: int = 64, num_workers: int = 0, - band_set: str = "rgb", - unsupervised_mode: bool = False, + band_set: str = "all", **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for So2Sat based DataLoaders. + """Initialize a new So2SatDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - band_set: Collection of So2Sat bands to use - unsupervised_mode: Makes the train dataloader return imagery from the train, - val, and test sets + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + band_set: One of 'all', 's1', or 's2'. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.So2Sat` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.band_set = band_set - self.unsupervised_mode = unsupervised_mode - self.kwargs = kwargs - - self.norm = Normalize(self.band_means, self.band_stds) - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample + :class:`~torchgeo.datasets.So2Sat`. """ - sample["image"] = sample["image"].float() - sample["image"] = self.norm(sample["image"]) - sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] + kwargs["bands"] = So2Sat.BAND_SETS[band_set] - if self.band_set == "rgb": - sample["image"] = sample["image"][:3, :, :] + if band_set == "s1": + self.mean = self.mean[:8] + self.std = self.std[:8] + elif band_set == "s2": + self.mean = self.mean[8:] + self.std = self.std[8:] - return sample + super().__init__(So2Sat, batch_size, num_workers, **kwargs) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + def setup(self, stage: str) -> None: + """Set up datasets. - This method is called once per GPU per run. + Called at the beginning of fit, validate, test, or predict. During distributed + training, this method is called from every process across all the nodes. Setting + state here is recommended. Args: - stage: stage to set up - """ - bands = So2Sat.BAND_SETS["s2"] - train_transforms = Compose([self.preprocess]) - val_test_transforms = self.preprocess - - if not self.unsupervised_mode: - self.train_dataset = So2Sat( - split="train", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = So2Sat( - split="validation", - bands=bands, - transforms=val_test_transforms, - **self.kwargs, - ) - - self.test_dataset = So2Sat( - split="test", bands=bands, transforms=val_test_transforms, **self.kwargs - ) - - else: - - temp_train = So2Sat( - split="train", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.val_dataset = So2Sat( - split="validation", - bands=bands, - transforms=train_transforms, - **self.kwargs, - ) - - self.test_dataset = So2Sat( - split="test", bands=bands, transforms=train_transforms, **self.kwargs - ) - - self.train_dataset = cast( - So2Sat, temp_train + self.val_dataset + self.test_dataset - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.So2Sat.plot`. - - .. versionadded:: 0.4 + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["fit"]: + self.train_dataset = So2Sat(split="train", **self.kwargs) + if stage in ["fit", "validate"]: + self.val_dataset = So2Sat(split="validation", **self.kwargs) + if stage in ["test"]: + self.test_dataset = So2Sat(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/spacenet.py b/torchgeo/datamodules/spacenet.py index 4af8ac8091d..802cc7a26c6 100644 --- a/torchgeo/datamodules/spacenet.py +++ b/torchgeo/datamodules/spacenet.py @@ -3,18 +3,18 @@ """SpaceNet datamodules.""" -from typing import Any, Dict, Optional +from typing import Any, Dict import kornia.augmentation as K -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader +from torch import Tensor from ..datasets import SpaceNet1 +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule from .utils import dataset_split -class SpaceNet1DataModule(pl.LightningDataModule): +class SpaceNet1DataModule(NonGeoDataModule): """LightningDataModule implementation for the SpaceNet1 dataset. Randomly splits into train/val/test. @@ -30,150 +30,70 @@ def __init__( test_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for SpaceNet1. + """Initialize a new SpaceNet1DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set - test_split_pct: What percentage of the dataset to use as a test set + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. + val_split_pct: Percentage of the dataset to use as a validation set. + test_split_pct: Percentage of the dataset to use as a test set. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.SpaceNet1` + :class:`~torchgeo.datasets.SpaceNet1`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + super().__init__(SpaceNet1, batch_size, num_workers, **kwargs) + self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.kwargs = kwargs - - self.padto = K.PadTo((448, 448)) - - def on_after_batch_transfer( - self, batch: Dict[str, Any], batch_idx: int - ) -> Dict[str, Any]: - """Apply batch augmentations after batch is transferred to the device. - - Args: - batch: mini-batch of data - batch_idx: batch index - - Returns: - augmented mini-batch - """ - if ( - hasattr(self, "trainer") - and self.trainer is not None - and hasattr(self.trainer, "training") - and self.trainer.training - ): - # Kornia expects masks to be floats with a channel dimension - x = batch["image"] - y = batch["mask"].float().unsqueeze(1) - - train_augmentations = K.AugmentationSequential( - K.RandomRotation(p=0.5, degrees=90), - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - K.RandomSharpness(p=0.5), - K.ColorJitter( - p=0.5, - brightness=0.1, - contrast=0.1, - saturation=0.1, - hue=0.1, - silence_instantiation_warning=True, - ), - data_keys=["input", "mask"], - ) - x, y = train_augmentations(x, y) - - # torchmetrics expects masks to be longs without a channel dimension - batch["image"] = x - batch["mask"] = y.squeeze(1).long() - - return batch - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image and mask - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() / 255 - sample["image"] = self.padto(sample["image"]).squeeze() - - if "mask" in sample: - # We add 1 to the mask to map the current {background, building} labels to - # the values {1, 2}. This is necessary because we add 0 padding to the - # mask that we want to ignore in the loss function. - sample["mask"] = self.padto(sample["mask"].float() + 1).squeeze() - sample["mask"] = sample["mask"].long() - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - This method is only called once per run. - """ - if self.kwargs.get("download", False): - SpaceNet1(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.PadTo((448, 448)), + K.RandomRotation(p=0.5, degrees=90), + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + K.RandomSharpness(p=0.5), + K.ColorJitter( + p=0.5, + brightness=0.1, + contrast=0.1, + saturation=0.1, + hue=0.1, + silence_instantiation_warning=True, + ), + data_keys=["image", "mask"], + ) + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.PadTo((448, 448)), + data_keys=["image", "mask"], + ) - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - self.dataset = SpaceNet1(transforms=self.preprocess, **self.kwargs) + self.dataset = SpaceNet1(**self.kwargs) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( - self.dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, + self.dataset, self.val_split_pct, self.test_split_pct ) - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + def on_after_batch_transfer( + self, batch: Dict[str, Tensor], dataloader_idx: int + ) -> Dict[str, Tensor]: + """Apply batch augmentations to the batch after it is transferred to the device. - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. + Args: + batch: A batch of data that needs to be altered or augmented. + dataloader_idx: The index of the dataloader to which the batch belongs. Returns: - testing data loader + A batch of data. """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) + # We add 1 to the mask to map the current {background, building} labels to + # the values {1, 2}. This is necessary because we add 0 padding to the + # mask that we want to ignore in the loss function. + batch["mask"] += 1 - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.SpaceNet.plot`.""" - return self.dataset.plot(*args, **kwargs) + return super().on_after_batch_transfer(batch, dataloader_idx) diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index bc56908814f..dc8f43b5828 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -3,18 +3,16 @@ """UC Merced datamodule.""" -from typing import Any, Dict, Optional +from typing import Any -import matplotlib.pyplot as plt -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader -from torchvision.transforms import Compose +import kornia.augmentation as K from ..datasets import UCMerced +from ..transforms import AugmentationSequential +from .geo import NonGeoDataModule -class UCMercedDataModule(pl.LightningDataModule): +class UCMercedDataModule(NonGeoDataModule): """LightningDataModule implementation for the UC Merced dataset. Uses random train/val/test splits. @@ -23,103 +21,18 @@ class UCMercedDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for UCMerced based DataLoaders. + """Initialize a new UCMercedDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.UCMerced` + :class:`~torchgeo.datasets.UCMerced`. """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs + super().__init__(UCMerced, batch_size, num_workers, **kwargs) - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - c, h, w = sample["image"].shape - if h != 256 or w != 256: - sample["image"] = torchvision.transforms.functional.resize( - sample["image"], size=(256, 256) - ) - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - UCMerced(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. - - This method is called once per GPU per run. - - Args: - stage: stage to set up - """ - transforms = Compose([self.preprocess]) - - self.train_dataset = UCMerced( - split="train", transforms=transforms, **self.kwargs - ) - self.val_dataset = UCMerced(split="val", transforms=transforms, **self.kwargs) - self.test_dataset = UCMerced(split="test", transforms=transforms, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.Resize(size=256), + data_keys=["image"], ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.UCMerced.plot`. - - .. versionadded:: 0.2 - """ - return self.val_dataset.plot(*args, **kwargs) diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 9c7fa6d0333..329afe94058 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -3,16 +3,13 @@ """USAVars datamodule.""" -from typing import Any, Dict, Optional - -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader +from typing import Any from ..datasets import USAVars +from .geo import NonGeoDataModule -class USAVarsDataModule(pl.LightningModule): +class USAVarsDataModule(NonGeoDataModule): """LightningDataModule implementation for the USAVars dataset. Uses random train/val/test splits. @@ -23,85 +20,12 @@ class USAVarsDataModule(pl.LightningModule): def __init__( self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: - """Initialize a LightningDataModule for USAVars based DataLoaders. + """Initialize a new USAVarsDataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.USAVars` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: dictionary containing image - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample - - def prepare_data(self) -> None: - """Make sure that the dataset is downloaded. - - This method is only called once per run. - """ - if self.kwargs.get("download", False): - USAVars(**self.kwargs) - - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. - """ - self.train_dataset = USAVars( - split="train", transforms=self.preprocess, **self.kwargs - ) - self.val_dataset = USAVars( - split="val", transforms=self.preprocess, **self.kwargs - ) - self.test_dataset = USAVars( - split="test", transforms=self.preprocess, **self.kwargs - ) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training.""" - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation.""" - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing.""" - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.USAVars.plot`. - - .. versionadded:: 0.4 + :class:`~torchgeo.datasets.USAVars`. """ - return self.train_dataset.plot(*args, **kwargs) + super().__init__(USAVars, batch_size, num_workers, **kwargs) diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d088e493312..b1df01721c3 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -5,11 +5,17 @@ from typing import Any, List, Optional, Union +from torch import Generator from torch.utils.data import Subset, TensorDataset, random_split from ..datasets import NonGeoDataset +# Based on lightning_lite.utilities.exceptions +class MisconfigurationException(Exception): + """Exception used to inform users of misuse with Lightning.""" + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float, @@ -19,6 +25,10 @@ def dataset_split( If ``test_pct`` is not set then only train and validation splits are returned. + .. deprecated:: 0.4 + Use :func:`torch.utils.data.random_split` instead, ``random_split`` + now supports percentages as of PyTorch 1.13. + Args: dataset: dataset to be split into train/val or train/val/test subsets val_pct: percentage of samples to be in validation set @@ -30,9 +40,15 @@ def dataset_split( if test_pct is None: val_length = round(len(dataset) * val_pct) train_length = len(dataset) - val_length - return random_split(dataset, [train_length, val_length]) + return random_split( + dataset, [train_length, val_length], generator=Generator().manual_seed(0) + ) else: val_length = round(len(dataset) * val_pct) test_length = round(len(dataset) * test_pct) train_length = len(dataset) - (val_length + test_length) - return random_split(dataset, [train_length, val_length, test_length]) + return random_split( + dataset, + [train_length, val_length, test_length], + generator=Generator().manual_seed(0), + ) diff --git a/torchgeo/datamodules/vaihingen.py b/torchgeo/datamodules/vaihingen.py index 2df86acdfd4..1128ea76655 100644 --- a/torchgeo/datamodules/vaihingen.py +++ b/torchgeo/datamodules/vaihingen.py @@ -3,23 +3,19 @@ """Vaihingen datamodule.""" -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Tuple, Union -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from einops import rearrange -from kornia.augmentation import Normalize -from torch import Tensor -from torch.utils.data import DataLoader +import kornia.augmentation as K from ..datasets import Vaihingen2D from ..samplers.utils import _to_tuple from ..transforms import AugmentationSequential -from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop +from ..transforms.transforms import _RandomNCrop +from .geo import NonGeoDataModule from .utils import dataset_split -class Vaihingen2DDataModule(pl.LightningDataModule): +class Vaihingen2DDataModule(NonGeoDataModule): """LightningDataModule implementation for the Vaihingen2D dataset. Uses the train/test splits from the dataset. @@ -29,133 +25,44 @@ class Vaihingen2DDataModule(pl.LightningDataModule): def __init__( self, - num_tiles_per_batch: int = 16, - num_patches_per_tile: int = 16, + batch_size: int = 64, patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a new LightningDataModule instance. - - The Vaihingen2D dataset contains images that are too large to pass - directly through a model. Instead, we randomly sample patches from image tiles - during training and chop up image tiles into patch grids during evaluation. - During training, the effective batch size is equal to - ``num_tiles_per_batch`` x ``num_patches_per_tile``. - - .. versionchanged:: 0.4 - *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, - and *patch_size*. + """Initialize a new Vaihingen2DDataModule instance. Args: - num_tiles_per_batch: The number of image tiles to sample from during - training - num_patches_per_tile: The number of patches to randomly sample from each - image tile during training - patch_size: The size of each patch, either ``size`` or ``(height, width)``. - Should be a multiple of 32 for most segmentation architectures - val_split_pct: The percentage of the dataset to use as a validation set - num_workers: The number of workers to use for parallel data loading + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures. + val_split_pct: Percentage of the dataset to use as a validation set. + num_workers: Number of workers for parallel data loading. **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.Vaihingen2D` + :class:`~torchgeo.datasets.Vaihingen2D`. """ - super().__init__() + super().__init__(Vaihingen2D, 1, num_workers, **kwargs) - self.num_tiles_per_batch = num_tiles_per_batch - self.num_patches_per_tile = num_patches_per_tile self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct - self.num_workers = num_workers - self.kwargs = kwargs - self.train_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _RandomNCrop(self.patch_size, self.num_patches_per_tile), - data_keys=["image", "mask"], - ) - self.test_transform = AugmentationSequential( - Normalize(mean=0.0, std=255.0), - _ExtractTensorPatches(self.patch_size), + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + _RandomNCrop(self.patch_size, batch_size), data_keys=["image", "mask"], ) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main Dataset objects. - - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up - """ - train_dataset = Vaihingen2D(split="train", **self.kwargs) - self.train_dataset, self.val_dataset = dataset_split( - train_dataset, self.val_split_pct - ) - self.test_dataset = Vaihingen2D(split="test", **self.kwargs) - - def train_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.num_tiles_per_batch, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def test_dataloader(self) -> DataLoader[Dict[str, Tensor]]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False - ) - - def on_after_batch_transfer( - self, batch: Dict[str, Tensor], dataloader_idx: int - ) -> Dict[str, Tensor]: - """Apply augmentations to batch after transferring to GPU. - - Args: - batch: A batch of data that needs to be altered or augmented - dataloader_idx: The index of the dataloader to which the batch belongs - - Returns: - A batch of data - """ - # Kornia requires masks to have a channel dimension - batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") - - if self.trainer: - if self.trainer.training: - batch = self.train_transform(batch) - elif self.trainer.validating or self.trainer.testing: - batch = self.test_transform(batch) - - # Torchmetrics does not support masks with a channel dimension - batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - - return batch - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.Vaihingen2D.plot`. - - .. versionadded:: 0.4 + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["fit", "validate"]: + self.dataset = Vaihingen2D(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, self.val_split_pct + ) + if stage in ["test"]: + self.test_dataset = Vaihingen2D(split="test", **self.kwargs) diff --git a/torchgeo/datamodules/xview.py b/torchgeo/datamodules/xview.py index 8f3fce7a8c4..8f96d786bea 100644 --- a/torchgeo/datamodules/xview.py +++ b/torchgeo/datamodules/xview.py @@ -3,18 +3,14 @@ """xView2 datamodule.""" -from typing import Any, Dict, Optional - -import matplotlib.pyplot as plt -import pytorch_lightning as pl -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose +from typing import Any from ..datasets import XView2 +from .geo import NonGeoDataModule from .utils import dataset_split -class XView2DataModule(pl.LightningDataModule): +class XView2DataModule(NonGeoDataModule): """LightningDataModule implementation for the xView2 dataset. Uses the train/val/test splits from the dataset. @@ -29,101 +25,29 @@ def __init__( val_split_pct: float = 0.2, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for xView2 based DataLoaders. + """Initialize a new XView2DataModule instance. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders + batch_size: Size of each mini-batch. + num_workers: Number of workers for parallel data loading. val_split_pct: What percentage of the dataset to use as a validation set **kwargs: Additional keyword arguments passed to - :class:`~torchgeo.datasets.XView2` - """ - super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers - self.val_split_pct = val_split_pct - self.kwargs = kwargs - - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample + :class:`~torchgeo.datasets.XView2`. """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + super().__init__(XView2, batch_size, num_workers, **kwargs) - def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + self.val_split_pct = val_split_pct - This method is called once per GPU per run. + def setup(self, stage: str) -> None: + """Set up datasets. Args: - stage: stage to set up + stage: Either 'fit', 'validate', 'test', or 'predict'. """ - transforms = Compose([self.preprocess]) - - dataset = XView2(split="train", transforms=transforms, **self.kwargs) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 + if stage in ["fit", "validate"]: + self.dataset = XView2(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + self.dataset, val_pct=self.val_split_pct ) - else: - self.train_dataset = dataset - self.val_dataset = dataset - - self.test_dataset = XView2(split="test", transforms=transforms, **self.kwargs) - - def train_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for training. - - Returns: - training data loader - """ - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=True, - ) - - def val_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for validation. - - Returns: - validation data loader - """ - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def test_dataloader(self) -> DataLoader[Any]: - """Return a DataLoader for testing. - - Returns: - testing data loader - """ - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, - ) - - def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: - """Run :meth:`torchgeo.datasets.XView2.plot`. - - .. versionadded:: 0.4 - """ - return self.test_dataset.plot(*args, **kwargs) + if stage in ["test"]: + self.test_dataset = XView2(split="test", **self.kwargs) diff --git a/torchgeo/datasets/bigearthnet.py b/torchgeo/datasets/bigearthnet.py index 81c318baa52..e2f49770af1 100644 --- a/torchgeo/datasets/bigearthnet.py +++ b/torchgeo/datasets/bigearthnet.py @@ -394,7 +394,7 @@ def _load_image(self, index: int) -> Tensor: ) images.append(array) arrays: "np.typing.NDArray[np.int_]" = np.stack(images, axis=0) - tensor = torch.from_numpy(arrays) + tensor = torch.from_numpy(arrays).float() return tensor def _load_target(self, index: int) -> Tensor: @@ -539,9 +539,6 @@ def plot( Returns: a matplotlib Figure with the rendered sample - Raises: - ValueError: if ``self.bands`` is "s1" - .. versionadded:: 0.2 """ if self.bands == "s2": diff --git a/torchgeo/datasets/chesapeake.py b/torchgeo/datasets/chesapeake.py index 454c52f980b..024c384dc0a 100644 --- a/torchgeo/datasets/chesapeake.py +++ b/torchgeo/datasets/chesapeake.py @@ -688,8 +688,8 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: sample["image"] = np.concatenate(sample["image"], axis=0) sample["mask"] = np.concatenate(sample["mask"], axis=0) - sample["image"] = torch.from_numpy(sample["image"]) - sample["mask"] = torch.from_numpy(sample["mask"]) + sample["image"] = torch.from_numpy(sample["image"]).float() + sample["mask"] = torch.from_numpy(sample["mask"]).long() if self.transforms is not None: sample = self.transforms(sample) @@ -765,7 +765,11 @@ def plot( .. versionadded:: 0.4 """ image = np.rollaxis(sample["image"].numpy(), 0, 3) - mask = np.rollaxis(sample["mask"].numpy(), 0, 3) + mask = sample["mask"].numpy() + if mask.ndim == 3: + mask = np.rollaxis(mask, 0, 3) + else: + mask = np.expand_dims(mask, 2) num_panels = len(self.layers) showing_predictions = "prediction" in sample diff --git a/torchgeo/datasets/cowc.py b/torchgeo/datasets/cowc.py index 1b40974d6d9..d34749cf2d0 100644 --- a/torchgeo/datasets/cowc.py +++ b/torchgeo/datasets/cowc.py @@ -147,7 +147,7 @@ def _load_image(self, index: int) -> Tensor: filename = os.path.join(self.root, self.images[index]) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -162,7 +162,7 @@ def _load_target(self, index: int) -> Tensor: the target """ target = int(self.targets[index]) - tensor = torch.tensor(target) + tensor = torch.tensor(target).float() return tensor def _check_integrity(self) -> bool: diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index c87dcc997a7..54d55111468 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -153,10 +153,8 @@ def _load_image(self, directory: str) -> Tensor: except AttributeError: resample = Image.BILINEAR img = img.resize(size=(self.size, self.size), resample=resample) - array: "np.typing.NDArray[np.int_]" = np.array(img) - if len(array.shape) == 3: - array = array[:, :, 0] - tensor = torch.from_numpy(array) + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor = torch.from_numpy(array).permute((2, 0, 1)).float() return tensor def _load_features(self, directory: str) -> Dict[str, Any]: @@ -178,7 +176,7 @@ def _load_features(self, directory: str) -> Dict[str, Any]: features["relative_time"] = int(features["relative_time"]) features["ocean"] = int(features["ocean"]) - features["label"] = int(features["wind_speed"]) + features["label"] = torch.tensor(int(features["wind_speed"])).float() return features @@ -243,7 +241,7 @@ def plot( fig, ax = plt.subplots(1, 1, figsize=(10, 10)) - ax.imshow(image, cmap="gray") + ax.imshow(image.permute(1, 2, 0)) ax.axis("off") if show_titles: diff --git a/torchgeo/datasets/etci2021.py b/torchgeo/datasets/etci2021.py index 5e2919102b8..2a11edc9204 100644 --- a/torchgeo/datasets/etci2021.py +++ b/torchgeo/datasets/etci2021.py @@ -207,7 +207,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/eurosat.py b/torchgeo/datasets/eurosat.py index 85133f178c6..fea0ae73f8f 100644 --- a/torchgeo/datasets/eurosat.py +++ b/torchgeo/datasets/eurosat.py @@ -170,7 +170,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ image, label = self._load_image(index) - image = torch.index_select(image, dim=0, index=self.band_indices) + image = torch.index_select(image, dim=0, index=self.band_indices).float() sample = {"image": image, "label": label} if self.transforms is not None: diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index dbc4011d0a7..84a531268e8 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -427,8 +427,11 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: else: data = self._merge_files(filepaths, query, self.band_indexes) - key = "image" if self.is_image else "mask" - sample = {key: data, "crs": self.crs, "bbox": query} + sample = {"crs": self.crs, "bbox": query} + if self.is_image: + sample["image"] = data.float() + else: + sample["mask"] = data.long() if self.transforms is not None: sample = self.transforms(sample) @@ -777,7 +780,7 @@ def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: """ img, label = ImageFolder.__getitem__(self, index) array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) label = torch.tensor(label) diff --git a/torchgeo/datasets/inria.py b/torchgeo/datasets/inria.py index c1698cf6c15..395d0c909bb 100644 --- a/torchgeo/datasets/inria.py +++ b/torchgeo/datasets/inria.py @@ -116,7 +116,7 @@ def _load_image(self, path: str) -> Tensor: """ with rio.open(path) as img: array = img.read().astype(np.int32) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: @@ -131,7 +131,7 @@ def _load_target(self, path: str) -> Tensor: with rio.open(path) as img: array = img.read().astype(np.int32) array = np.clip(array, 0, 1) - mask = torch.from_numpy(array[0]) + mask = torch.from_numpy(array[0]).long() return mask def __len__(self) -> int: diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index a6514aae510..002ff595bd6 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -149,7 +149,7 @@ def _load_image(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + ".jpg") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor @@ -167,7 +167,7 @@ def _load_target(self, id_: str) -> Tensor: filename = os.path.join(self.root, "output", id_ + "_m.png") with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).long() return tensor def _verify(self) -> None: diff --git a/torchgeo/datasets/loveda.py b/torchgeo/datasets/loveda.py index 70d005c0df3..f6e2b83f3d2 100644 --- a/torchgeo/datasets/loveda.py +++ b/torchgeo/datasets/loveda.py @@ -214,7 +214,7 @@ def _load_image(self, path: str) -> Tensor: filename = os.path.join(path) with Image.open(filename) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor diff --git a/torchgeo/datasets/nasa_marine_debris.py b/torchgeo/datasets/nasa_marine_debris.py index 56428a5af84..4a506e5c701 100644 --- a/torchgeo/datasets/nasa_marine_debris.py +++ b/torchgeo/datasets/nasa_marine_debris.py @@ -129,7 +129,7 @@ def _load_image(self, path: str) -> Tensor: """ with rasterio.open(path) as f: array = f.read() - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor def _load_target(self, path: str) -> Tensor: @@ -233,6 +233,7 @@ def plot( """ ncols = 1 + sample["image"] = sample["image"].byte() image = sample["image"] if "boxes" in sample and len(sample["boxes"]): image = draw_bounding_boxes(image=sample["image"], boxes=sample["boxes"]) diff --git a/torchgeo/datasets/oscd.py b/torchgeo/datasets/oscd.py index 3977933432c..b302e8e4d21 100644 --- a/torchgeo/datasets/oscd.py +++ b/torchgeo/datasets/oscd.py @@ -129,7 +129,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: image2 = self._load_image(files["images2"]) mask = self._load_target(str(files["mask"])) - image = torch.stack(tensors=[image1, image2], dim=0) + image = torch.cat([image1, image2]) sample = {"image": image, "mask": mask} if self.transforms is not None: @@ -306,7 +306,9 @@ def get_masked(img: Tensor) -> "np.typing.NDArray[np.uint8]": ) return array - image1, image2 = get_masked(sample["image"][0]), get_masked(sample["image"][1]) + idx = sample["image"].shape[0] // 2 + image1 = get_masked(sample["image"][:idx]) + image2 = get_masked(sample["image"][idx:]) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) axs[0].axis("off") diff --git a/torchgeo/datasets/sen12ms.py b/torchgeo/datasets/sen12ms.py index 787098bf764..0e4db7adf73 100644 --- a/torchgeo/datasets/sen12ms.py +++ b/torchgeo/datasets/sen12ms.py @@ -223,14 +223,14 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: """ filename = self.ids[index] - lc = self._load_raster(filename, "lc") + lc = self._load_raster(filename, "lc").long() s1 = self._load_raster(filename, "s1") s2 = self._load_raster(filename, "s2") image = torch.cat(tensors=[s1, s2], dim=0) image = torch.index_select(image, dim=0, index=self.band_indices) - sample: Dict[str, Tensor] = {"image": image, "mask": lc} + sample: Dict[str, Tensor] = {"image": image, "mask": lc[0]} if self.transforms is not None: sample = self.transforms(sample) @@ -336,13 +336,13 @@ def plot( else: raise ValueError("Dataset doesn't contain some of the RGB bands") - image, mask = sample["image"][rgb_indices].numpy(), sample["mask"][0] + image, mask = sample["image"][rgb_indices].numpy(), sample["mask"] image = percentile_normalization(image) ncols = 2 showing_predictions = "prediction" in sample if showing_predictions: - prediction = sample["prediction"][0] + prediction = sample["prediction"] ncols += 1 fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 7c592f94cde..0770228a229 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -234,7 +234,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: s1 = torch.from_numpy(s1) s2 = torch.from_numpy(s2) - sample = {"image": torch.cat([s1, s2]), "label": label} + sample = {"image": torch.cat([s1, s2]).float(), "label": label} if self.transforms is not None: sample = self.transforms(sample) diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 3e5baa7ad75..f0e3987dd64 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -158,7 +158,7 @@ def _load_image(self, path: str) -> Tuple[Tensor, Affine, CRS]: filename = os.path.join(path) with rio.open(filename) as img: array = img.read().astype(np.int32) - tensor = torch.from_numpy(array) + tensor = torch.from_numpy(array).float() return tensor, img.transform, img.crs def _load_mask( diff --git a/torchgeo/trainers/byol.py b/torchgeo/trainers/byol.py index d733df142ae..3b9e8d0db3e 100644 --- a/torchgeo/trainers/byol.py +++ b/torchgeo/trainers/byol.py @@ -4,18 +4,15 @@ """BYOL tasks.""" import os -import random -from typing import Any, Callable, Dict, Optional, Tuple, cast +from typing import Any, Dict, Optional, Tuple, cast import pytorch_lightning as pl import timm import torch +import torch.nn as nn import torch.nn.functional as F from kornia import augmentation as K -from kornia import filters -from kornia.geometry import transform as KorniaTransform from torch import Tensor, optim -from torch.nn.modules import BatchNorm1d, Linear, Module, ReLU, Sequential from torch.optim.lr_scheduler import ReduceLROnPlateau from torchvision.models._api import WeightsEnum @@ -39,38 +36,10 @@ def normalized_mse(x: Tensor, y: Tensor) -> Tensor: return mse -# TODO: Move this to transforms -class RandomApply(Module): - """Applies augmentation function (augm) with probability p.""" - - def __init__(self, augm: Callable[[Tensor], Tensor], p: float) -> None: - """Initialize RandomApply. - - Args: - augm: augmentation function to apply - p: probability with which the augmentation function is applied - """ - super().__init__() - self.augm = augm - self.p = p - - def forward(self, x: Tensor) -> Tensor: - """Applies an augmentation to the input with some probability. - - Args: - x: a batch of imagery - - Returns - augmented version of ``x`` with probability ``self.p`` else an un-augmented - version - """ - return x if random.random() > self.p else self.augm(x) - - # TODO: This isn't _really_ applying the augmentations from SimCLR as we have # multispectral imagery and thus can't naively apply color jittering or grayscale # conversions. We should think more about what makes sense here. -class SimCLRAugmentation(Module): +class SimCLRAugmentation(nn.Module): """A module for applying SimCLR augmentations. SimCLR was one of the first papers to show the effectiveness of random data @@ -87,13 +56,13 @@ def __init__(self, image_size: Tuple[int, int] = (256, 256)) -> None: super().__init__() self.size = image_size - self.augmentation = Sequential( - KorniaTransform.Resize(size=image_size, align_corners=False), + self.augmentation = nn.Sequential( + K.Resize(size=image_size, align_corners=False), # Not suitable for multispectral adapt - # RandomApply(K.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), + # K.ColorJitter(0.8, 0.8, 0.8, 0.8, 0.2), # K.RandomGrayscale(p=0.2), K.RandomHorizontalFlip(), - RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), + K.RandomGaussianBlur((3, 3), (1.5, 1.5), p=0.1), K.RandomResizedCrop(size=image_size), ) @@ -109,7 +78,7 @@ def forward(self, x: Tensor) -> Tensor: return cast(Tensor, self.augmentation(x)) -class MLP(Module): +class MLP(nn.Module): """MLP used in the BYOL projection head.""" def __init__( @@ -123,11 +92,11 @@ def __init__( hidden_size: size of the hidden layer """ super().__init__() - self.mlp = Sequential( - Linear(dim, hidden_size), - BatchNorm1d(hidden_size), - ReLU(inplace=True), - Linear(hidden_size, projection_size), + self.mlp = nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size), ) def forward(self, x: Tensor) -> Tensor: @@ -142,7 +111,7 @@ def forward(self, x: Tensor) -> Tensor: return cast(Tensor, self.mlp(x)) -class BackboneWrapper(Module): +class BackboneWrapper(nn.Module): """Backbone wrapper for joining a model and a projection head. When we call .forward() on this module the following steps happen: @@ -158,7 +127,7 @@ class BackboneWrapper(Module): def __init__( self, - model: Module, + model: nn.Module, projection_size: int = 256, hidden_size: int = 4096, layer: int = -2, @@ -178,13 +147,13 @@ def __init__( self.hidden_size = hidden_size self.layer = layer - self._projector: Optional[Module] = None + self._projector: Optional[nn.Module] = None self._projector_dim: Optional[int] = None self._encoded = torch.empty(0) self._register_hook() @property - def projector(self) -> Module: + def projector(self) -> nn.Module: """Wrapper module for the projector head.""" assert self._projector_dim is not None if self._projector is None: @@ -234,7 +203,7 @@ def forward(self, x: Tensor) -> Tensor: return self._encoded -class BYOL(Module): +class BYOL(nn.Module): """BYOL implementation. BYOL contains two identical backbone networks. The first is trained as usual, and @@ -247,13 +216,13 @@ class BYOL(Module): def __init__( self, - model: Module, + model: nn.Module, image_size: Tuple[int, int] = (256, 256), hidden_layer: int = -2, in_channels: int = 4, projection_size: int = 256, hidden_size: int = 4096, - augment_fn: Optional[Module] = None, + augment_fn: Optional[nn.Module] = None, beta: float = 0.99, **kwargs: Any, ) -> None: @@ -273,7 +242,7 @@ def __init__( """ super().__init__() - self.augment: Module + self.augment: nn.Module if augment_fn is None: self.augment = SimCLRAugmentation(image_size) else: diff --git a/torchgeo/trainers/classification.py b/torchgeo/trainers/classification.py index 41d7c922ecd..e8af2dcb70f 100644 --- a/torchgeo/trainers/classification.py +++ b/torchgeo/trainers/classification.py @@ -187,20 +187,25 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] + datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] + summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except ValueError: pass def validation_epoch_end(self, outputs: Any) -> None: @@ -366,19 +371,24 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] + datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] + summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) - except AttributeError: + except ValueError: pass def test_step(self, *args: Any, **kwargs: Any) -> None: diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index b50d273d419..9037208998e 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -169,9 +169,14 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.val_metrics.update(y_hat, y) - if batch_idx < 10: + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] + datamodule = self.trainer.datamodule batch["prediction_boxes"] = [b["boxes"].cpu() for b in y_hat] batch["prediction_labels"] = [b["labels"].cpu() for b in y_hat] batch["prediction_scores"] = [b["scores"].cpu() for b in y_hat] @@ -182,12 +187,12 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: sample["image"] *= 255 sample["image"] = sample["image"].to(torch.uint8) fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] + summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except ValueError: pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/trainers/regression.py b/torchgeo/trainers/regression.py index 73bffd5934f..98234bdec5d 100644 --- a/torchgeo/trainers/regression.py +++ b/torchgeo/trainers/regression.py @@ -144,20 +144,25 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss) self.val_metrics(y_hat, y) - if batch_idx < 10: + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] + datamodule = self.trainer.datamodule batch["prediction"] = y_hat for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] + summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except ValueError: pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/trainers/segmentation.py b/torchgeo/trainers/segmentation.py index 1d92c0842b1..b5868196fc2 100644 --- a/torchgeo/trainers/segmentation.py +++ b/torchgeo/trainers/segmentation.py @@ -197,20 +197,25 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) - if batch_idx < 10: + if ( + batch_idx < 10 + and hasattr(self.trainer, "datamodule") + and self.logger + and hasattr(self.logger, "experiment") + ): try: - datamodule = self.trainer.datamodule # type: ignore[attr-defined] + datamodule = self.trainer.datamodule batch["prediction"] = y_hat_hard for key in ["image", "mask", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) - summary_writer = self.logger.experiment # type: ignore[union-attr] + summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) plt.close() - except AttributeError: + except ValueError: pass def validation_epoch_end(self, outputs: Any) -> None: diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 3990b9e82dc..c202a8859d2 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -7,9 +7,9 @@ import kornia import torch +from einops import rearrange from kornia.augmentation import GeometricAugmentationBase2D from kornia.augmentation.random_generator import CropGenerator -from kornia.contrib import compute_padding, extract_tensor_patches from kornia.geometry import crop_by_indices from torch import Tensor from torch.nn.modules import Module @@ -17,7 +17,11 @@ # TODO: contribute these to Kornia and delete this file class AugmentationSequential(Module): - """Wrapper around kornia AugmentationSequential to handle input dicts.""" + """Wrapper around kornia AugmentationSequential to handle input dicts. + + .. deprecated:: 0.4 + Use :class:`kornia.augmentation.container.AugmentationSequential` instead. + """ def __init__(self, *args: Module, data_keys: List[str]) -> None: """Initialize a new augmentation sequential instance. @@ -40,24 +44,26 @@ def __init__(self, *args: Module, data_keys: List[str]) -> None: self.augs = kornia.augmentation.AugmentationSequential(*args, data_keys=keys) - def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: + def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]: """Perform augmentations and update data dict. Args: - sample: the input + batch: the input Returns: the augmented input """ - # Kornia augmentations require masks & boxes to be float - if "mask" in self.data_keys: - mask_dtype = sample["mask"].dtype - sample["mask"] = sample["mask"].to(torch.float) - if "boxes" in self.data_keys: - boxes_dtype = sample["boxes"].dtype - sample["boxes"] = sample["boxes"].to(torch.float) - - inputs = [sample[k] for k in self.data_keys] + # Kornia augmentations require all inputs to be float + dtype = {} + for key in self.data_keys: + dtype[key] = batch[key].dtype + batch[key] = batch[key].float() + + # Kornia requires masks to have a channel dimension + if "mask" in batch and len(batch["mask"].shape) == 3: + batch["mask"] = rearrange(batch["mask"], "b h w -> b () h w") + + inputs = [batch[k] for k in self.data_keys] outputs_list: Union[Tensor, List[Tensor]] = self.augs(*inputs) outputs_list = ( outputs_list if isinstance(outputs_list, list) else [outputs_list] @@ -65,69 +71,17 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: outputs: Dict[str, Tensor] = { k: v for k, v in zip(self.data_keys, outputs_list) } - sample.update(outputs) - - # Convert masks & boxes to previous dtype - if "mask" in self.data_keys: - sample["mask"] = sample["mask"].to(mask_dtype) - if "boxes" in self.data_keys: - sample["boxes"] = sample["boxes"].to(boxes_dtype) - - return sample - - -class _ExtractTensorPatches(GeometricAugmentationBase2D): - """Chop up a tensor into a grid.""" - - def __init__(self, window_size: Union[int, Tuple[int, int]]) -> None: - """Initialize a new _ExtractTensorPatches instance. - - Args: - window_size: the size of each patch - """ - super().__init__(p=1) - self.flags = {"window_size": window_size} - - def compute_transformation( - self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any] - ) -> Tensor: - """Compute the transformation. - - Args: - input: the input tensor - params: generated parameters - flags: static parameters + batch.update(outputs) - Returns: - the transformation - """ - out: Tensor = self.identity_matrix(input) - return out - - def apply_transform( - self, - input: Tensor, - params: Dict[str, Tensor], - flags: Dict[str, Any], - transform: Optional[Tensor] = None, - ) -> Tensor: - """Apply the transform. + # Convert all inputs back to their previous dtype + for key in self.data_keys: + batch[key] = batch[key].to(dtype[key]) - Args: - input: the input tensor - params: generated parameters - flags: static parameters - transform: the geometric transformation tensor + # Torchmetrics does not support masks with a channel dimension + if "mask" in batch and batch["mask"].shape[1] == 1: + batch["mask"] = rearrange(batch["mask"], "b () h w -> b h w") - Returns: - the augmented input - """ - size = flags["window_size"] - h, w = input.shape[-2:] - padding = compute_padding((h, w), size) - input = extract_tensor_patches(input, size, size, padding) - input = torch.flatten(input, 0, 1) # [B, N, C?, H, W] -> [B*N, C?, H, W] - return input + return batch class _RandomNCrop(GeometricAugmentationBase2D):