-
Notifications
You must be signed in to change notification settings - Fork 53
/
utils.py
104 lines (78 loc) · 3.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
file: utils.py
about: all utilities
author: Xiaohong Liu
date: 01/08/19
"""
# --- Imports --- #
import time
import torch
import torch.nn.functional as F
import torchvision.utils as utils
from math import log10
from skimage import measure
def to_psnr(dehaze, gt):
mse = F.mse_loss(dehaze, gt, reduction='none')
mse_split = torch.split(mse, 1, dim=0)
mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]
intensity_max = 1.0
psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
return psnr_list
def to_ssim_skimage(dehaze, gt):
dehaze_list = torch.split(dehaze, 1, dim=0)
gt_list = torch.split(gt, 1, dim=0)
dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
ssim_list = [measure.compare_ssim(dehaze_list_np[ind], gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))]
return ssim_list
def validation(net, val_data_loader, device, category, save_tag=False):
"""
:param net: GateDehazeNet
:param val_data_loader: validation loader
:param device: The GPU that loads the network
:param category: indoor or outdoor test dataset
:param save_tag: tag of saving image or not
:return: average PSNR value
"""
psnr_list = []
ssim_list = []
for batch_id, val_data in enumerate(val_data_loader):
with torch.no_grad():
haze, gt, image_name = val_data
haze = haze.to(device)
gt = gt.to(device)
dehaze = net(haze)
# --- Calculate the average PSNR --- #
psnr_list.extend(to_psnr(dehaze, gt))
# --- Calculate the average SSIM --- #
ssim_list.extend(to_ssim_skimage(dehaze, gt))
# --- Save image --- #
if save_tag:
save_image(dehaze, image_name, category)
avr_psnr = sum(psnr_list) / len(psnr_list)
avr_ssim = sum(ssim_list) / len(ssim_list)
return avr_psnr, avr_ssim
def save_image(dehaze, image_name, category):
dehaze_images = torch.split(dehaze, 1, dim=0)
batch_num = len(dehaze_images)
for ind in range(batch_num):
utils.save_image(dehaze_images[ind], './{}_results/{}'.format(category, image_name[ind][:-3] + 'png'))
def print_log(epoch, num_epochs, one_epoch_time, train_psnr, val_psnr, val_ssim, category):
print('({0:.0f}s) Epoch [{1}/{2}], Train_PSNR:{3:.2f}, Val_PSNR:{4:.2f}, Val_SSIM:{5:.4f}'
.format(one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim))
# --- Write the training log --- #
with open('./training_log/{}_log.txt'.format(category), 'a') as f:
print('Date: {0}s, Time_Cost: {1:.0f}s, Epoch: [{2}/{3}], Train_PSNR: {4:.2f}, Val_PSNR: {5:.2f}, Val_SSIM: {6:.4f}'
.format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim), file=f)
def adjust_learning_rate(optimizer, epoch, category, lr_decay=0.5):
# --- Decay learning rate --- #
step = 20 if category == 'indoor' else 2
if not epoch % step and epoch > 0:
for param_group in optimizer.param_groups:
param_group['lr'] *= lr_decay
print('Learning rate sets to {}.'.format(param_group['lr']))
else:
for param_group in optimizer.param_groups:
print('Learning rate sets to {}.'.format(param_group['lr']))