-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
143 lines (109 loc) · 4.59 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Here, we will run everything that is related to the training procedure.
"""
import time
import torch
import torch.nn as nn
from tqdm import tqdm
from utils import train_utils
from torch.utils.data import DataLoader
from utils.types import Scores, Metrics
from utils.train_utils import TrainParams
from utils.train_logger import TrainLogger
def get_metrics(best_eval_score: float, eval_score: float, train_loss: float) -> Metrics:
"""
Example of metrics dictionary to be reported to tensorboard. Change it to your metrics
:param best_eval_score:
:param eval_score:
:param train_loss:
:return:
"""
return {'Metrics/BestAccuracy': best_eval_score,
'Metrics/LastAccuracy': eval_score,
'Metrics/LastLoss': train_loss}
def train(model: nn.Module, train_loader: DataLoader, eval_loader: DataLoader, train_params: TrainParams,
logger: TrainLogger) -> Metrics:
"""
Training procedure. Change each part if needed (optimizer, loss, etc.)
:param model:
:param train_loader:
:param eval_loader:
:param train_params:
:param logger:
:return:
"""
metrics = train_utils.get_zeroed_metrics_dict()
best_eval_score = 0
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=train_params.lr)
# Create learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=train_params.lr_step_size,
gamma=train_params.lr_gamma)
for epoch in tqdm(range(train_params.num_epochs)):
t = time.time()
metrics = train_utils.get_zeroed_metrics_dict()
for i, (x, y) in enumerate(train_loader):
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
y_hat = model(x)
loss = nn.functional.binary_cross_entropy_with_logits(y_hat, y)
# Optimization step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Calculate metrics
metrics['total_norm'] += nn.utils.clip_grad_norm_(model.parameters(), train_params.grad_clip)
metrics['count_norm'] += 1
# NOTE! This function compute scores correctly only for one hot encoding representation of the logits
batch_score = train_utils.compute_score_with_logits(y_hat, y.data).sum()
metrics['train_score'] += batch_score.item()
metrics['train_loss'] += loss.item()
# Report model to tensorboard
if epoch == 0 and i == 0:
logger.report_graph(model, x)
# Learning rate scheduler step
scheduler.step()
# Calculate metrics
metrics['train_loss'] /= len(train_loader.dataset)
metrics['train_score'] /= len(train_loader.dataset)
metrics['train_score'] *= 100
norm = metrics['total_norm'] / metrics['count_norm']
model.train(False)
metrics['eval_score'], metrics['eval_loss'] = evaluate(model, eval_loader)
model.train(True)
epoch_time = time.time() - t
logger.write_epoch_statistics(epoch, epoch_time, metrics['train_loss'], norm,
metrics['train_score'], metrics['eval_score'])
scalars = {'Accuracy/Train': metrics['train_score'],
'Accuracy/Validation': metrics['eval_score'],
'Loss/Train': metrics['train_loss'],
'Loss/Validation': metrics['eval_loss']}
logger.report_scalars(scalars, epoch)
if metrics['eval_score'] > best_eval_score:
best_eval_score = metrics['eval_score']
if train_params.save_model:
logger.save_model(model, epoch, optimizer)
return get_metrics(best_eval_score, metrics['eval_score'], metrics['train_loss'])
@torch.no_grad()
def evaluate(model: nn.Module, dataloader: DataLoader) -> Scores:
"""
Evaluate a model without gradient calculation
:param model: instance of a model
:param dataloader: dataloader to evaluate the model on
:return: tuple of (accuracy, loss) values
"""
score = 0
loss = 0
for i, (x, y) in tqdm(enumerate(dataloader)):
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
y_hat = model(x)
loss += nn.functional.binary_cross_entropy_with_logits(y_hat, y)
score += train_utils.compute_score_with_logits(y_hat, y).sum().item()
loss /= len(dataloader.dataset)
score /= len(dataloader.dataset)
score *= 100
return score, loss