Skip to content

Commit

Permalink
Minor refactors - cleaning models (#524)
Browse files Browse the repository at this point in the history
* autoencoders

* detection

* gans

* vision

* rl

* regression

* mnist_module

* models/__init__

* yapf

* yapf

* minor refactor

* Remove re-import
  • Loading branch information
akihironitta authored Jan 19, 2021
1 parent da35d3d commit 6c307c1
Show file tree
Hide file tree
Showing 28 changed files with 152 additions and 88 deletions.
6 changes: 3 additions & 3 deletions pl_bolts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
except NameError:
__LIGHTNING_BOLT_SETUP__: bool = False

if __LIGHTNING_BOLT_SETUP__:
import sys # pragma: no-cover
if __LIGHTNING_BOLT_SETUP__: # pragma: no cover
import sys

sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n')
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pl_bolts import callbacks, datamodules, datasets, losses, metrics, models, optimizers, transforms, utils
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
warn_missing_pkg("sklearn", pypi_name="scikit-learn")


class KNNOnlineEvaluator(Callback): # pragma: no-cover
class KNNOnlineEvaluator(Callback): # pragma: no cover
"""
Evaluates self-supervised K nearest neighbors.
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Figure = object


class ConfusedLogitCallback(Callback): # pragma: no-cover
class ConfusedLogitCallback(Callback): # pragma: no cover
"""
Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
doesn't have high certainty that it should pick one vs the other class.
Expand Down Expand Up @@ -122,8 +122,8 @@ def _plot(
model: LightningModule,
mask_idxs: Tensor,
) -> None:
if not _MATPLOTLIB_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _MATPLOTLIB_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `matplotlib` which is not installed yet, install it with `pip install matplotlib`.'
)

Expand Down
10 changes: 7 additions & 3 deletions pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
from pytorch_lightning import Callback, LightningModule, Trainer

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
import torchvision
except ModuleNotFoundError:
warn_missing_pkg("torchvision") # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("torchvision")


class TensorboardGenerativeModelImageSampler(Callback):
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

super().__init__()
self.num_samples = num_samples
self.nrow = nrow
Expand Down
13 changes: 13 additions & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,16 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression # noqa: F401
from pl_bolts.models.vision import PixelCNN, SemSegment, UNet # noqa: F401
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT # noqa: F401

__all__ = [
"AE",
"VAE",
"LitMNIST",
"LinearRegression",
"LogisticRegression",
"PixelCNN",
"SemSegment",
"UNet",
"GPT2",
"ImageGPT",
]
9 changes: 9 additions & 0 deletions pl_bolts/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
resnet50_decoder,
resnet50_encoder,
)

__all__ = [
"AE",
"VAE",
"resnet18_decoder",
"resnet18_encoder",
"resnet50_decoder",
"resnet50_encoder",
]
12 changes: 7 additions & 5 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
try:
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
except ModuleNotFoundError: # pragma: no-cover
pass # pragma: no-cover
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401

__all__ = [
"components",
"FasterRCNN",
]
9 changes: 7 additions & 2 deletions pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pytorch_lightning as pl
import torch

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.ops import box_iou

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")

Expand All @@ -22,6 +21,9 @@ def _evaluate_iou(target, pred):
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
Expand Down Expand Up @@ -69,6 +71,9 @@ def __init__(
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()

self.learning_rate = learning_rate
Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401

__all__ = [
"GAN",
"DCGAN",
]
2 changes: 1 addition & 1 deletion pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import LSUN, MNIST
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("torchvision")


Expand Down
10 changes: 7 additions & 3 deletions pl_bolts/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import MNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class LitMNIST(LightningModule):

def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir='', **kwargs):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()
self.save_hyperparameters()

Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/models/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from pl_bolts.models.regression.linear_regression import LinearRegression # noqa: F401
from pl_bolts.models.regression.logistic_regression import LogisticRegression # noqa: F401

__all__ = [
"LinearRegression",
"LogisticRegression",
]
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,17 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

pl.seed_everything(1234)

# create dataset
try:
if _SKLEARN_AVAILABLE:
from sklearn.datasets import load_boston
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
else: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err
)

# args
parser = ArgumentParser()
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,17 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

pl.seed_everything(1234)

# Example: Iris dataset in Sklearn (4 features, 3 class labels)
try:
if _SKLEARN_AVAILABLE:
from sklearn.datasets import load_iris
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
else: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err
)

# args
parser = ArgumentParser()
Expand Down
27 changes: 17 additions & 10 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
try:
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401
except ModuleNotFoundError:
pass
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401

__all__ = [
"DoubleDQN",
"DQN",
"DuelingDQN",
"NoisyDQN",
"PERDQN",
"Reinforce",
"VanillaPolicyGradient",
]
24 changes: 17 additions & 7 deletions pl_bolts/models/rl/common/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@
import gym.spaces
from gym import make as gym_make
from gym import ObservationWrapper, Wrapper
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Wrapper = object
ObservationWrapper = object

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('cv2', pypi_name='opencv-python')


class ToTensor(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(ToTensor, self).__init__(env)

def step(self, action):
Expand All @@ -45,6 +48,9 @@ class FireResetEnv(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3
Expand All @@ -69,6 +75,9 @@ class MaxAndSkipEnv(Wrapper):
"""Return only every `skip`-th frame"""

def __init__(self, env=None, skip=4):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
Expand Down Expand Up @@ -99,8 +108,7 @@ class ProcessFrame84(ObservationWrapper):
"""preprocessing images from env"""

def __init__(self, env=None):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ProcessFrame84, self).__init__(env)
Expand Down Expand Up @@ -130,8 +138,7 @@ class ImageToPyTorch(ObservationWrapper):
"""converts image to pytorch format"""

def __init__(self, env):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ImageToPyTorch, self).__init__(env)
Expand Down Expand Up @@ -188,6 +195,9 @@ class DataAugmentation(ObservationWrapper):
"""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super().__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Env = object


Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/models/rl/reinforce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

if _GYM_AVAILABLE:
import gym
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')


class Reinforce(pl.LightningModule):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
"""
super().__init__()

if not _GYM_AVAILABLE:
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

# Hyperparameters
Expand Down
Loading

0 comments on commit 6c307c1

Please sign in to comment.