-
Notifications
You must be signed in to change notification settings - Fork 13
/
doall.lua
50 lines (39 loc) · 1.16 KB
/
doall.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
---------- library ----------
require 'nn'
require 'cunn'
---------- settings ----------
cmd = torch.CmdLine()
cmd:text()
cmd:text('SVHN Loss Function')
cmd:text()
cmd:text('Options:')
-- global:
cmd:option('-seed', 91, 'fixed input seed for repeatable experiments')
cmd:option('-threads', 2, 'number of threads')
-- path:
cmd:option('-path_models', 'models', 'subdirectory to save models')
cmd:option('-path_saveimg', 'save', 'subdirectory to save images')
cmd:option('-path_submission', 'submission', 'subdirectory to submission file')
-- training:
cmd:option('-learningRate', 1e-3, 'learning rate at t=0')
cmd:option('-weightDecay', 0, 'weight decay')
cmd:option('-momentum', 0, 'momentum')
cmd:option('-batchSize', 1, 'mini-batch size (1 = pure stochastic)')
cmd:text()
opt = cmd:parse(arg or {})
torch.setdefaulttensortype('torch.FloatTensor')
torch.manualSeed(opt.seed)
cutorch.manualSeed(opt.seed)
---------- read dataset and function ----------
dofile "1_data.lua"
dofile "2_model.lua"
dofile "3_train.lua"
dofile "4_test.lua"
dofile "5_valid.lua"
---------- execute ----------
print("==> training")
train()
print("==> validation")
valid()
print("==> test")
test()