-
Notifications
You must be signed in to change notification settings - Fork 15
/
trainer.py
85 lines (70 loc) · 2.94 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import sys
import torch
from torch.optim import SGD, Adam
from tqdm.auto import tqdm
class Trainer:
def __init__(
self,
optimizer: dict(help='optimization algorithm', choices=['sgd', 'adam']) = 'adam',
max_epochs: dict(help='maximum number of training epochs') = 500,
learning_rate: dict(help='learning rate') = 0.01,
weight_decay: dict(help='weight decay (L2 penalty)') = 0.0,
patience: dict(help='early-stopping patience window size') = 0,
device='cuda',
logger=None,
):
self.optimizer_name = optimizer
self.max_epochs = max_epochs
self.device = device
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.patience = patience
self.logger = logger
self.model = None
def configure_optimizers(self):
if self.optimizer_name == 'sgd':
return SGD(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
elif self.optimizer_name == 'adam':
return Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
def fit(self, model, data):
self.model = model.to(self.device)
data = data.to(self.device)
optimizer = self.configure_optimizers()
num_epochs_without_improvement = 0
best_metrics = None
epoch_progbar = tqdm(range(1, self.max_epochs + 1), desc='Epoch: ', leave=False, position=1, file=sys.stdout)
for epoch in epoch_progbar:
metrics = {'epoch': epoch}
train_metrics = self._train(data, optimizer)
metrics.update(train_metrics)
val_metrics = self._validation(data)
metrics.update(val_metrics)
if self.logger:
self.logger.log(metrics)
if best_metrics is None or (
metrics['val/loss'] < best_metrics['val/loss'] and
best_metrics['val/acc'] < metrics['val/acc'] <= metrics['train/maxacc'] and
best_metrics['train/acc'] < metrics['train/acc'] <= 1.05 * metrics['train/maxacc']
):
best_metrics = metrics
num_epochs_without_improvement = 0
else:
num_epochs_without_improvement += 1
if num_epochs_without_improvement >= self.patience > 0:
break
# display metrics on progress bar
epoch_progbar.set_postfix(metrics)
if self.logger:
self.logger.log_summary(best_metrics)
return best_metrics
def _train(self, data, optimizer):
self.model.train()
optimizer.zero_grad()
loss, metrics = self.model.training_step(data)
loss.backward()
optimizer.step()
return metrics
@torch.no_grad()
def _validation(self, data):
self.model.eval()
return self.model.validation_step(data)