From 6c307c1292d26562b849c35695819cc93c104a8f Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Tue, 19 Jan 2021 16:18:53 +0900 Subject: [PATCH] Minor refactors - cleaning models (#524) * autoencoders * detection * gans * vision * rl * regression * mnist_module * models/__init__ * yapf * yapf * minor refactor * Remove re-import --- pl_bolts/__init__.py | 6 ++--- pl_bolts/callbacks/knn_online.py | 2 +- pl_bolts/callbacks/vision/confused_logit.py | 6 ++--- pl_bolts/callbacks/vision/image_generation.py | 10 ++++--- pl_bolts/models/__init__.py | 13 +++++++++ pl_bolts/models/autoencoders/__init__.py | 9 +++++++ pl_bolts/models/detection/__init__.py | 12 +++++---- .../faster_rcnn/faster_rcnn_module.py | 9 +++++-- pl_bolts/models/gans/__init__.py | 5 ++++ pl_bolts/models/gans/dcgan/dcgan_module.py | 2 +- pl_bolts/models/mnist_module.py | 10 ++++--- pl_bolts/models/regression/__init__.py | 5 ++++ .../models/regression/linear_regression.py | 9 ++++--- .../models/regression/logistic_regression.py | 9 ++++--- pl_bolts/models/rl/__init__.py | 27 ++++++++++++------- pl_bolts/models/rl/common/gym_wrappers.py | 24 ++++++++++++----- pl_bolts/models/rl/dqn_model.py | 4 +-- pl_bolts/models/rl/reinforce_model.py | 6 ++--- .../rl/vanilla_policy_gradient_model.py | 6 ++--- .../self_supervised/amdim/transforms.py | 6 ++--- .../self_supervised/cpc/cpc_finetuner.py | 2 +- .../models/self_supervised/cpc/cpc_module.py | 10 ++----- .../self_supervised/moco/moco2_module.py | 18 ++++++------- pl_bolts/models/self_supervised/resnets.py | 7 ++--- .../self_supervised/simclr/transforms.py | 4 +-- pl_bolts/models/vision/__init__.py | 6 +++++ pl_bolts/transforms/dataset_normalizations.py | 2 +- .../self_supervised/ssl_transforms.py | 11 ++++---- 28 files changed, 152 insertions(+), 88 deletions(-) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index aab6e6ca42..61d1929f33 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -38,10 +38,10 @@ except NameError: __LIGHTNING_BOLT_SETUP__: bool = False -if __LIGHTNING_BOLT_SETUP__: - import sys # pragma: no-cover +if __LIGHTNING_BOLT_SETUP__: # pragma: no cover + import sys - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover + sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # We are not importing the rest of the lightning during the build process, as it may not be compiled yet else: from pl_bolts import callbacks, datamodules, datasets, losses, metrics, models, optimizers, transforms, utils diff --git a/pl_bolts/callbacks/knn_online.py b/pl_bolts/callbacks/knn_online.py index 32eda875bb..248db6b497 100644 --- a/pl_bolts/callbacks/knn_online.py +++ b/pl_bolts/callbacks/knn_online.py @@ -14,7 +14,7 @@ warn_missing_pkg("sklearn", pypi_name="scikit-learn") -class KNNOnlineEvaluator(Callback): # pragma: no-cover +class KNNOnlineEvaluator(Callback): # pragma: no cover """ Evaluates self-supervised K nearest neighbors. diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 992aea8ae8..62cb20ab6d 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -17,7 +17,7 @@ Figure = object -class ConfusedLogitCallback(Callback): # pragma: no-cover +class ConfusedLogitCallback(Callback): # pragma: no cover """ Takes the logit predictions of a model and when the probabilities of two classes are very close, the model doesn't have high certainty that it should pick one vs the other class. @@ -122,8 +122,8 @@ def _plot( model: LightningModule, mask_idxs: Tensor, ) -> None: - if not _MATPLOTLIB_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover + if not _MATPLOTLIB_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError( 'You want to use `matplotlib` which is not installed yet, install it with `pip install matplotlib`.' ) diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index 83ee748d05..92c66aabef 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -3,12 +3,13 @@ import torch from pytorch_lightning import Callback, LightningModule, Trainer +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: import torchvision -except ModuleNotFoundError: - warn_missing_pkg("torchvision") # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg("torchvision") class TensorboardGenerativeModelImageSampler(Callback): @@ -57,6 +58,9 @@ def __init__( images separately rather than the (min, max) over all images. Default: ``False``. pad_value: Value for the padded pixels. Default: ``0``. """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") + super().__init__() self.num_samples = num_samples self.nrow = nrow diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index 9d74fb32d6..fe7e2e730f 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -8,3 +8,16 @@ from pl_bolts.models.regression import LinearRegression, LogisticRegression # noqa: F401 from pl_bolts.models.vision import PixelCNN, SemSegment, UNet # noqa: F401 from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT # noqa: F401 + +__all__ = [ + "AE", + "VAE", + "LitMNIST", + "LinearRegression", + "LogisticRegression", + "PixelCNN", + "SemSegment", + "UNet", + "GPT2", + "ImageGPT", +] diff --git a/pl_bolts/models/autoencoders/__init__.py b/pl_bolts/models/autoencoders/__init__.py index 7eca48f64a..d665041bff 100644 --- a/pl_bolts/models/autoencoders/__init__.py +++ b/pl_bolts/models/autoencoders/__init__.py @@ -11,3 +11,12 @@ resnet50_decoder, resnet50_encoder, ) + +__all__ = [ + "AE", + "VAE", + "resnet18_decoder", + "resnet18_encoder", + "resnet50_decoder", + "resnet50_encoder", +] diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index c181ee0649..4c09eac3d4 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -1,5 +1,7 @@ -try: - from pl_bolts.models.detection import components # noqa: F401 - from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 -except ModuleNotFoundError: # pragma: no-cover - pass # pragma: no-cover +from pl_bolts.models.detection import components # noqa: F401 +from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401 + +__all__ = [ + "components", + "FasterRCNN", +] diff --git a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py index 0cff1d800d..7757bfe543 100644 --- a/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py +++ b/pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py @@ -4,6 +4,7 @@ import pytorch_lightning as pl import torch +from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -11,8 +12,6 @@ from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor from torchvision.ops import box_iou - - from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone else: # pragma: no cover warn_missing_pkg("torchvision") @@ -22,6 +21,9 @@ def _evaluate_iou(target, pred): Evaluate intersection over union (IOU) for target from dataset and output prediction from model """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.') + if pred["boxes"].shape[0] == 0: # no box detected, 0 IOU return torch.tensor(0.0, device=pred["boxes"].device) @@ -69,6 +71,9 @@ def __init__( pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers: number of trainable resnet layers starting from final block """ + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.') + super().__init__() self.learning_rate = learning_rate diff --git a/pl_bolts/models/gans/__init__.py b/pl_bolts/models/gans/__init__.py index c28eb32124..5cca383df1 100644 --- a/pl_bolts/models/gans/__init__.py +++ b/pl_bolts/models/gans/__init__.py @@ -1,2 +1,7 @@ from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401 from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401 + +__all__ = [ + "GAN", + "DCGAN", +] diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index b99f7f4f99..11d0e2aa67 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -14,7 +14,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import LSUN, MNIST -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg("torchvision") diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 39c128d424..be6c4ee623 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -5,18 +5,22 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision import transforms from torchvision.datasets import MNIST -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') class LitMNIST(LightningModule): def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir='', **kwargs): + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.') + super().__init__() self.save_hyperparameters() diff --git a/pl_bolts/models/regression/__init__.py b/pl_bolts/models/regression/__init__.py index ab61583091..68d6d1ced3 100644 --- a/pl_bolts/models/regression/__init__.py +++ b/pl_bolts/models/regression/__init__.py @@ -1,2 +1,7 @@ from pl_bolts.models.regression.linear_regression import LinearRegression # noqa: F401 from pl_bolts.models.regression.logistic_regression import LogisticRegression # noqa: F401 + +__all__ = [ + "LinearRegression", + "LogisticRegression", +] diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 4e935085af..7bee23db74 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -110,16 +110,17 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule + from pl_bolts.utils import _SKLEARN_AVAILABLE pl.seed_everything(1234) # create dataset - try: + if _SKLEARN_AVAILABLE: from sklearn.datasets import load_boston - except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover + else: # pragma: no cover + raise ModuleNotFoundError( 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' - ) from err + ) # args parser = ArgumentParser() diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index f57bd85cfe..2601eadd60 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -116,16 +116,17 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule + from pl_bolts.utils import _SKLEARN_AVAILABLE pl.seed_everything(1234) # Example: Iris dataset in Sklearn (4 features, 3 class labels) - try: + if _SKLEARN_AVAILABLE: from sklearn.datasets import load_iris - except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover + else: # pragma: no cover + raise ModuleNotFoundError( 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' - ) from err + ) # args parser = ArgumentParser() diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 0ef392113b..070ec666be 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -1,10 +1,17 @@ -try: - from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401 - from pl_bolts.models.rl.dqn_model import DQN # noqa: F401 - from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401 - from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 - from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 - from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 - from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 -except ModuleNotFoundError: - pass +from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401 +from pl_bolts.models.rl.dqn_model import DQN # noqa: F401 +from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401 +from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 +from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 +from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 +from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 + +__all__ = [ + "DoubleDQN", + "DQN", + "DuelingDQN", + "NoisyDQN", + "PERDQN", + "Reinforce", + "VanillaPolicyGradient", +] diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py index cd389e2fbd..d14a736b9d 100644 --- a/pl_bolts/models/rl/common/gym_wrappers.py +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -14,21 +14,24 @@ import gym.spaces from gym import make as gym_make from gym import ObservationWrapper, Wrapper -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('gym') Wrapper = object ObservationWrapper = object if _OPENCV_AVAILABLE: import cv2 -else: - warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('cv2', pypi_name='opencv-python') class ToTensor(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `gym` which is not installed yet.') + super(ToTensor, self).__init__(env) def step(self, action): @@ -45,6 +48,9 @@ class FireResetEnv(Wrapper): """For environments where the user need to press FIRE for the game to start.""" def __init__(self, env=None): + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `gym` which is not installed yet.') + super(FireResetEnv, self).__init__(env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 @@ -69,6 +75,9 @@ class MaxAndSkipEnv(Wrapper): """Return only every `skip`-th frame""" def __init__(self, env=None, skip=4): + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `gym` which is not installed yet.') + super(MaxAndSkipEnv, self).__init__(env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = collections.deque(maxlen=2) @@ -99,8 +108,7 @@ class ProcessFrame84(ObservationWrapper): """preprocessing images from env""" def __init__(self, env=None): - - if not _OPENCV_AVAILABLE: + if not _OPENCV_AVAILABLE: # pragma: no cover raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.') super(ProcessFrame84, self).__init__(env) @@ -130,8 +138,7 @@ class ImageToPyTorch(ObservationWrapper): """converts image to pytorch format""" def __init__(self, env): - - if not _OPENCV_AVAILABLE: + if not _OPENCV_AVAILABLE: # pragma: no cover raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.') super(ImageToPyTorch, self).__init__(env) @@ -188,6 +195,9 @@ class DataAugmentation(ObservationWrapper): """ def __init__(self, env=None): + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `gym` which is not installed yet.') + super().__init__(env) self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 2c98d5efa2..0d43bbe6a2 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -25,8 +25,8 @@ if _GYM_AVAILABLE: from gym import Env -else: - warn_missing_pkg('gym') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('gym') Env = object diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 9207fdf92b..02cf4e5732 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -21,8 +21,8 @@ if _GYM_AVAILABLE: import gym -else: - warn_missing_pkg('gym') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('gym') class Reinforce(pl.LightningModule): @@ -80,7 +80,7 @@ def __init__( """ super().__init__() - if not _GYM_AVAILABLE: + if not _GYM_AVAILABLE: # pragma: no cover raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.') # Hyperparameters diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index 52abc96b61..1e6aaa4199 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -20,8 +20,8 @@ if _GYM_AVAILABLE: import gym -else: - warn_missing_pkg('gym') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('gym') class VanillaPolicyGradient(pl.LightningModule): @@ -79,7 +79,7 @@ def __init__( """ super().__init__() - if not _GYM_AVAILABLE: + if not _GYM_AVAILABLE: # pragma: no cover raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.') # Hyperparameters diff --git a/pl_bolts/models/self_supervised/amdim/transforms.py b/pl_bolts/models/self_supervised/amdim/transforms.py index 3dbc99b123..9076a9cfeb 100644 --- a/pl_bolts/models/self_supervised/amdim/transforms.py +++ b/pl_bolts/models/self_supervised/amdim/transforms.py @@ -30,10 +30,8 @@ class AMDIMTrainTransformsCIFAR10: """ def __init__(self): - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `transforms` from `torchvision` which is not installed yet.' - ) + if not _TORCHVISION_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError('You want to use `transforms` from `torchvision` which is not installed yet.') # flipping image along vertical axis self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py index a1decdf7c6..69895e8eb5 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py @@ -12,7 +12,7 @@ ) -def cli_main(): # pragma: no-cover +def cli_main(): # pragma: no cover from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule pl.seed_everything(1234) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index f261719269..54a65ff6c2 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -11,6 +11,7 @@ from pytorch_lightning.utilities import rank_zero_warn from torch import optim as optim +from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.losses.self_supervised_learning import CPCTask from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101 from pl_bolts.models.self_supervised.cpc.transforms import ( @@ -150,13 +151,6 @@ def validation_step(self, batch, batch_nb): return nce_loss def shared_step(self, batch): - try: - from pl_bolts.datamodules.stl10_datamodule import STL10DataModule - except ModuleNotFoundError as err: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) from err - if isinstance(self.datamodule, STL10DataModule): unlabeled_batch = batch[0] batch = unlabeled_batch @@ -201,7 +195,7 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule + from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule pl.seed_everything(1234) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 4712fb92c2..7bd63d27e2 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -17,13 +17,6 @@ from torch import nn from torch.nn import functional as F -from pl_bolts.utils.warnings import warn_missing_pkg - -try: - import torchvision -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover - from pl_bolts.metrics import mean, precision_at_k from pl_bolts.models.self_supervised.moco.transforms import ( Moco2EvalCIFAR10Transforms, @@ -33,6 +26,13 @@ Moco2TrainImagenetTransforms, Moco2TrainSTL10Transforms, ) +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + import torchvision +else: # pragma: no cover + warn_missing_pkg('torchvision') class MocoV2(pl.LightningModule): @@ -159,7 +159,7 @@ def _dequeue_and_enqueue(self, keys): self.queue_ptr[0] = ptr @torch.no_grad() - def _batch_shuffle_ddp(self, x): # pragma: no-cover + def _batch_shuffle_ddp(self, x): # pragma: no cover """ Batch shuffle, for making use of BatchNorm. *** Only support DistributedDataParallel (DDP) model. *** @@ -187,7 +187,7 @@ def _batch_shuffle_ddp(self, x): # pragma: no-cover return x_gather[idx_this], idx_unshuffle @torch.no_grad() - def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no-cover + def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover """ Undo batch shuffle. *** Only support DistributedDataParallel (DDP) model. *** diff --git a/pl_bolts/models/self_supervised/resnets.py b/pl_bolts/models/self_supervised/resnets.py index 8729e27a73..dbe74fedec 100644 --- a/pl_bolts/models/self_supervised/resnets.py +++ b/pl_bolts/models/self_supervised/resnets.py @@ -1,12 +1,13 @@ import torch from torch import nn as nn +from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -try: +if _TORCHVISION_AVAILABLE: from torchvision.models.utils import load_state_dict_from_url -except ModuleNotFoundError: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') __all__ = [ 'ResNet', diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index 34015dc39d..39add46d20 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -5,8 +5,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transforms -else: - warn_missing_pkg('torchvision') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('torchvision') if _OPENCV_AVAILABLE: import cv2 diff --git a/pl_bolts/models/vision/__init__.py b/pl_bolts/models/vision/__init__.py index 00ce072f79..567e04a790 100644 --- a/pl_bolts/models/vision/__init__.py +++ b/pl_bolts/models/vision/__init__.py @@ -1,3 +1,9 @@ from pl_bolts.models.vision.pixel_cnn import PixelCNN # noqa: F401 from pl_bolts.models.vision.segmentation import SemSegment # noqa: F401 from pl_bolts.models.vision.unet import UNet # noqa: F401 + +__all__ = [ + "PixelCNN", + "SemSegment", + "UNet", +] diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index 14192099c0..f07447c82b 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -3,7 +3,7 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms -else: # pragma: no-cover +else: # pragma: no cover warn_missing_pkg('torchvision') diff --git a/pl_bolts/transforms/self_supervised/ssl_transforms.py b/pl_bolts/transforms/self_supervised/ssl_transforms.py index d89e4e5900..c612f1b02a 100644 --- a/pl_bolts/transforms/self_supervised/ssl_transforms.py +++ b/pl_bolts/transforms/self_supervised/ssl_transforms.py @@ -6,8 +6,8 @@ if _PIL_AVAILABLE: from PIL import Image -else: - warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover +else: # pragma: no cover + warn_missing_pkg('PIL', pypi_name='Pillow') class RandomTranslateWithReflect: @@ -20,13 +20,12 @@ class RandomTranslateWithReflect: """ def __init__(self, max_translation): + if not _PIL_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("You want to use `Pillow` which is not installed yet.") + self.max_translation = max_translation def __call__(self, old_image): - if not _PIL_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use `Pillow` which is not installed yet, install it with `pip install Pillow`.' - ) xtranslation, ytranslation = np.random.randint(-self.max_translation, self.max_translation + 1, size=2) xpad, ypad = abs(xtranslation), abs(ytranslation) xsize, ysize = old_image.size