diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 7229462e97..a741203167 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -76,6 +76,7 @@ def __init__( self.num_workers = num_workers self.normalize = normalize self.seed = seed + self.batch_size = batch_size @property def num_classes(self): @@ -92,15 +93,14 @@ def prepare_data(self): MNIST(self.data_dir, train=True, download=True, transform=transform_lib.ToTensor()) MNIST(self.data_dir, train=False, download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self, batch_size=32, transforms=None): + def train_dataloader(self): """ MNIST train set removes a subset to use for validation Args: - batch_size: size of batch transforms: custom transforms """ - transforms = transforms or self.train_transforms or self._default_transforms() + transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) @@ -109,7 +109,7 @@ def train_dataloader(self, batch_size=32, transforms=None): ) loader = DataLoader( dataset_train, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, @@ -117,15 +117,14 @@ def train_dataloader(self, batch_size=32, transforms=None): ) return loader - def val_dataloader(self, batch_size=32, transforms=None): + def val_dataloader(self): """ MNIST val set uses a subset of the training set for validation Args: - batch_size: size of batch transforms: custom transforms """ - transforms = transforms or self.val_transforms or self._default_transforms() + transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) _, dataset_val = random_split( @@ -133,7 +132,7 @@ def val_dataloader(self, batch_size=32, transforms=None): ) loader = DataLoader( dataset_val, - batch_size=batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, @@ -141,23 +140,23 @@ def val_dataloader(self, batch_size=32, transforms=None): ) return loader - def test_dataloader(self, batch_size=32, transforms=None): + def test_dataloader(self): """ MNIST test set uses the test split Args: - batch_size: size of batch transforms: custom transforms """ - transforms = transforms or self.val_transforms or self._default_transforms() + transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( - dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True + dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, + pin_memory=True ) return loader - def _default_transforms(self): + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]