-
Notifications
You must be signed in to change notification settings - Fork 4
/
model.lua
38 lines (33 loc) · 1.07 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
function loadModel(pkg, numClasses, imgSize, phase)
backend = pkg
local batch
if phase == 'train' then
batch = imgSize
else
batch = 1
end
-- Input of size: batch X imgSize X imgSize
-- Torch automatically infers batch size
model = backend.Sequential()
model:add(backend.SpatialConvolution(batch,64,3,3,1,1,1,1))
model:add(backend.ELU())
model:add(backend.SpatialMaxPooling(2,2))
model:add(backend.SpatialConvolution(64,128,3,3,1,1,1,1))
model:add(backend.ELU())
model:add(backend.SpatialMaxPooling(2,2))
model:add(backend.SpatialConvolution(128,256,3,3,1,1,1,1))
model:add(backend.ELU())
model:add(backend.SpatialMaxPooling(2,2))
model:add(backend.SpatialConvolution(256,512,3,3,1,1,1,1))
model:add(backend.ELU())
model:add(backend.SpatialMaxPooling(2,2))
model:add(backend.View(-1,512))
model:add(backend.Linear(512,1024))
model:add(backend.ELU())
model:add(backend.Dropout(0.5))
model:add(backend.View(512*8*8*2))
model:add(backend.Linear(512*8*8*2,numClasses))
model:add(backend.LogSoftMax())
return model
end
return loadModel