Skip to content

Commit

Permalink
Fix datamodules ImportError and add tests (#380)
Browse files Browse the repository at this point in the history
* Warn when unavailable

* Add tests

* Add more to the tests

* Fix experience_source

* Fix vocdetection_datamodule

* Fix imagenet

* Update test_imports

* Fix kitti

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Add docstring to tests

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
akihironitta and Borda authored Nov 24, 2020
1 parent f0e2bee commit ecbb82a
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 13 deletions.
15 changes: 12 additions & 3 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,23 @@
Datamodules for RL models that rely on experiences generated during training
Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
"""
import importlib
from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterable, List, Tuple

import torch
from gym import Env
from torch.utils.data import IterableDataset

from pl_bolts.utils.warnings import warn_missing_pkg

_GYM_AVAILABLE = importlib.util.find_spec("gym") is not None
if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg("gym") # pragma: no-cover


# Datasets

Experience = namedtuple(
Expand Down Expand Up @@ -172,7 +181,7 @@ def env_actions(self, device) -> List[List[int]]:

return actions

def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
def env_step(self, env_idx: int, env: "Env", action: List[int]) -> Experience:
"""
Carries out a step through the given environment using the given action
Expand Down Expand Up @@ -236,7 +245,7 @@ def pop_rewards_steps(self):
class DiscountedExperienceSource(ExperienceSource):
"""Outputs experiences with a discounted reward over N steps"""

def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
def __init__(self, env: "Env", agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent, (n_steps + 1))
self.gamma = gamma
self.steps = n_steps
Expand Down
14 changes: 13 additions & 1 deletion pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import importlib
import os

import torch
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split

from pl_bolts.datasets.kitti_dataset import KittiDataset
from pl_bolts.utils.warnings import warn_missing_pkg

_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover


class KittiDataModule(LightningDataModule):
Expand Down Expand Up @@ -56,6 +63,11 @@ def __init__(
batch_size: the batch size
seed: random seed to be used for train/val/test splits
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `torchvision` which is not installed yet.'
)

super().__init__(*args, **kwargs)
self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.batch_size = batch_size
Expand Down
5 changes: 3 additions & 2 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from pl_bolts.utils.warnings import warn_missing_pkg

try:
from sklearn.utils import shuffle as sk_shuffle
except ModuleNotFoundError:
raise ModuleNotFoundError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')
warn_missing_pkg("sklearn") # pragma: no-cover
_SKLEARN_AVAILABLE = False
else:
_SKLEARN_AVAILABLE = True
Expand Down
3 changes: 1 addition & 2 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import torch
import torchvision.transforms as T
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from pl_bolts.utils.warnings import warn_missing_pkg

try:
import torchvision.transforms as T
from torchvision.datasets import VOCDetection

except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
_TORCHVISION_AVAILABLE = False
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import torch
from torch._six import PY3

from pl_bolts.utils.warnings import warn_missing_pkg

try:
from torchvision.datasets import ImageNet
from torchvision.datasets.imagenet import load_meta_file
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
) from err
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
ImageNet = object


class UnlabeledImagenet(ImageNet):
Expand Down
16 changes: 15 additions & 1 deletion pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import importlib
import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from pl_bolts.utils.warnings import warn_missing_pkg

_PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None
if _PIL_AVAILABLE:
from PIL import Image
else:
warn_missing_pkg('PIL') # pragma: no-cover


DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)

Expand Down Expand Up @@ -41,6 +50,11 @@ def __init__(
void_labels: useless classes to be excluded from training
valid_labels: useful classes to include
"""
if not _PIL_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `PIL` which is not installed yet.'
)

self.img_size = img_size
self.void_labels = void_labels
self.valid_labels = valid_labels
Expand Down
37 changes: 37 additions & 0 deletions tests/datamodules/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import importlib
from unittest import mock

import pytest


@pytest.mark.parametrize("dm_cls,deps", [
("AsynchronousLoader", []),
("BinaryMNISTDataModule", ["torchvision"]),
("CIFAR10DataModule", ["torchvision"]),
("TinyCIFAR10DataModule", ["torchvision"]),
("DiscountedExperienceSource", ["gym"]),
("ExperienceSource", ["gym"]),
("ExperienceSourceDataset", ["gym"]),
("FashionMNISTDataModule", ["torchvision"]),
("ImagenetDataModule", ["torchvision"]),
("MNISTDataModule", ["torchvision"]),
("SklearnDataModule", ["sklearn"]),
("SklearnDataset", []),
("TensorDataset", []),
("SSLImagenetDataModule", ["torchvision"]),
("STL10DataModule", ["torchvision"]),
("VOCDetectionDataModule", ["torchvision"]),
("CityscapesDataModule", ["torchvision"]),
("KittiDataset", ["PIL"]),
("KittiDataModule", ["torchvision"]),
])
def test_import(dm_cls, deps):
"""Tests importing when dependencies are not met.
Set the followings in @pytest.mark.parametrize:
dm_cls: class to test importing
deps: packages required for dm_cls
"""
with mock.patch.dict("sys.modules", {pkg: None for pkg in deps}):
dms_module = importlib.import_module("pl_bolts.datamodules")
assert hasattr(dms_module, dm_cls), f"`from pl_bolts.datamodules import {dm_cls}` failed."

0 comments on commit ecbb82a

Please sign in to comment.