-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
executable file
·55 lines (47 loc) · 1.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import torch
import torch.backends.cuda
import torch.backends.cudnn
from jsonargparse import lazy_instance
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.loggers import CSVLogger
from src.data import DataModule
from src.model import ClassificationModel
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser) -> None:
parser.add_lightning_class_args(ModelCheckpoint, "model_checkpoint")
parser.set_defaults(
{
"trainer.logger": lazy_instance(
CSVLogger, save_dir="output", name="default"
),
"model_checkpoint.monitor": "val_acc",
"model_checkpoint.mode": "max",
"model_checkpoint.filename": "best-step-{step}-{val_acc:.4f}",
"model_checkpoint.save_last": True,
}
)
parser.link_arguments("data.size", "model.image_size")
parser.link_arguments(
"data.num_classes", "model.n_classes", apply_on="instantiate"
)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
cli = MyLightningCLI(
ClassificationModel,
DataModule,
save_config_kwargs={"overwrite": True},
trainer_defaults={"check_val_every_n_epoch": None},
)
# Copy the config into the experiment directory
# Fix for https://github.com/Lightning-AI/lightning/issues/17168
try:
os.rename(
os.path.join(cli.trainer.logger.save_dir, "config.yaml"), # type:ignore
os.path.join(
cli.trainer.checkpoint_callback.dirpath[:-12], "config.yaml" # type:ignore
),
)
except:
pass