-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
34 lines (24 loc) · 972 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import hydra
from omegaconf import DictConfig
from DCNN.datasets import create_torch_dataloaders
from DCNN.trainer import DCNNTrainer
import warnings
warnings.simplefilter('ignore')
import torch
torch.cuda.empty_cache()
max_split_size_mb = 512
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def train(config: DictConfig):
"""Runs the training procedure using Pytorch lightning
And tests the model with the best validation score against the test dataset.
Args:
config (DictConfig): Configuration automatically loaded by Hydra.
See the config/ directory for the configuration
"""
dataset_train, dataset_val, dataset_test = create_torch_dataloaders(config)
trainer = DCNNTrainer(config)
trainer.fit(dataset_train, val_dataloaders=dataset_val)
trainer.test(dataset_test)
if __name__ == "__main__":
train()