Skip to content

Commit

Permalink
Fixed mypy errors in the backward compatibility code of LightningCLI …
Browse files Browse the repository at this point in the history
…import
  • Loading branch information
senarvi committed Jul 2, 2023
1 parent 7c0b0b6 commit 7621b99
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
10 changes: 6 additions & 4 deletions src/pl_bolts/models/detection/retinanet/retinanet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,16 @@ def configure_optimizers(self):

@under_review()
def cli_main():
import pytorch_lightning

try: # Backward compatibility for Lightning CLI
from pytorch_lightning.cli import LightningCLI # PL v1.9+
except ImportError:
from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8
cli_class: Any = getattr(pytorch_lightning.cli, "LightningCLI") # PL v1.9+
except AttributeError:
cli_class = getattr(pytorch_lightning.utilities.cli, "LightningCLI") # PL v1.8

from pl_bolts.datamodules import VOCDetectionDataModule

LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
cli_class(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions src/pl_bolts/models/detection/yolo/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor, optim

Expand Down Expand Up @@ -615,4 +614,11 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]:


if __name__ == "__main__":
LightningCLI(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42)
import pytorch_lightning

try: # Backward compatibility for Lightning CLI
cli_class: Any = getattr(pytorch_lightning.cli, "LightningCLI") # PL v1.9+
except AttributeError:
cli_class = getattr(pytorch_lightning.utilities.cli, "LightningCLI") # PL v1.8

cli_class(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42)
10 changes: 6 additions & 4 deletions src/pl_bolts/models/self_supervised/moco/moco_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,14 @@ def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:


def cli_main() -> None:
import pytorch_lightning

try: # Backward compatibility for Lightning CLI
from pytorch_lightning.cli import LightningCLI # PL v1.9+
except ImportError:
from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8
cli_class: Any = getattr(pytorch_lightning.cli, "LightningCLI") # PL v1.9+
except AttributeError:
cli_class = getattr(pytorch_lightning.utilities.cli, "LightningCLI") # PL v1.8

LightningCLI(MoCo, CIFAR10ContrastiveDataModule, seed_everything_default=42)
cli_class(MoCo, CIFAR10ContrastiveDataModule, seed_everything_default=42)


if __name__ == "__main__":
Expand Down

0 comments on commit 7621b99

Please sign in to comment.