Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactors - datasets/datamodules/optimizers/transforms/metrics/utils #523

Merged
merged 9 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pl_bolts/datamodules/binary_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


Expand Down Expand Up @@ -71,8 +71,8 @@ def __init__(
drop_last: If true drops the last incomplete batch
"""

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use transforms loaded from `torchvision` which is not installed yet."
)

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR10
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')
CIFAR10 = None


Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import Cityscapes
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CityscapesDataModule(LightningDataModule):
Expand Down Expand Up @@ -88,8 +88,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use CityScapes dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
9 changes: 7 additions & 2 deletions pl_bolts/datamodules/fashion_mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import FashionMNIST
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')
FashionMNIST = None


Expand Down Expand Up @@ -71,6 +71,11 @@ def __init__(
returning them
drop_last: If true drops the last incomplete batch
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.'
)

super().__init__(
data_dir=data_dir,
val_split=val_split,
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class ImagenetDataModule(LightningDataModule):
Expand Down Expand Up @@ -76,8 +76,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
9 changes: 7 additions & 2 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import MNIST
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')
MNIST = None


Expand Down Expand Up @@ -70,6 +70,11 @@ def __init__(
returning them
drop_last: If true drops the last incomplete batch
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
)

super().__init__(
data_dir=data_dir,
val_split=val_split,
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else:
warn_missing_pkg("sklearn") # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("sklearn")


class SklearnDataset(Dataset):
Expand Down Expand Up @@ -172,8 +172,8 @@ def __init__(
# shuffle x and y
if shuffle and _SKLEARN_AVAILABLE:
X, y = sk_shuffle(X, y, random_state=random_state)
elif shuffle and not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
elif shuffle and not _SKLEARN_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use shuffle function from `scikit-learn` which is not installed yet.'
)

Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class SSLImagenetDataModule(LightningDataModule): # pragma: no cover
Expand All @@ -32,8 +32,8 @@ def __init__(
):
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import STL10
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class STL10DataModule(LightningDataModule): # pragma: no cover
Expand Down Expand Up @@ -81,8 +81,8 @@ def __init__(
"""
super().__init__(*args, **kwargs)

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use STL10 dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T
from torchvision.datasets import VOCDetection
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class Compose(object):
Expand Down Expand Up @@ -116,8 +116,8 @@ def __init__(
*args,
**kwargs,
):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
)

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/datasets/cifar10_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

if _PIL_AVAILABLE:
from PIL import Image
else:
warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('PIL', pypi_name='Pillow')


class CIFAR10(LightDataset):
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if _PIL_AVAILABLE:
from PIL import Image
else: # pragma: no cover
warn_missing_pkg('PIL')
warn_missing_pkg('PIL', pypi_name='Pillow')

DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
Expand Down
11 changes: 7 additions & 4 deletions pl_bolts/datasets/mnist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

if _TORCHVISION_AVAILABLE:
from torchvision.datasets import MNIST
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')
MNIST = object

if _PIL_AVAILABLE:
from PIL import Image
else:
warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('PIL', pypi_name='Pillow')


class BinaryMNIST(MNIST):
Expand All @@ -22,6 +22,9 @@ def __getitem__(self, idx):
Returns:
tuple: (image, target) where target is index of the target class.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

img, target = self.data[idx], int(self.targets[idx])

# doing this so that it is consistent with all other datasets
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/datasets/ssl_amdim_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
nb_labeled_per_class: Optional[int] = None,
val_pct: float = 0.10
):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

if nb_labeled_per_class == -1:
nb_labeled_per_class = None
Expand Down
6 changes: 6 additions & 0 deletions pl_bolts/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k # noqa: F401

__all__ = [
"accuracy",
"mean",
"precision_at_k",
]
5 changes: 5 additions & 0 deletions pl_bolts/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from pl_bolts.optimizers.lars_scheduling import LARSWrapper # noqa: F401
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401

__all__ = [
"LARSWrapper",
"LinearWarmupCosineAnnealingLR",
]
22 changes: 19 additions & 3 deletions pl_bolts/transforms/dataset_normalizations.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no-cover
warn_missing_pkg('torchvision')


def imagenet_normalization():
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
return normalize


def cifar10_normalization():
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
)

normalize = transforms.Normalize(
mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
Expand All @@ -20,5 +31,10 @@ def cifar10_normalization():


def stl10_normalization():
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
)

normalize = transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27))
return normalize
2 changes: 1 addition & 1 deletion pl_bolts/utils/pretrained_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
urls = {'vae-imagenet2012': vae_imagenet2012, 'CPCV2-resnet18': cpcv2_resnet18}


def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no-cover
def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover
if class_name is None:
class_name = model.__class__.__name__
ckpt_url = urls[class_name]
Expand Down
4 changes: 1 addition & 3 deletions pl_bolts/utils/self_supervised.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.nn import Module

from pl_bolts.models.self_supervised import resnets
from pl_bolts.utils.semi_supervised import Identity


Expand All @@ -8,9 +9,6 @@ def torchvision_ssl_encoder(
pretrained: bool = False,
return_all_feature_maps: bool = False,
) -> Module:
from pl_bolts.models.self_supervised import resnets

pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)

pretrained_model.fc = Identity()
return pretrained_model
8 changes: 3 additions & 5 deletions pl_bolts/utils/semi_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('sklearn', pypi_name='scikit-learn')


Expand Down Expand Up @@ -45,10 +45,8 @@ def balance_classes(X: Union[Tensor, np.ndarray], Y: Union[Tensor, np.ndarray, S
Y: mixed labels (ints)
batch_size: the ultimate batch size
"""
if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)
if not _SKLEARN_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `shuffle` function from `scikit-learn` which is not installed yet.')

nb_classes = len(set(Y))

Expand Down