-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.lua
165 lines (108 loc) · 3.76 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
----------------------------------------------------------------------
-- SGD training of the network
----------------------------------------------------------------------
require 'optim'
require 'xlua'
require 'cutorch'
require 'cunn'
----------------------------------------------------------------------
-- parse command line arguments
if not opt then
print '==> processing options'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Options:')
cmd:option('-save', '/usr/local/data/jtaylor/Deep/save', 'subdirectory to save/log experiments in')
cmd:option('-LR', 1e-3, 'learning rate at t=0')
cmd:option('-LRDecay', 1e-5, 'learning rate decay')
cmd:option('-momentum', 0.9, 'momentum')
cmd:option('-weightDecay',1e-7,'weight decay')
cmd:text()
opt = cmd:parse(arg or {})
end
-- training logs
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
-- get model parameters
parameters,gradParameters = model:getParameters()
-- configure SGD
optimState = {
learningRate = opt.LR,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.LRDecay
}
optimMethod = optim.sgd
paths.dofile('dataset.lua')
loader = dataLoader(opt.fps,opt.datapath)
printFreq = 10 -- freq to print confusion matrix during epoch
function train()
-- time stuff
local time = sys.clock()
epoch = epoch or 1
--confusion = optim.ConfusionMatrix({1,2})
confusion = optim.ConfusionMatrix(loader.classes)
-- set model to training mode (for modules that differ in training and testing, like Dropout)
model:training()
-- shuffle at each epoch
shuffle = torch.randperm(#loader.trainIndeces)
print("==> online epoch # " .. epoch)
for t = 1,#loader.trainIndeces do
-- progress bar
xlua.progress(t, #loader.trainIndeces)
-- load data and labels
local inputCPU, labelsCPU = loader:get(loader.trainIndeces[shuffle[t]])
if #inputCPU > 0 then
local labels = labelsCPU:cuda()
local input = {}
for i = 1,#inputCPU do
input[i] = inputCPU[i]:cuda()
end
-- train current sample
local feval = function(x)
-- reset gradients
gradParameters:zero()
model:forget()
-- forward pass
local output = model:forward(input)
local err = criterion:forward(output,labels)
-- update confusion matrix
-- requires tensor instead of table
for i = 1,#output do
confusion:add(output[i],labels[i])
end
-- backprop
local gradOutputs = criterion:backward(output,labels)
model:backward(input,gradOutputs)
-- normalize
gradParameters:div(#input)
err = err/#input
return err,gradParameters
end
optimMethod(feval,parameters,optimState)
model:forget()
collectgarbage()
end
-- print updates periodically throughout epoch
if t%printFreq==0 then
--print(confusion)
confusion:updateValids()
print('mean class accuracy = ' .. confusion.totalValid*100 .. '%')
end
end
-- time taken
time = sys.clock()-time
print("\n==> training time = " .. (time*1000) .. 'ms')
print(confusion)
-- print accuracy update
confusion:updateValids() -- necessary to get .totalValid without printing full matrix
print('==> mean class accuracy = ' .. confusion.totalValid*100 .. '%')
-- update training log
trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid*100}
-- save/log current net
print('==> saving model to ' .. opt.save)
collectgarbage()
--model:clearState() -- saves a lot of space
--torch.save(paths.concat(opt.save,'model_' .. epoch .. '.t7'),model)
-- return global accuracy
return 100*confusion.totalValid
end