diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index 69292c1b5d2..d66fee22755 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -141,3 +141,8 @@ NonGeoDataModule ^^^^^^^^^^^^^^^^ .. autoclass:: NonGeoDataModule + +Utilities +--------- + +.. autoclass:: MisconfigurationException diff --git a/docs/conf.py b/docs/conf.py index 58e674beffd..6b7ce0b4ad6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,7 +62,6 @@ ("py:class", "LightningDataModule"), ("py:class", "pytorch_lightning.core.module.LightningModule"), # Undocumented class - ("py:exc", "MisconfigurationException"), ("py:class", "torchvision.models.resnet.ResNet"), ("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"), ] diff --git a/tests/trainers/test_byol.py b/tests/trainers/test_byol.py index 9f649f9ed72..ecefc2ba53d 100644 --- a/tests/trainers/test_byol.py +++ b/tests/trainers/test_byol.py @@ -8,14 +8,9 @@ import torch.nn as nn from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torchvision.models import resnet18 -from torchgeo.datamodules import ChesapeakeCVPRDataModule +from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException from torchgeo.datasets import ChesapeakeCVPR from torchgeo.samplers import GridGeoSampler from torchgeo.trainers import BYOLTask diff --git a/tests/trainers/test_classification.py b/tests/trainers/test_classification.py index e88884a3737..cb212cef60f 100644 --- a/tests/trainers/test_classification.py +++ b/tests/trainers/test_classification.py @@ -9,16 +9,12 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchgeo.datamodules import ( BigEarthNetDataModule, EuroSATDataModule, + MisconfigurationException, RESISC45DataModule, So2SatDataModule, UCMercedDataModule, diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index a6446cf1cc3..fe7fae1a2de 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -9,12 +9,7 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) - -from torchgeo.datamodules import NASAMarineDebrisDataModule +from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris from torchgeo.trainers import ObjectDetectionTask diff --git a/tests/trainers/test_regression.py b/tests/trainers/test_regression.py index 91ae464bb21..82e5680899e 100644 --- a/tests/trainers/test_regression.py +++ b/tests/trainers/test_regression.py @@ -9,12 +9,11 @@ from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] +from torchgeo.datamodules import ( + COWCCountingDataModule, MisconfigurationException, + TropicalCycloneDataModule, ) - -from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule from torchgeo.datasets import TropicalCyclone from torchgeo.trainers import RegressionTask diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index a12324c3d95..6282152b8da 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -9,11 +9,6 @@ from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch.nn.modules import Module from torchgeo.datamodules import ( @@ -24,6 +19,7 @@ InriaAerialImageLabelingDataModule, LandCoverAIDataModule, LoveDADataModule, + MisconfigurationException, NAIPChesapeakeDataModule, Potsdam2DDataModule, SEN12MSDataModule, diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 2f846c4a5fc..fe4dafaa986 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -26,6 +26,7 @@ from .spacenet import SpaceNet1DataModule from .ucmerced import UCMercedDataModule from .usavars import USAVarsDataModule +from .utils import MisconfigurationException from .vaihingen import Vaihingen2DDataModule from .xview import XView2DataModule @@ -59,4 +60,6 @@ # Base classes "GeoDataModule", "NonGeoDataModule", + # Utilities + "MisconfigurationException", ) diff --git a/torchgeo/datamodules/geo.py b/torchgeo/datamodules/geo.py index bb1d325273b..5c94839c33d 100644 --- a/torchgeo/datamodules/geo.py +++ b/torchgeo/datamodules/geo.py @@ -9,11 +9,6 @@ import matplotlib.pyplot as plt import torch from pytorch_lightning import LightningDataModule - -# TODO: import from lightning_lite instead -from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined] - MisconfigurationException, -) from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -25,6 +20,7 @@ RandomBatchGeoSampler, ) from ..transforms import AugmentationSequential +from .utils import MisconfigurationException class GeoDataModule(LightningDataModule): diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d088e493312..be50cd7dc99 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -10,6 +10,11 @@ from ..datasets import NonGeoDataset +# Based on lightning_lite.utilities.exceptions +class MisconfigurationException(Exception): + """Exception used to inform users of misuse with Lightning.""" + + def dataset_split( dataset: Union[TensorDataset, NonGeoDataset], val_pct: float,