-
Notifications
You must be signed in to change notification settings - Fork 9
/
test_mRNN.py
131 lines (117 loc) · 5 KB
/
test_mRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import fasta, preprocessing, model, evaluate
import sys, os, getopt
#########
# USAGE #
#########
'''
Prints the usage statement and all options
'''
def usage():
script = os.path.basename(__file__)
print "\n\nUsage: " + script + " [options] <positive fasta> <negative fasta>"
print('''
Options:
-h --help\t\tprints this help message.
-o --output\t\tthe file-base for the output files.
-w --weights\tpkl file of the model/model weights.
-E --epochs\tNumber of epochs to train on.(default=100)
-b --batch_size\tbatch size for testing (default=64)
-e --embedding_size\tNumber of dimensions in embedding (default=256)
-r --recurrent_gate_size\tSize of recurrent gate (default=512)
-d --dropout\tThe dropout probability p_dropout (default=0.4)
-t --test\tProportion of data to test on. (default=0.1)
-l --min_length\tminimum length sequence to train on (default=200)
-L --max_length\tmaximum length sequence to train on (default=1000)
-f --file_label\tA text label on the accuracy output files.
''')
sys.exit()
#########
# MAIN #
#########
'''
The main loop. Parse input options, run training sequence.
'''
def main():
# Options
opts, files = getopt.getopt(sys.argv[1:], "hvo:w:E:b:e:r:d:t:p:f:", ["help",
"output=",
"weights=",
"epochs=",
"batch_size=",
"embedding_size=",
"recurrent_gate_size=",
"dropout=",
"test=",
"min_length=",
"max_length=",
"file_label=",
])
if len(files) != 2:
usage()
posFastaFile = files[0]
negFastaFile = files[1]
print "using positive file: ", posFastaFile
print "using negative file: ", negFastaFile
# Defaults:
parameters = {}
parameters['output'] = None
parameters['verbose'] = False
parameters['weights'] = None
parameters['batch_size'] = 16
parameters['embedding_size'] = 128
parameters['recurrent_gate_size'] = 256
parameters['dropout'] = 0.1
parameters['test'] = 0.1
parameters['min_length'] = 200
parameters['max_length'] = 1000
parameters['num_train'] = 10000
parameters['epochs'] = 50
parameters['save_freq'] = 3
parameters['file_label'] = ""
# loop over options:
for option, argument in opts:
if option == "-v":
parameters[verbose] = True
elif option in ("-h", "--help"):
usage()
elif option in ("-o", "--output"):
parameters['output'] = argument
elif option in ("-w", "--weights"):
parameters['weights'] = argument
elif option in ("-E", "--epochs"):
parameters['epochs'] = int(argument)
elif option in ("-b", "--batch_size"):
parameters['batch_size'] = int(argument)
elif option in ("-e", "--embedding_size"):
parameters['embedding_size'] = int(argument)
elif option in ("-d", "--dropout"):
parameters['dropout'] = float(argument)
elif option in ("-t", "--test"):
parameters['test'] = float(argument)
elif option in ("-l", "--min_length"):
parameters['min_length'] = int(argument)
elif option in ("-L", "--max_length"):
parameters['max_length'] = int(argument)
elif option in ("-n", "--num_train"):
parameters['num_train'] = int(argument)
elif option in ("-f", "--file_label"):
parameters['file_label'] = argument
else:
assert False, "unhandled option"
##########
## MAIN ##
##########
print "Reading input files..."
positives = fasta.load_fasta(posFastaFile,parameters['min_length'])
negatives = fasta.load_fasta(negFastaFile,parameters['min_length'])
test = positives,negatives
print "Building model..."
if not parameters['weights']:
print "No weights given with -w parameter.\n"
sys.exit()
mRNN = model.build_model(parameters['weights'],parameters['embedding_size'],parameters['recurrent_gate_size'],5,parameters['dropout'])
print "Evaluating model..."
conf_mat = evaluate.evaluate_model(mRNN, test, parameters['batch_size'])
acc = evaluate.process_results(conf_mat,parameters)
if __name__ == "__main__":
main()