Skip to content

Commit

Permalink
Define our own MisconfigurationException
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Jan 15, 2023
1 parent 4f19b26 commit c49080b
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 32 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,8 @@ NonGeoDataModule
^^^^^^^^^^^^^^^^

.. autoclass:: NonGeoDataModule

Utilities
---------

.. autoclass:: MisconfigurationException
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
("py:class", "LightningDataModule"),
("py:class", "pytorch_lightning.core.module.LightningModule"),
# Undocumented class
("py:exc", "MisconfigurationException"),
("py:class", "torchvision.models.resnet.ResNet"),
("py:class", "segmentation_models_pytorch.base.model.SegmentationModel"),
]
Expand Down
7 changes: 1 addition & 6 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@
import torch.nn as nn
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
MisconfigurationException,
)
from torchvision.models import resnet18

from torchgeo.datamodules import ChesapeakeCVPRDataModule
from torchgeo.datamodules import ChesapeakeCVPRDataModule, MisconfigurationException
from torchgeo.datasets import ChesapeakeCVPR
from torchgeo.samplers import GridGeoSampler
from torchgeo.trainers import BYOLTask
Expand Down
6 changes: 1 addition & 5 deletions tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
MisconfigurationException,
)
from torch.nn.modules import Module

from torchgeo.datamodules import (
BigEarthNetDataModule,
EuroSATDataModule,
MisconfigurationException,
RESISC45DataModule,
So2SatDataModule,
UCMercedDataModule,
Expand Down
7 changes: 1 addition & 6 deletions tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
MisconfigurationException,
)

from torchgeo.datamodules import NASAMarineDebrisDataModule
from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule
from torchgeo.datasets import NASAMarineDebris
from torchgeo.trainers import ObjectDetectionTask

Expand Down
7 changes: 3 additions & 4 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
from torchgeo.datamodules import (
COWCCountingDataModule,
MisconfigurationException,
TropicalCycloneDataModule,
)

from torchgeo.datamodules import COWCCountingDataModule, TropicalCycloneDataModule
from torchgeo.datasets import TropicalCyclone
from torchgeo.trainers import RegressionTask

Expand Down
6 changes: 1 addition & 5 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
from _pytest.monkeypatch import MonkeyPatch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
MisconfigurationException,
)
from torch.nn.modules import Module

from torchgeo.datamodules import (
Expand All @@ -24,6 +19,7 @@
InriaAerialImageLabelingDataModule,
LandCoverAIDataModule,
LoveDADataModule,
MisconfigurationException,
NAIPChesapeakeDataModule,
Potsdam2DDataModule,
SEN12MSDataModule,
Expand Down
3 changes: 3 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .spacenet import SpaceNet1DataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
from .vaihingen import Vaihingen2DDataModule
from .xview import XView2DataModule

Expand Down Expand Up @@ -59,4 +60,6 @@
# Base classes
"GeoDataModule",
"NonGeoDataModule",
# Utilities
"MisconfigurationException",
)
6 changes: 1 addition & 5 deletions torchgeo/datamodules/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
import matplotlib.pyplot as plt
import torch
from pytorch_lightning import LightningDataModule

# TODO: import from lightning_lite instead
from pytorch_lightning.utilities.exceptions import ( # type: ignore[attr-defined]
MisconfigurationException,
)
from torch import Tensor
from torch.utils.data import DataLoader, Dataset

Expand All @@ -25,6 +20,7 @@
RandomBatchGeoSampler,
)
from ..transforms import AugmentationSequential
from .utils import MisconfigurationException


class GeoDataModule(LightningDataModule):
Expand Down
5 changes: 5 additions & 0 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from ..datasets import NonGeoDataset


# Based on lightning_lite.utilities.exceptions
class MisconfigurationException(Exception):
"""Exception used to inform users of misuse with Lightning."""


def dataset_split(
dataset: Union[TensorDataset, NonGeoDataset],
val_pct: float,
Expand Down

0 comments on commit c49080b

Please sign in to comment.