-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
178 lines (165 loc) · 6.87 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
from colorization import *
from unet import *
from utils import classification_accuracy
from tqdm import tqdm
import torch.nn as nn
from torchvision import ops
from dataset import ProCodesDataModule
from torch.nn import MSELoss
from torch.utils.tensorboard import SummaryWriter
def train(model: ColorizationNet, train_path: str, learning_rate: float, epochs: int, batch_size: int, model_path: str,
loss_fn: object, early_stop_max: int = 25, continue_training: str = None, parallel: bool = True):
"""
:param model: model object
:param train_path: path to training, if unet this should be a list of two paths
:param learning_rate: learning rate of the optimizer
:param epochs:
:param batch_size:
:param model_path: path of where the model is to be saved
:param continue_training: path of where the model is to be loaded from if training is continued
:param loss_fn: loss function
:param parallel: bool if parallel training
:param data_type: determine type of model
:return:
"""
torch.cuda.empty_cache()
start_time = time.time()
writer = SummaryWriter(comment='identity_test')
epoch = 0
if parallel:
model = nn.DataParallel(model)
if continue_training:
model_data = torch.load(continue_training)
train_losses = model_data['train_losses']
epoch = model_data['epoch']
model.load_state_dict(model_data['model_state_dict'])
if torch.cuda.is_available():
cuda0 = torch.device('cuda:0')
model.to(cuda0)
else:
cuda0 = 'cpu'
# setup hyper parameters
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# repetitive if statement because of bug with cuda and model creation and optimizer creation
if continue_training:
optimizer.load_state_dict(model_data['optimizer_state_dict'])
# set up dataloader
z = ProCodesDataModule(data_dir=train_path, batch_size=batch_size, test_size=0.2)
train_loader = z.train_dataloader()
val_loader = z.validation_dataloader()
# set up some params
length_train = len(train_loader)
length_val = len(val_loader)
es_counter = 0
best_loss, best_e = None, 0
best_model, best_optim = None, None
for e in tqdm(range(epoch, epochs)):
running_loss = 0
running_classification_acc = 0
for i, image_label in enumerate(train_loader):
# image, label, zero_mask = image_label
image, label = image_label
image = image.to(cuda0)
label = label.to(cuda0)
# zero_mask = zero_mask.to(cuda0)
# forward pass
output = model(image)
loss = loss_fn(output, label)
running_loss += loss.item()
# running_classification_acc += classification_accuracy(output, label, zero_mask)
del image
del label
# del zero_mask
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_val_loss = 0
# running_classification_val_acc = 0
# for i, image_label in enumerate(val_loader):
# image, label, zero_mask = image_label
# image = image.to(cuda0)
# label = label.to(cuda0)
# zero_mask = zero_mask.to(cuda0)
# # forward pass
# output = model(image)
# loss_v = criterion(output, label)
# running_val_loss += loss_v.item()
# running_classification_val_acc += classification_accuracy(output, label, zero_mask)
# del image
# del label
# classification_val_acc_per_epoch = running_classification_val_acc / length_val
val_loss_per_epoch = running_val_loss / length_val
# classification_acc_per_epoch = running_classification_acc / length_train
loss_per_epoch = running_loss / length_train
writer.add_scalar("Loss/train", loss_per_epoch, e)
# writer.add_scalar("Accuracy/train", classification_acc_per_epoch, e)
# writer.add_scalar("Loss/val", val_loss_per_epoch, e)
# writer.add_scalar("Accuracy/val", classification_val_acc_per_epoch, e)
if (e + 1) % 200 == 0:
torch.save({
'epoch': e,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
'test_images':z.xtest
}, model_path + f'{e + 1}_test.tar')
# print('Epoch [{}/{}], Loss: {:.6f}, Accuracy: {:.3f}, Val Loss: {:.4f}, Val Accuracy: {:.3f}'.format(e + 1,
# epochs, loss_per_epoch, classification_acc_per_epoch, val_loss_per_epoch, classification_val_acc_per_epoch))
print('Epoch [{}/{}], Loss: {:.6f}'.format(e + 1, epochs,loss_per_epoch))
# Check for early stopping
# if best_loss is None:
# best_loss = -val_loss_per_epoch
# best_model = model.state_dict()
# best_optim = optimizer.state_dict()
# best_e = e + 1
# elif -val_loss_per_epoch < best_loss:
# es_counter += 1
# if es_counter > early_stop_max:
# print("Early Stop Initiated")
# torch.save({
# 'epoch': best_e,
# 'model_state_dict': best_model,
# 'optimizer_state_dict': best_optim,
# 'loss': -best_loss,
# 'test_images': z.xtest
# }, model_path + f'{best_e}_best_ES.tar')
# break
# else:
# es_counter = 0
# best_loss = -val_loss_per_epoch
# best_model = model.state_dict()
# best_optim = optimizer.state_dict()
# best_e = e + 1
print((time.time() - start_time)/60, ' minutes to finish')
writer.flush()
writer.close()
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# batch_size
parser.add_argument('batch_size')
# epochs
parser.add_argument('epochs')
# path to data
parser.add_argument('path')
# path (optional) to pretrained model
parser.add_argument('--model')
# path if using unet to labeled data
parser.add_argument('--label_path')
args = parser.parse_args()
batch_size = int(args.batch_size)
epochs = int(args.epochs)
path = args.path
model_presave = args.model
label_path = args.label_path
MSE = MSELoss(reduction='mean')
BCE = nn.BCEWithLogitsLoss(reduction='mean')
if label_path:
path = [path, label_path]
# unet = UNet(num_class=3, retain_dim=True, out_sz=(256, 256), dropout=0.05)
unet, _ = create_pretrained('resnet34', None)
print("BEGIN TRAINING")
loss_data = train(unet, path, 0.0005, epochs, batch_size, 'models/unet/', early_stop_max=1000, loss_fn=MSE,
continue_training=model_presave, parallel=True)