Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed datamodule from an input parameter #270

Merged
merged 7 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self,
learning_rate: float = 0.03,
momentum: float = 0.9,
weight_decay: float = 1e-4,
datamodule: pl.LightningDataModule = None,
data_dir: str = './',
batch_size: int = 256,
use_mlp: bool = False,
Expand All @@ -96,14 +95,6 @@ def __init__(self,
super().__init__()
self.save_hyperparameters()

# use CIFAR-10 by default if no datamodule passed in
# if datamodule is None:
# datamodule = CIFAR10DataModule(data_dir)
# datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
# datamodule.val_transforms = Moco2EvalCIFAR10Transforms()
assert datamodule
self.datamodule = datamodule

# create the encoders
# num_classes is the output fc dimension
self.encoder_q, self.encoder_k = self.init_encoders(base_encoder)
Expand Down Expand Up @@ -259,7 +250,7 @@ def forward(self, img_q, img_k):

def training_step(self, batch, batch_idx):
# in STL10 we pass in both lab+unl for online ft
if self.hparams.datamodule.name == 'stl10':
if self.trainer.datamodule.name == 'stl10':
labeled_batch = batch[1]
unlabeled_batch = batch[0]
batch = unlabeled_batch
Expand All @@ -276,11 +267,12 @@ def training_step(self, batch, batch_idx):
'train_acc1': acc1,
'train_acc5': acc5
}
return {'loss': loss, 'log': log, 'progress_bar': log}
self.log_dict(log)
return loss

def validation_step(self, batch, batch_idx):
# in STL10 we pass in both lab+unl for online ft
if self.hparams.datamodule.name == 'stl10':
if self.trainer.datamodule.name == 'stl10':
labeled_batch = batch[1]
unlabeled_batch = batch[0]
batch = unlabeled_batch
Expand Down Expand Up @@ -309,7 +301,7 @@ def validation_epoch_end(self, outputs):
'val_acc1': val_acc1,
'val_acc5': val_acc5
}
return {'val_loss': val_loss, 'log': log, 'progress_bar': log}
self.log_dict(log)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), self.hparams.learning_rate,
Expand Down Expand Up @@ -382,10 +374,14 @@ def cli_main():
datamodule.train_transforms = Moco2TrainImagenetTransforms()
datamodule.val_transforms = Moco2EvalImagenetTransforms()

model = MocoV2(**args.__dict__, datamodule=datamodule)
else:
# replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in
datamodule = None

model = MocoV2(**args.__dict__)

trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)


if __name__ == '__main__':
Expand Down
10 changes: 3 additions & 7 deletions pl_bolts/models/vision/image_gpt/igpt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ class ImageGPT(pl.LightningModule):
"""
def __init__(
self,
datamodule: pl.LightningDataModule = None,
embed_dim: int = 16,
heads: int = 2,
layers: int = 2,
Expand All @@ -115,7 +114,6 @@ def __init__(
):
"""
Args:
datamodule: LightningDataModule
embed_dim: the embedding dim
heads: number of attention heads
layers: number of layers
Expand All @@ -129,7 +127,7 @@ def __init__(
data_dir: where to store data
num_workers: num_data workers
"""
super(ImageGPT, self).__init__()
super().__init__()
self.save_hyperparameters()

# default to MNIST if no datamodule given
Expand All @@ -139,8 +137,6 @@ def __init__(
# )
# self.hparams.pixels = datamodule.size(1)
# self.hparams.num_classes = datamodule.num_classes
assert datamodule
self.datamodule = datamodule

self.gpt = GPT2(
embed_dim=self.hparams.embed_dim,
Expand Down Expand Up @@ -259,10 +255,10 @@ def cli_main():
elif args.dataset == "imagenet128":
datamodule = ImagenetDataModule.from_argparse_args(args)

model = ImageGPT(**args.__dict__, datamodule=datamodule)
model = ImageGPT(**args.__dict__)

trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def test_moco(tmpdir):
datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
datamodule.val_transforms = Moco2EvalCIFAR10Transforms()

model = MocoV2(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True)
model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True)
trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()])
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)
loss = trainer.progress_bar_dict['loss']

assert float(loss) > 0
Expand Down
10 changes: 5 additions & 5 deletions tests/models/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@
def test_igpt(tmpdir):
pl.seed_everything(0)
dm = MNISTDataModule(tmpdir, normalize=False)
model = ImageGPT(datamodule=dm)
model = ImageGPT()

trainer = pl.Trainer(
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)
trainer.test()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
assert trainer.callback_metrics["test_loss"] < 1.7

dm = FashionMNISTDataModule(tmpdir, num_workers=1)
model = ImageGPT(classify=True, datamodule=dm)
model = ImageGPT(classify=True)
trainer = pl.Trainer(
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)
trainer.fit(model, datamodule=dm)


def test_gpt2(tmpdir):
Expand Down