-
Notifications
You must be signed in to change notification settings - Fork 63
/
utils.py
100 lines (82 loc) · 3.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from pycls.models.nas.nas import Cell
class DropChannel(torch.nn.Module):
def __init__(self, p, mod):
super(DropChannel, self).__init__()
self.mod = mod
self.p = p
def forward(self, s0, s1, droppath):
ret = self.mod(s0, s1, droppath)
return ret
class DropConnect(torch.nn.Module):
def __init__(self, p):
super(DropConnect, self).__init__()
self.p = p
def forward(self, inputs):
batch_size = inputs.shape[0]
dim1 = inputs.shape[2]
dim2 = inputs.shape[3]
channel_size = inputs.shape[1]
keep_prob = 1 - self.p
# generate binary_tensor mask according to probability (p for 0, 1-p for 1)
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, channel_size, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def add_dropout(network, p, prefix=''):
#p = 0.5
for attr_str in dir(network):
target_attr = getattr(network, attr_str)
if isinstance(target_attr, torch.nn.Conv2d):
setattr(network, attr_str, torch.nn.Sequential(target_attr, DropConnect(p)))
elif isinstance(target_attr, Cell):
setattr(network, attr_str, DropChannel(p, target_attr))
for n, ch in list(network.named_children()):
#print(f'{prefix}add_dropout {n}')
if isinstance(ch, torch.nn.Conv2d):
setattr(network, n, torch.nn.Sequential(ch, DropConnect(p)))
elif isinstance(ch, Cell):
setattr(network, n, DropChannel(p, ch))
else:
add_dropout(ch, p, prefix + '\t')
def orth_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.orthogonal_(m.weight)
def uni_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight)
def uni2_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight, -1., 1.)
def uni3_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.uniform_(m.weight, -.5, .5)
def norm_init(m):
if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.norm_(m.weight)
def eye_init(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.eye_(m.weight)
elif isinstance(m, torch.nn.Conv2d):
torch.nn.init.dirac_(m.weight)
def fixup_init(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.zero_(m.weight)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.zero_(m.weight)
torch.nn.init.zero_(m.bias)
def init_network(network, init):
if init == 'orthogonal':
network.apply(orth_init)
elif init == 'uniform':
print('uniform')
network.apply(uni_init)
elif init == 'uniform2':
network.apply(uni2_init)
elif init == 'uniform3':
network.apply(uni3_init)
elif init == 'normal':
network.apply(norm_init)
elif init == 'identity':
network.apply(eye_init)