forked from erobic/negative_analysis_of_grounding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
122 lines (101 loc) · 3.66 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pickle
from dataset import Dictionary
from tqdm import tqdm
from models.models import Model_explain2
import opts
def instance_bce_with_logits(logits, labels):
assert logits.dim() == 2
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
loss *= labels.size(1)
return loss
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1].data # argmax
one_hots = torch.zeros(*labels.size()).cuda()
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
def compute_score_with_k_logits(logits, labels, k=5):
logits = torch.sort(logits, 1)[1].data # argmax
scores = torch.zeros((labels.size(0), k))
for i in range(k):
one_hots = torch.zeros(*labels.size()).cuda()
one_hots.scatter_(1, logits[:, -i - 1].view(-1, 1), 1)
scores[:, i] = (one_hots * labels).squeeze().sum(1)
scores = scores.max(1)[0]
return scores
def evaluate(model, dataloader):
score = 0
scorek = 0
score1 = 0
V_loss = 0
V_loss1 = 0
qid2type = pickle.load(open('qid2type.pkl', 'rb'))
upper_bound = 0
num_data = 0
score_yesno = 0
score_number = 0
score_other = 0
total_yesno = 0
total_number = 0
total_other = 0
for objs, q, a, hintscore, _, qids in tqdm(iter(dataloader)):
objs = objs.cuda().float().requires_grad_()
q = q.cuda().long()
a = a.cuda() # true labels
hintscore = hintscore.cuda().float()
pred, _, ansidx = model(q, objs)
#loss = instance_bce_with_logits(pred, a)
#V_loss += loss.item() * objs.size(0)
batch_score = compute_score_with_logits(pred, a.data).cpu().numpy().sum(1)
score += batch_score.sum()
upper_bound += (a.max(1)[0]).sum()
num_data += pred.size(0)
qids = qids.detach().cpu().int().numpy()
for j in range(len(qids)):
qid = qids[j]
typ = qid2type[qid]
if typ == 'yes/no':
score_yesno += batch_score[j]
total_yesno += 1
elif typ == 'other':
score_other += batch_score[j]
total_other += 1
elif typ == 'number':
score_number += batch_score[j]
total_number += 1
else:
print('Hahahahahahahahahahaha')
score = score / len(dataloader.dataset)
V_loss /= len(dataloader.dataset)
score_yesno /= total_yesno
score_other /= total_other
score_number /= total_number
return score, score_yesno, score_other, score_number
if __name__ == '__main__':
opt = opts.parse_opt()
dictionary = Dictionary.load_from_file(f'{opts.data_dir}/dictionary.pkl')
opt.ntokens = dictionary.ntoken
model = Model_explain2(opt)
model = model.cuda()
model = nn.DataParallel(model).cuda()
# model = model.cuda()
eval_dset = GraphQAIMGDataset('v2cp_test', dictionary, opt)
eval_loader = DataLoader(eval_dset, opt.batch_size, shuffle=False, num_workers=0)
states_ = torch.load('saved_models/%s/model-best.pth'%opt.load_model_states)
states = model.state_dict()
for k in states_.keys():
if k in states:
states[k] = states_[k]
print('copying %s' % k)
else:
print('ignoring %s' % k)
model.load_state_dict(states)
model.eval()
score, score_yesno, score_other, score_number = evaluate(model, eval_loader)
print('Overall: %.3f\n' % score)
print('Yes/No: %.3f\n' % score_yesno)
print('Number: %.3f\n' % score_number)
print('Other: %.3f\n' % score_other)