diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15b0f8b175..9a129a04f2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,5 +5,6 @@ repos: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.790 hooks: - id: mypy diff --git a/CHANGELOG.md b/CHANGELOG.md index 26a8c2779f..e986a3a5fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added data monitor callbacks `ModuleDataMonitor` and `TrainingDataMonitor` ([#285](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/285)) +- Added `VisionDataModule` as parent class for `BinaryMNISTDataModule`, `CIFAR10DataModule`, `FashionMNISTDataModule`, + and `MNISTDataModule` ([#400](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/400)) + ### Changed - Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270)) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 191ea45a17..142b3d54ef 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -1,7 +1,6 @@ -import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, random_split +from typing import Any, Optional, Union +from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.mnist_dataset import BinaryMNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -12,7 +11,7 @@ warn_missing_pkg('torchvision') -class BinaryMNISTDataModule(LightningDataModule): +class BinaryMNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png :width: 400 @@ -41,136 +40,70 @@ class BinaryMNISTDataModule(LightningDataModule): """ name = "binary_mnist" + dataset_cls = BinaryMNIST + dims = (1, 28, 28) def __init__( - self, - data_dir: str, - val_split: int = 5000, - num_workers: int = 16, - normalize: bool = False, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = False, - pin_memory: bool = False, - drop_last: bool = False, - *args, - **kwargs, - ): + self, + data_dir: Optional[str] = None, + val_split: Union[int, float] = 0.2, + num_workers: int = 16, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: """ Args: - data_dir: where to save/load the data - val_split: how many of the training images to use for the validation split - num_workers: how many workers to use for loading data + data_dir: Where to save/load the data + val_split: Percent (float) or number (int) of samples to use for the validation split + num_workers: How many workers to use for loading data normalize: If true applies image normalize - batch_size: size of batch - seed: random seed to be used for train/val/test splits - shuffle: If true shuffles the data every epoch + batch_size: How many samples per batch to load + seed: Random seed to be used for train/val/test splits + shuffle: If true shuffles the train data every epoch pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before returning them drop_last: If true drops the last incomplete batch """ - super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover - 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' + "You want to use transforms loaded from `torchvision` which is not installed yet." ) - self.dims = (1, 28, 28) - self.data_dir = data_dir - self.val_split = val_split - self.num_workers = num_workers - self.normalize = normalize - self.batch_size = batch_size - self.seed = seed - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last + super().__init__( + data_dir=data_dir, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + *args, + **kwargs, + ) @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 10 """ return 10 - def prepare_data(self): - """ - Saves MNIST files to data_dir - """ - BinaryMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) - BinaryMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - - def train_dataloader(self): - """ - MNIST train set removes a subset to use for validation - """ - transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms - - dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - dataset_train, _ = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_train, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def val_dataloader(self): - """ - MNIST val set uses a subset of the training set for validation - """ - transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms - dataset = BinaryMNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - _, dataset_val = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_val, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def test_dataloader(self): - """ - MNIST test set uses the test split - """ - transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms - - dataset = BinaryMNIST(self.data_dir, train=False, download=False, transform=transforms) - loader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def _default_transforms(self): + def default_transforms(self): if self.normalize: - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) + mnist_transforms = transform_lib.Compose( + [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] + ) else: - mnist_transforms = transform_lib.ToTensor() + mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) return mnist_transforms diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 29c075654f..afb2df8c9a 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -1,10 +1,6 @@ -import os -from typing import Optional, Sequence - -import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, random_split +from typing import Any, Optional, Sequence, Union +from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -15,9 +11,10 @@ from torchvision.datasets import CIFAR10 else: warn_missing_pkg('torchvision') # pragma: no-cover + CIFAR10 = None -class CIFAR10DataModule(LightningDataModule): +class CIFAR10DataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png @@ -57,137 +54,70 @@ class CIFAR10DataModule(LightningDataModule): dm.test_transforms = ... dm.val_transforms = ... """ - - name = 'cifar10' - extra_args = {} + name = "cifar10" + dataset_cls = CIFAR10 + dims = (3, 32, 32) def __init__( - self, - data_dir: Optional[str] = None, - val_split: int = 5000, - num_workers: int = 16, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = False, - pin_memory: bool = False, - drop_last: bool = False, - *args, - **kwargs, - ): + self, + data_dir: Optional[str] = None, + val_split: Union[int, float] = 0.2, + num_workers: int = 16, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: """ Args: - data_dir: where to save/load the data - val_split: how many of the training images to use for the validation split - num_workers: how many workers to use for loading data - batch_size: number of examples per training/eval step - seed: random seed to be used for train/val/test splits - shuffle: If true shuffles the data every epoch + data_dir: Where to save/load the data + val_split: Percent (float) or number (int) of samples to use for the validation split + num_workers: How many workers to use for loading data + normalize: If true applies image normalize + batch_size: How many samples per batch to load + seed: Random seed to be used for train/val/test splits + shuffle: If true shuffles the train data every epoch pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before returning them drop_last: If true drops the last incomplete batch """ - super().__init__(*args, **kwargs) - - if not _TORCHVISION_AVAILABLE: # pragma: no cover - raise ModuleNotFoundError( - 'You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.' - ) + super().__init__( + data_dir=data_dir, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + *args, + **kwargs, + ) - self.dims = (3, 32, 32) - self.DATASET = CIFAR10 - self.val_split = val_split - self.num_workers = num_workers - self.batch_size = batch_size - self.seed = seed - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last - self.data_dir = data_dir if data_dir is not None else os.getcwd() - self.num_samples = 50000 - val_split + @property + def num_samples(self) -> int: + train_len, _ = self._get_splits(len_dataset=50_000) + return train_len @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 10 """ return 10 - def prepare_data(self): - """ - Saves CIFAR10 files to data_dir - """ - self.DATASET(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor(), **self.extra_args) - self.DATASET(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor(), **self.extra_args) - - def train_dataloader(self): - """ - CIFAR train set removes a subset to use for validation - """ - transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms - - dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args) - train_length = len(dataset) - dataset_train, _ = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_train, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def val_dataloader(self): - """ - CIFAR10 val set uses a subset of the training set for validation - """ - transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms - - dataset = self.DATASET(self.data_dir, train=True, download=False, transform=transforms, **self.extra_args) - train_length = len(dataset) - _, dataset_val = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_val, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - drop_last=self.drop_last - ) - return loader - - def test_dataloader(self): - """ - CIFAR10 test set uses the test split - """ - transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms - - dataset = self.DATASET(self.data_dir, train=False, download=False, transform=transforms, **self.extra_args) - loader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader + def default_transforms(self): + if self.normalize: + cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) + else: + cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()]) - def _default_transforms(self): - cf10_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - cifar10_normalization() - ]) return cf10_transforms @@ -211,16 +141,19 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): model = LitModel(datamodule=dm) """ + dataset_cls = TrialCIFAR10 + dims = (3, 32, 32) + def __init__( - self, - data_dir: str, - val_split: int = 50, - num_workers: int = 16, - num_samples: int = 100, - labels: Optional[Sequence] = (1, 5, 8), - *args, - **kwargs, - ): + self, + data_dir: str, + val_split: int = 50, + num_workers: int = 16, + num_samples: int = 100, + labels: Optional[Sequence] = (1, 5, 8), + *args: Any, + **kwargs: Any, + ) -> None: """ Args: data_dir: where to save/load the data @@ -230,8 +163,7 @@ def __init__( labels: list selected CIFAR10 classes/labels """ super().__init__(data_dir, val_split, num_workers, *args, **kwargs) - self.dims = (3, 32, 32) - self.DATASET = TrialCIFAR10 + self.num_samples = num_samples self.labels = sorted(labels) if labels is not None else set(range(10)) self.extra_args = dict(num_samples=self.num_samples, labels=self.labels) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index fd60786ef2..833c4599a6 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -1,7 +1,6 @@ -import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, random_split +from typing import Any, Optional, Union +from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -10,9 +9,10 @@ from torchvision.datasets import FashionMNIST else: warn_missing_pkg('torchvision') # pragma: no-cover + FashionMNIST = None -class FashionMNISTDataModule(LightningDataModule): +class FashionMNISTDataModule(VisionDataModule): """ .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png @@ -40,131 +40,65 @@ class FashionMNISTDataModule(LightningDataModule): Trainer().fit(model, dm) """ - - name = 'fashion_mnist' + name = "fashion_mnist" + dataset_cls = FashionMNIST + dims = (1, 28, 28) def __init__( - self, - data_dir: str, - val_split: int = 5000, - num_workers: int = 16, - batch_size: int = 32, - seed: int = 42, - shuffle: bool = False, - pin_memory: bool = False, - drop_last: bool = False, - *args, - **kwargs, - ): + self, + data_dir: Optional[str] = None, + val_split: Union[int, float] = 0.2, + num_workers: int = 16, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: """ Args: - data_dir: where to save/load the data - val_split: how many of the training images to use for the validation split - num_workers: how many workers to use for loading data - batch_size: size of batch - seed: random seed to be used for train/val/test splits - shuffle: If true shuffles the data every epoch + data_dir: Where to save/load the data + val_split: Percent (float) or number (int) of samples to use for the validation split + num_workers: How many workers to use for loading data + normalize: If true applies image normalize + batch_size: How many samples per batch to load + seed: Random seed to be used for train/val/test splits + shuffle: If true shuffles the train data every epoch pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before returning them drop_last: If true drops the last incomplete batch """ - super().__init__(*args, **kwargs) - - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use fashion MNIST dataset loaded from `torchvision` which is not installed yet.' - ) - - self.dims = (1, 28, 28) - self.data_dir = data_dir - self.val_split = val_split - self.num_workers = num_workers - self.batch_size = batch_size - self.seed = seed - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last + super().__init__( + data_dir=data_dir, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + *args, + **kwargs, + ) @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 10 """ return 10 - def prepare_data(self): - """ - Saves FashionMNIST files to data_dir - """ - FashionMNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) - FashionMNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - - def train_dataloader(self): - """ - FashionMNIST train set removes a subset to use for validation - """ - transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms - - dataset = FashionMNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - dataset_train, _ = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_train, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def val_dataloader(self): - """ - FashionMNIST val set uses a subset of the training set for validation - """ - transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms - - dataset = FashionMNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - _, dataset_val = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_val, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def test_dataloader(self): - """ - FashionMNIST test set uses the test split - """ - transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms - - dataset = FashionMNIST(self.data_dir, train=False, download=False, transform=transforms) - loader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader + def default_transforms(self): + if self.normalize: + mnist_transforms = transform_lib.Compose( + [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] + ) + else: + mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) - def _default_transforms(self): - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor() - ]) return mnist_transforms diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1a7399e45e..1dd5e927b6 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,7 +1,6 @@ -import torch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader, random_split +from typing import Any, Optional, Union +from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -10,9 +9,10 @@ from torchvision.datasets import MNIST else: warn_missing_pkg('torchvision') # pragma: no-cover + MNIST = None -class MNISTDataModule(LightningDataModule): +class MNISTDataModule(VisionDataModule): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png :width: 400 @@ -39,13 +39,14 @@ class MNISTDataModule(LightningDataModule): Trainer().fit(model, dm) """ - name = "mnist" + dataset_cls = MNIST + dims = (1, 28, 28) def __init__( self, - data_dir: str = "./", - val_split: int = 5000, + data_dir: Optional[str] = None, + val_split: Union[int, float] = 0.2, num_workers: int = 16, normalize: bool = False, batch_size: int = 32, @@ -53,119 +54,50 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, - ): + *args: Any, + **kwargs: Any, + ) -> None: """ Args: - data_dir: where to save/load the data - val_split: how many of the training images to use for the validation split - num_workers: how many workers to use for loading data + data_dir: Where to save/load the data + val_split: Percent (float) or number (int) of samples to use for the validation split + num_workers: How many workers to use for loading data normalize: If true applies image normalize - batch_size: size of batch - seed: random seed to be used for train/val/test splits - shuffle: If true shuffles the data every epoch + batch_size: How many samples per batch to load + seed: Random seed to be used for train/val/test splits + shuffle: If true shuffles the train data every epoch pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before returning them drop_last: If true drops the last incomplete batch """ - super().__init__(*args, **kwargs) - - if not _TORCHVISION_AVAILABLE: - raise ModuleNotFoundError( # pragma: no-cover - 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' - ) - - self.dims = (1, 28, 28) - self.data_dir = data_dir - self.val_split = val_split - self.num_workers = num_workers - self.normalize = normalize - self.batch_size = batch_size - self.seed = seed - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last + super().__init__( + data_dir=data_dir, + val_split=val_split, + num_workers=num_workers, + normalize=normalize, + batch_size=batch_size, + seed=seed, + shuffle=shuffle, + pin_memory=pin_memory, + drop_last=drop_last, + *args, + **kwargs, + ) @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 10 """ return 10 - def prepare_data(self): - """ - Saves MNIST files to data_dir - """ - MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) - MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - - def train_dataloader(self): - """ - MNIST train set removes a subset to use for validation - """ - transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms - - dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - dataset_train, _ = random_split( - dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_train, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory - ) - return loader - - def val_dataloader(self): - """ - MNIST val set uses a subset of the training set for validation - """ - transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms - dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) - train_length = len(dataset) - _, dataset_val = random_split( - dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) - ) - loader = DataLoader( - dataset_val, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - ) - return loader - - def test_dataloader(self): - """ - MNIST test set uses the test split - """ - transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms - - dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) - loader = DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - ) - return loader - - def _default_transforms(self): + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] ) else: - mnist_transforms = transform_lib.ToTensor() + mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) return mnist_transforms diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py new file mode 100644 index 0000000000..2144f0f509 --- /dev/null +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -0,0 +1,141 @@ +import os +from abc import abstractmethod +from typing import Any, List, Optional, Union + +import torch +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset, random_split + + +class VisionDataModule(LightningDataModule): + + EXTRA_ARGS = {} + name: str = "" + #: Dataset class to use + dataset_cls = ... + #: A tuple describing the shape of the data + dims: tuple = ... + + def __init__( + self, + data_dir: Optional[str] = None, + val_split: Union[int, float] = 0.2, + num_workers: int = 16, + normalize: bool = False, + batch_size: int = 32, + seed: int = 42, + shuffle: bool = False, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Args: + data_dir: Where to save/load the data + val_split: Percent (float) or number (int) of samples to use for the validation split + num_workers: How many workers to use for loading data + normalize: If true applies image normalize + batch_size: How many samples per batch to load + seed: Random seed to be used for train/val/test splits + shuffle: If true shuffles the train data every epoch + pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before + returning them + drop_last: If true drops the last incomplete batch + """ + + super().__init__(*args, **kwargs) + + self.data_dir = data_dir if data_dir is not None else os.getcwd() + self.val_split = val_split + self.num_workers = num_workers + self.normalize = normalize + self.batch_size = batch_size + self.seed = seed + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + + def prepare_data(self) -> None: + """ + Saves files to data_dir + """ + self.dataset_cls(self.data_dir, train=True, download=True) + self.dataset_cls(self.data_dir, train=False, download=True) + + def setup(self, stage: Optional[str] = None) -> None: + """ + Creates train, val, and test dataset + """ + if stage == "fit" or stage is None: + train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms + val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms + + dataset_train = self.dataset_cls(self.data_dir, train=True, transform=train_transforms, **self.EXTRA_ARGS) + dataset_val = self.dataset_cls(self.data_dir, train=True, transform=val_transforms, **self.EXTRA_ARGS) + + # Split + self.dataset_train = self._split_dataset(dataset_train) + self.dataset_val = self._split_dataset(dataset_val, train=False) + + if stage == "test" or stage is None: + test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms + self.dataset_test = self.dataset_cls( + self.data_dir, train=False, transform=test_transforms, **self.EXTRA_ARGS + ) + + def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset: + """ + Splits the dataset into train and validation set + """ + len_dataset = len(dataset) + splits = self._get_splits(len_dataset) + dataset_train, dataset_val = random_split( + dataset, splits, generator=torch.Generator().manual_seed(self.seed) + ) + + if train: + return dataset_train + return dataset_val + + def _get_splits(self, len_dataset: int) -> List[int]: + """ + Computes split lengths for train and validation set + """ + if isinstance(self.val_split, int): + train_len = len_dataset - self.val_split + splits = [train_len, self.val_split] + elif isinstance(self.val_split, float): + val_len = int(self.val_split * len_dataset) + train_len = len_dataset - val_len + splits = [train_len, val_len] + else: + raise ValueError(f'Unsupported type {type(self.val_split)}') + + return splits + + @abstractmethod + def default_transforms(self): + """ Default transform for the dataset """ + + def train_dataloader(self) -> DataLoader: + """ The train dataloader """ + return self._data_loader(self.dataset_train, shuffle=self.shuffle) + + def val_dataloader(self) -> DataLoader: + """ The val dataloader """ + return self._data_loader(self.dataset_val) + + def test_dataloader(self) -> DataLoader: + """ The test dataloader """ + return self._data_loader(self.dataset_test) + + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + return DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + drop_last=self.drop_last, + pin_memory=self.pin_memory + ) diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index 130c1970c7..2fc6a4c3df 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -1,10 +1,17 @@ import uuid from pathlib import Path +import pytest import torch from PIL import Image -from pl_bolts.datamodules import CityscapesDataModule +from pl_bolts.datamodules import ( + BinaryMNISTDataModule, + CIFAR10DataModule, + CityscapesDataModule, + FashionMNISTDataModule, + MNISTDataModule, +) from pl_bolts.datasets.cifar10_dataset import CIFAR10 @@ -63,3 +70,24 @@ def test_cityscapes_datamodule(datadir): img, mask = next(iter(loader)) assert img.size() == torch.Size([batch_size, 3, 1024, 2048]) assert mask.size() == torch.Size([batch_size, 1024, 2048]) + + +@pytest.mark.parametrize("val_split, train_len", [(0.2, 48_000), (5_000, 55_000)]) +def test_vision_data_module(datadir, val_split, train_len): + dm = _create_dm(MNISTDataModule, datadir, val_split=val_split) + assert len(dm.dataset_train) == train_len + + +@pytest.mark.parametrize("dm_cls", [BinaryMNISTDataModule, CIFAR10DataModule, FashionMNISTDataModule, MNISTDataModule]) +def test_data_modules(datadir, dm_cls): + dm = _create_dm(dm_cls, datadir) + loader = dm.train_dataloader() + img, _ = next(iter(loader)) + assert img.size() == torch.Size([2, *dm.size()]) + + +def _create_dm(dm_cls, datadir, val_split=0.2): + dm = dm_cls(data_dir=datadir, val_split=val_split, num_workers=1, batch_size=2) + dm.prepare_data() + dm.setup() + return dm diff --git a/tests/models/test_classic_ml.py b/tests/models/test_classic_ml.py index 565773a95e..978273d3db 100644 --- a/tests/models/test_classic_ml.py +++ b/tests/models/test_classic_ml.py @@ -41,6 +41,7 @@ def test_logistic_regression_model(tmpdir, datadir): model = LogisticRegression(input_dim=28 * 28, num_classes=10, learning_rate=0.001) model.prepare_data = dm.prepare_data + model.setup = dm.setup model.train_dataloader = dm.train_dataloader model.val_dataloader = dm.val_dataloader model.test_dataloader = dm.test_dataloader