-
Notifications
You must be signed in to change notification settings - Fork 239
/
util.lua
169 lines (143 loc) · 4.93 KB
/
util.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
-- Modified by Mohammad Rastegari (Allen Institute for Artificial Intelligence (AI2))
local ffi=require 'ffi'
function computeScore(output, target, nCrops)
if nCrops > 1 then
-- Sum over crops
output = output:view(output:size(1) / nCrops, nCrops, output:size(2))
--:exp()
:sum(2):squeeze(2)
end
-- Coputes the top1 and top5 error rate
local batchSize = output:size(1)
local _ , predictions = output:float():sort(2, true) -- descending
-- Find which predictions match the target
local correct = predictions:eq(
target:long():view(batchSize, 1):expandAs(output))
local top1 = correct:narrow(2, 1, 1):sum() / batchSize
local top5 = correct:narrow(2, 1, 5):sum() / batchSize
return top1 * 100, top5 * 100
end
function makeDataParallel(model, nGPU)
if nGPU > 1 then
print('converting module to nn.DataParallelTable')
assert(nGPU <= cutorch.getDeviceCount(), 'number of GPUs less than nGPU specified')
local model_single = model
model = nn.DataParallelTable(1)
for i=1, nGPU do
cutorch.setDevice(i)
model:add(model_single:clone():cuda(), i)
end
end
cutorch.setDevice(opt.GPU)
return model
end
local function cleanDPT(module)
return module:get(1)
end
function saveDataParallel(filename, model)
if torch.type(model) == 'nn.DataParallelTable' then
torch.save(filename, cleanDPT(model))
elseif torch.type(model) == 'nn.Sequential' then
torch.save(filename, model)
else
error('This saving function only works with Sequential or DataParallelTable modules.')
end
end
function loadParams(model,saved_model)
params = model:parameters();
local saved_params = saved_model:parameters();
for i=1,#params do
params[i]:copy(saved_params[i]);
end
local bn= model:findModules("nn.SpatialBatchNormalization")
local saved_bn= saved_model:findModules("nn.SpatialBatchNormalization")
for i=1,#bn do
bn[i].running_mean:copy(saved_bn[i].running_mean)
bn[i].running_var:copy(saved_bn[i].running_var)
end
end
function zeroBias(convNodes)
for i =1, #convNodes do
local n = convNodes[i].bias:fill(0)
end
end
function updateBinaryGradWeight(convNodes)
for i =2, #convNodes-1 do
local n = convNodes[i].weight[1]:nElement()
local s = convNodes[i].weight:size()
local m = convNodes[i].weight:norm(1,4):sum(3):sum(2):div(n):expand(s);
m[convNodes[i].weight:le(-1)]=0;
m[convNodes[i].weight:ge(1)]=0;
m:add(1/(n)):mul(1-1/s[2])
if opt.optimType == 'sgd' then
m:mul(n);
end
convNodes[i].gradWeight:cmul(m)--:cmul(mg)
end
if opt.nGPU >1 then
model:syncParameters()
end
end
function meancenterConvParms(convNodes)
for i =2, #convNodes-1 do
local s = convNodes[i].weight:size()
local negMean = convNodes[i].weight:mean(2):mul(-1):repeatTensor(1,s[2],1,1);
convNodes[i].weight:add(negMean)
end
if opt.nGPU >1 then
model:syncParameters()
end
end
function binarizeConvParms(convNodes)
for i =2, #convNodes-1 do
local n = convNodes[i].weight[1]:nElement()
local s = convNodes[i].weight:size()
local m = convNodes[i].weight:norm(1,4):sum(3):sum(2):div(n);
convNodes[i].weight:sign():cmul(m:expand(s))
end
if opt.nGPU >1 then
model:syncParameters()
end
end
function clampConvParms(convNodes)
for i =2, #convNodes-1 do
convNodes[i].weight:clamp(-1,1)
end
if opt.nGPU >1 then
model:syncParameters()
end
end
function rand_initialize(layer)
local tn = torch.type(layer)
if tn == "cudnn.SpatialConvolution" then
local c = math.sqrt(2.0 / (layer.kH * layer.kW * layer.nInputPlane));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "nn.SpatialConvolution" then
local c = math.sqrt(2.0 / (layer.kH * layer.kW * layer.nInputPlane));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "nn.BinarySpatialConvolution" then
local c = math.sqrt(2.0 / (layer.kH * layer.kW * layer.nInputPlane));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "nn.SpatialConvolutionMM" then
local c = math.sqrt(2.0 / (layer.kH * layer.kW * layer.nInputPlane));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "cudnn.VolumetricConvolution" then
local c = math.sqrt(2.0 / (layer.kH * layer.kW * layer.nInputPlane));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "nn.Linear" then
local c = math.sqrt(2.0 / layer.weight:size(2));
layer.weight:copy(torch.randn(layer.weight:size()) * c)
layer.bias:fill(0)
elseif tn == "nn.SpatialBachNormalization" then
layer.weight:fill(1)
layer.bias:fill(0)
elseif tn == "cudnn.SpatialBachNormalization" then
layer.weight:fill(1)
layer.bias:fill(0)
end
end