forked from MKFMIKU/RAW2RGBNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
122 lines (100 loc) · 4.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
from os import listdir
from os.path import join
from os.path import exists
# import torch
import random
# from torch.autograd import Variable
# from torch.nn import init
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import numpy as np
# from torch import nn
import math
def plot_grad_flow(named_parameters):
'''Plots the gradients flowing through different layers in the net during training.
Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as
"plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
plt.switch_backend('agg')
figure = plt.figure(figsize=(16, 4))
ave_grads = []
max_grads = []
layers = []
for n, p in named_parameters:
if (p.requires_grad) and ("bias" not in n):
layers.append(n)
ave_grads.append(p.grad.abs().mean())
max_grads.append(p.grad.abs().max())
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
plt.xlim(left=0, right=len(ave_grads))
plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions
plt.xlabel("Layers")
plt.ylabel("average gradient")
plt.title("Gradient flow")
plt.grid(True)
plt.legend([Line2D([0], [0], color="c", lw=4),
Line2D([0], [0], color="b", lw=4),
Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
return figure
def is_image_file(filename):
filename_lower = filename.lower()
return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif'])
def load_all_image(path):
return [join(path, x) for x in listdir(path) if is_image_file(x)]
def save_checkpoint(model, name, discriminator, epoch, model_folder):
if not exists(model_folder):
os.makedirs(model_folder)
if not exists(join(model_folder, name)):
os.makedirs(join(model_folder, name))
model_out_path = "%s/%d.pth" % (join(model_folder, name), epoch)
state_dict_model = model.module.state_dict()
for key in state_dict_model.keys():
state_dict_model[key] = state_dict_model[key].cpu()
if discriminator:
state_dict_discriminator = discriminator.module.state_dict()
for key in state_dict_discriminator.keys():
state_dict_discriminator[key] = state_dict_discriminator[key].cpu()
torch.save({"epoch": epoch,
"state_dict_model": state_dict_model,
"state_dict_discriminator": state_dict_discriminator}, model_out_path)
else:
torch.save({"epoch": epoch,
"state_dict_model": state_dict_model}, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_uniform_(m.weight.data, a=math.sqrt(5))
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight.data)
bound = 1 / math.sqrt(fan_in)
init.uniform_(m.bias.data, -bound, bound)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def quantize(img, rgb_range):
return img.mul(rgb_range).clamp(0, 255).round()