-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_hourglass.py
110 lines (90 loc) · 4.3 KB
/
train_hourglass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from tensorboardX import SummaryWriter
import MPII
import model.hourglass_torch7
from util import config
from util.visualize import colorize, overlap
from util.log import get_logger
logger, log_dir, comment = get_logger(comment=config.hourglass.comment)
if config.hourglass.comment is None or not os.path.exists('save/{comment}/parameter'.format(comment=comment)):
logger.info(' ')
logger.info(' ')
logger.info('===========================================================')
logger.info('Comment : ' + comment + ' ')
logger.info('===========================================================')
logger.info('Architecture : ' + 'Stacked hourglass' + ' ')
logger.info(' -task : ' + MPII.Task.Train + ' ')
logger.info(' -device : ' + str(config.hourglass.device) + ' ')
logger.info('===========================================================')
logger.info('Data : ' + 'MPII' + ' ')
logger.info(' -directory : ' + config.hourglass.data_dir + ' ')
logger.info(' -mini batch : ' + str(config.hourglass.batch_size) + ' ')
logger.info(' -shuffle : ' + 'True' + ' ')
logger.info(' -worker : ' + str(config.hourglass.num_workers) + ' ')
logger.info('===========================================================')
data = DataLoader(
MPII.Dataset(
root=config.hourglass.data_dir,
task=MPII.Task.Train,
),
batch_size=config.hourglass.batch_size,
num_workers=config.hourglass.num_workers,
shuffle=True,
pin_memory=True,
)
hourglass, optimizer, step, train_epoch = model.hourglass_torch7.load(
device=config.hourglass.device,
parameter_dir='{log_dir}/parameter'.format(log_dir=log_dir) if config.hourglass.comment is not None else None,
)
criterion = nn.MSELoss()
writer = SummaryWriter(log_dir='{log_dir}/visualize'.format(
log_dir=log_dir,
))
hourglass.train()
resize = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(size=[256, 256]),
transforms.ToTensor(),
])
upscale = lambda heatmaps: torch.stack([resize(heatmap) for heatmap in heatmaps.cpu()]).to(config.hourglass.device)
for epoch in range(train_epoch + 1, train_epoch + 10 + 1):
with tqdm(total=len(data), desc='%d epoch' % epoch) as progress:
with torch.set_grad_enabled(True):
for images, heatmaps, _, _, _, _ in data:
images = images.to(config.hourglass.device)
heatmaps = heatmaps.to(config.hourglass.device)
optimizer.zero_grad()
outputs = hourglass(images)
loss = sum([criterion(output, heatmaps) for output in outputs])
loss.backward()
nn.utils.clip_grad_norm_(hourglass.parameters(), max_norm=1)
optimizer.step()
writer.add_scalar('SH/loss', loss, step)
if step % 100 == 0:
ground_truth = overlap(images=images, heatmaps=upscale(colorize(heatmaps)))
prediction = overlap(images=images, heatmaps=upscale(colorize(outputs[-1])))
writer.add_image('{comment}/ground-truth'.format(comment=config.hourglass.comment), ground_truth.data, step)
writer.add_image('{comment}/prediction'.format(comment=config.hourglass.comment), prediction.data, step)
progress.set_postfix(loss=float(loss.item()))
progress.update(1)
step = step + 1
save_dir = '{log_dir}/parameter'.format(log_dir=log_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_to = '{save_dir}/{epoch}.save'.format(save_dir=save_dir, epoch=epoch, )
torch.save(
{
'epoch': epoch,
'step': step,
'state': hourglass.state_dict(),
'optimizer': optimizer.state_dict(),
},
save_to,
)
logger.info('Epoch {epoch} saved (loss: {loss})'.format(epoch=epoch, loss=float(loss.item())))
writer.close()