-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
overall_driver.py
210 lines (170 loc) · 6.76 KB
/
overall_driver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# -*- coding: utf-8 -*-
import argparse
import configparser
import logging
import json
import sys
from wbia_lca import ga_driver
from wbia_lca import db_interface_sim
from wbia_lca import edge_generator_sim
logger = logging.getLogger('wbia_lca')
"""
This is a top-level driver for the LCA graph algorithm, written mostly
to illustrate use of the graph algorithm through small examples and
simulations. Of particular note there are two objects created here
that must be replaced by objects that are connected to "real"
information about animals: the edge and id database, and the
generator for edges through calls to verifiers or calls to human
reviewers.
Three key files are needed here:
1. The configuration file. See the config.ini example
2. The JSON file of recent verifier ground truth positive and negative
probability results. Note as a reminder that both the values of the
probabilities AND the relative fraction of positive and negative
ground truth samples are important here.
3. The request JSON file, which includes the simulated database, the
simulated edge generator and the actual query request. See
request_example.json
Note that the first two will always be needed even if this is running
"for real", as will the actual query (with the request JSON). So in
an non-simulation, only the database and edge generator object need to
be replaced.
"""
def form_database(request):
"""
From the request JSON object extract the database if it is there.
If not, return an empty database. The JSON includes edge quads
(n0, n1, w, aug_name) and a clustering dictionary.
"""
edge_quads = []
clustering_dict = dict()
if 'database' not in request:
return edge_quads, clustering_dict
req_db = request['database']
if 'quads' in req_db:
edge_quads = req_db['quads']
if 'clustering' in req_db:
clustering_dict = {str(cid): c for cid, c in req_db['clustering'].items()}
db = db_interface_sim.db_interface_sim(edge_quads, clustering_dict)
return db
def form_edge_generator(request, db, wgtr):
"""
Form the edge generator object. Unlike the database, the generator
must be there for the small example / simulator to run.
"""
try:
gen_dict = request['generator']
except KeyError:
logger.info('Information about the edge generator must be in the request.')
sys.exit(1)
# Get hand-specified results from the verifier that aren't in the
# database yet. These are prob_quads of the form (n0, n1, prob,
# aug_name). The weighter will be used to turn the prob into a
# weight.
prob_quads = []
if 'verifier' in gen_dict:
prob_quads = gen_dict['verifier']
# Get human decisions of the form (n0, n1, bool). These will be
# returned as new edges when first requested
human_triples = []
if 'human' in gen_dict:
human_triples = gen_dict['human']
# Get the ground truth clusters - used to generate edges that
# aren't listed explicitly
gt_clusters = []
if 'gt_clusters' in gen_dict:
gt_clusters = gen_dict['gt_clusters']
# Get the nodes to be removed early in the computation.
nodes_to_remove = []
if 'nodes_to_remove' in gen_dict:
nodes_to_remove = gen_dict['nodes_to_remove']
# Get the number of steps between returning edge generation
# results. If this value is 0 then they are returned immediately
# upon request.
delay_steps = 0
if 'delay_steps' in gen_dict:
delay_steps = gen_dict['delay_steps']
edge_gen = edge_generator_sim.edge_generator_sim(
db, wgtr, prob_quads, human_triples, gt_clusters, nodes_to_remove, delay_steps
)
return edge_gen
def extract_requests(request, db):
try:
req_dict = request['query']
except KeyError:
logger.info('Information about the GA query itself must be in the request JSON.')
sys.exit(1)
# 1. Get the verifier result quads (n0, n1, prob, aug_name).
verifier_results = []
if 'verifier' in req_dict:
verifier_results = req_dict['verifier']
# 2. Get the human decision result triples (n0, n1, bool)
# No error checking is used
human_decisions = []
if 'human' in req_dict:
human_decisions = req_dict['human']
# 3. Get the list of existing cluster ids to check
cluster_ids_to_check = []
if 'cluster_ids' in req_dict:
cluster_ids_to_check = req_dict['cluster_ids']
for cid in cluster_ids_to_check:
logger.info(cid)
logger.info(cluster_ids_to_check)
if not db.cluster_exists(cid):
logger.info('GA request cluster id %s does not exist' % cid)
raise ValueError('Cluster id does not exist')
return verifier_results, human_decisions, cluster_ids_to_check
if __name__ == '__main__':
parser = argparse.ArgumentParser('overall_driver.py')
parser.add_argument(
'--ga_config', type=str, required=True, help='graph algorithm config INI file'
)
parser.add_argument(
'--verifier_gt',
type=str,
required=True,
help='json file containing verification algorithm ground truth',
)
parser.add_argument(
'--request',
type=str,
required=True,
help='json file continain graph algorithm request info',
)
parser.add_argument(
'--db_result', type=str, help='file to write resulting json database'
)
# 1. Configuration
args = parser.parse_args()
config_ini = configparser.ConfigParser()
config_ini.read(args.ga_config)
# 2. Recent results from verification ground truth tests. Used to
# establish the weighter.
with open(args.verifier_gt, 'r') as fn:
verifier_gt = json.loads(fn.read())
# 3. Form the parameters dictionary and weight objects (one per
# verification algorithm).
ga_params, wgtrs = ga_driver.params_and_weighters(config_ini, verifier_gt)
if len(wgtrs) > 1:
logger.info('Not currently handling more than one weighter!!')
sys.exit(1)
wgtr = wgtrs[0]
# 4. Get the request dictionary, which includes the database, the
# actual request edges and clusters, and the edge generator edges
# and ground truth (for simulation).
with open(args.request, 'r') as fn:
request = json.loads(fn.read())
db = form_database(request)
edge_gen = form_edge_generator(request, db, wgtr)
verifier_req, human_req, cluster_req = extract_requests(request, db)
# 5. Form the graph algorithm driver
driver = ga_driver.ga_driver(
verifier_req, human_req, cluster_req, db, edge_gen, ga_params
)
# 6. Run it. Changes are logged.
ccPIC_gen = driver.run_all_ccPICs()
changes_to_review = list(ccPIC_gen)
logger.info(changes_to_review)
# 7. Commit changes. Record them in the database and the log
# file.
# TBD