diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index f2679a50f2..d90184ea82 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -2,14 +2,23 @@ Datamodules for RL models that rely on experiences generated during training Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py """ +import importlib from abc import ABC from collections import deque, namedtuple from typing import Callable, Iterable, List, Tuple import torch -from gym import Env from torch.utils.data import IterableDataset +from pl_bolts.utils.warnings import warn_missing_pkg + +_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None +if _GYM_AVAILABLE: + from gym import Env +else: + warn_missing_pkg("gym") # pragma: no-cover + + # Datasets Experience = namedtuple( @@ -172,7 +181,7 @@ def env_actions(self, device) -> List[List[int]]: return actions - def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience: + def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience: """ Carries out a step through the given environment using the given action @@ -236,7 +245,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99): super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 5b39228742..54ba6a0bcf 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,12 +1,19 @@ +import importlib import os import torch -import torchvision.transforms as transforms from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from torch.utils.data.dataset import random_split from pl_bolts.datasets.kitti_dataset import KittiDataset +from pl_bolts.utils.warnings import warn_missing_pkg + +_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None +if _TORCHVISION_AVAILABLE: + import torchvision.transforms as transforms +else: + warn_missing_pkg('torchvision') # pragma: no-cover class KittiDataModule(LightningDataModule): @@ -56,6 +63,11 @@ def __init__( batch_size: the batch size seed: random seed to be used for train/val/test splits """ + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `torchvision` which is not installed yet.' + ) + super().__init__(*args, **kwargs) self.data_dir = data_dir if data_dir is not None else os.getcwd() self.batch_size = batch_size diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index dd66a40678..bd05d81c90 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -6,11 +6,12 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset +from pl_bolts.utils.warnings import warn_missing_pkg + try: from sklearn.utils import shuffle as sk_shuffle except ModuleNotFoundError: - raise ModuleNotFoundError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover - ' install it with `pip install sklearn`.') + warn_missing_pkg("sklearn") # pragma: no-cover _SKLEARN_AVAILABLE = False else: _SKLEARN_AVAILABLE = True diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 008d859c35..9e75e71918 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,13 +1,12 @@ import torch -import torchvision.transforms as T from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader from pl_bolts.utils.warnings import warn_missing_pkg try: + import torchvision.transforms as T from torchvision.datasets import VOCDetection - except ModuleNotFoundError: warn_missing_pkg('torchvision') # pragma: no-cover _TORCHVISION_AVAILABLE = False diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index f070bebe1a..5ed72189d2 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -11,13 +11,14 @@ import torch from torch._six import PY3 +from pl_bolts.utils.warnings import warn_missing_pkg + try: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file -except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) from err +except ModuleNotFoundError: + warn_missing_pkg('torchvision') # pragma: no-cover + ImageNet = object class UnlabeledImagenet(ImageNet): diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index bd8c774c39..a63c00739b 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -1,9 +1,18 @@ +import importlib import os import numpy as np -from PIL import Image from torch.utils.data import Dataset +from pl_bolts.utils.warnings import warn_missing_pkg + +_PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None +if _PIL_AVAILABLE: + from PIL import Image +else: + warn_missing_pkg('PIL') # pragma: no-cover + + DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) @@ -41,6 +50,11 @@ def __init__( void_labels: useless classes to be excluded from training valid_labels: useful classes to include """ + if not _PIL_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `PIL` which is not installed yet.' + ) + self.img_size = img_size self.void_labels = void_labels self.valid_labels = valid_labels diff --git a/tests/datamodules/test_imports.py b/tests/datamodules/test_imports.py new file mode 100644 index 0000000000..e51a817804 --- /dev/null +++ b/tests/datamodules/test_imports.py @@ -0,0 +1,37 @@ +import importlib +from unittest import mock + +import pytest + + +@pytest.mark.parametrize("dm_cls,deps", [ + ("AsynchronousLoader", []), + ("BinaryMNISTDataModule", ["torchvision"]), + ("CIFAR10DataModule", ["torchvision"]), + ("TinyCIFAR10DataModule", ["torchvision"]), + ("DiscountedExperienceSource", ["gym"]), + ("ExperienceSource", ["gym"]), + ("ExperienceSourceDataset", ["gym"]), + ("FashionMNISTDataModule", ["torchvision"]), + ("ImagenetDataModule", ["torchvision"]), + ("MNISTDataModule", ["torchvision"]), + ("SklearnDataModule", ["sklearn"]), + ("SklearnDataset", []), + ("TensorDataset", []), + ("SSLImagenetDataModule", ["torchvision"]), + ("STL10DataModule", ["torchvision"]), + ("VOCDetectionDataModule", ["torchvision"]), + ("CityscapesDataModule", ["torchvision"]), + ("KittiDataset", ["PIL"]), + ("KittiDataModule", ["torchvision"]), +]) +def test_import(dm_cls, deps): + """Tests importing when dependencies are not met. + + Set the followings in @pytest.mark.parametrize: + dm_cls: class to test importing + deps: packages required for dm_cls + """ + with mock.patch.dict("sys.modules", {pkg: None for pkg in deps}): + dms_module = importlib.import_module("pl_bolts.datamodules") + assert hasattr(dms_module, dm_cls), f"`from pl_bolts.datamodules import {dm_cls}` failed."