-
Notifications
You must be signed in to change notification settings - Fork 0
/
4_trainSvm.py
52 lines (41 loc) · 1.98 KB
/
4_trainSvm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import importlib
from fastRCNN.train_svms import SVMTrainer
import PARAMETERS
locals().update(importlib.import_module("PARAMETERS").__dict__)
#################################################
# Parameters
#################################################
experimentName = "exp1"
#no need to change these
cntkParsedOutputDir = cntkFilesDir + "train_svm_parsed/"
#################################################
# Main
#################################################
print " svm_targetNorm = " + str(svm_targetNorm)
print " svm_retrainLimit = " + str(svm_retrainLimit)
print " svm_posWeight = " + str(svm_posWeight)
print " svm_C = " + str(svm_C)
print " svm_B = " + str(svm_B)
print " svm_penality = " + str(svm_penality)
print " svm_loss = " + str(svm_loss)
print " svm_evictThreshold = " + str(svm_evictThreshold)
print " svm_nrEpochs = " + str(svm_nrEpochs)
#init
assert classifier == 'svm', "Error: classifier variable not set to 'svm' but to '{}'".format(classifier)
makeDirectory(trainedSvmDir)
np.random.seed(svm_rngSeed)
imdb = imdbs["train"]
net = DummyNet(4096, imdb.num_classes, cntkParsedOutputDir)
svmWeightsPath, svmBiasPath, svmFeatScalePath = getSvmModelPaths(trainedSvmDir, experimentName)
# add ROIs which significantly overlap with a ground truth object as positives
if train_posOverlapThres > 0:
print "Adding ROIs with gt overlap >= %2.2f as positives ..." % (train_posOverlapThres)
existingPosCounter, addedPosCounter = imdbUpdateRoisWithHighGtOverlap(imdb, train_posOverlapThres)
print " Number of positives originally: {} (in {} images)".format(existingPosCounter, imdb.num_images)
print " Number of additional positives: {}.".format(addedPosCounter)
# start training
svm = SVMTrainer(net, imdb, im_detect, svmWeightsPath, svmBiasPath, svmFeatScalePath,
svm_C, svm_B, svm_nrEpochs, svm_retrainLimit, svm_evictThreshold, svm_posWeight,
svm_targetNorm, svm_penality, svm_loss, svm_rngSeed)
svm.train()
print "DONE."