-
Notifications
You must be signed in to change notification settings - Fork 4
/
eval_mc.py
71 lines (56 loc) · 2.25 KB
/
eval_mc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os.path as osp
from utils import load_file
map_name = {'CW': 'Why', 'CH': 'How', 'TN': 'Bef&Aft', 'TC': 'When', 'DC': 'Cnt', 'DL': 'Loc', 'DO': 'Other', 'C': 'Acc_C', 'T': 'Acc_T', 'D': 'Acc_D'}
def accuracy_metric(sample_list_file, result_file):
sample_list = load_file(sample_list_file)
group = {'CW':[], 'CH':[], 'TN':[], 'TC':[], 'DC':[], 'DL':[], 'DO':[]}
for id, row in sample_list.iterrows():
qns_id = str(row['video']) + '_' + str(row['qid'])
qtype = str(row['type'])
#(combine temporal qns of previous and next as 'TN')
if qtype == 'TP': qtype = 'TN'
group[qtype].append(qns_id)
preds = load_file(result_file)
group_acc = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DC': 0, 'DL': 0, 'DO': 0}
group_cnt = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DC': 0, 'DL': 0, 'DO': 0}
overall_acc = {'C':0, 'T':0, 'D':0}
overall_cnt = {'C':0, 'T':0, 'D':0}
all_acc = 0
all_cnt = 0
for qtype, qns_ids in group.items():
cnt = 0
acc = 0
for qid in qns_ids:
cnt += 1
answer = preds[qid]['answer']
pred = preds[qid]['prediction']
if answer == pred:
acc += 1
group_cnt[qtype] = cnt
group_acc[qtype] += acc
overall_acc[qtype[0]] += acc
overall_cnt[qtype[0]] += cnt
all_acc += acc
all_cnt += cnt
for qtype, value in overall_acc.items():
group_acc[qtype] = value
group_cnt[qtype] = overall_cnt[qtype]
for qtype in group_acc:
print(map_name[qtype], end='\t')
print('')
for qtype, acc in group_acc.items():
print('{:.2f}'.format(acc*100.0/group_cnt[qtype]), end ='\t')
print('')
print('Acc: {:.2f}'.format(all_acc*100.0/all_cnt))
def main(result_file, dataset_dir, mode='val'):
# dataset_dir = 'dataset/nextqa/'
data_set = mode
sample_list_file = osp.join(dataset_dir, data_set+'.csv')
print('Evaluating {}'.format(result_file))
accuracy_metric(sample_list_file, result_file)
if __name__ == "__main__":
res_dir = 'results/nextqa/'
mode = 'test'
model_prefix = 'HQGA-bert-16c20b-2L05GCN-FCV-AC-VM'
result_file = '{}/{}-{}.json'.format(res_dir, model_prefix, mode)
main(result_file, mode)