-
Notifications
You must be signed in to change notification settings - Fork 39
/
eval.py
65 lines (48 loc) · 2.11 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import numpy as np
import argparse
import collections
import torchnet as tnt
from utils import util
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training')
parser.add_argument('--parallel', action ='store_true', default=False)
parser.add_argument('--workers', type=int, default=4)
parser.add_argument('--mode', default='video', help='video|clip')
parser.add_argument('--model', default='r50_nl', help='r50|r50_nl')
args = parser.parse_args()
def test():
net.eval()
topk = [1, 5]
loss_meters = collections.defaultdict(lambda: tnt.meter.AverageValueMeter())
for idx, batch in enumerate(testloader):
batch = util.batch_cuda(batch)
pred, loss_dict = net(batch)
loss_dict = {k:v.mean() for k,v in loss_dict.items() if v.numel()>0}
loss = sum(loss_dict.values())
for k, v in loss_dict.items():
loss_meters[k].add(v.item())
prec_scores = util.accuracy(pred, batch['label'], topk=topk)
for k, prec in zip(topk, prec_scores):
loss_meters['P%s'%k].add(prec.item(), pred.shape[0])
stats = ' | '.join(['%s: %.3f'%(k, v.value()[0]) for k,v in loss_meters.items()])
print ('%d/%d.. %s'%(idx, len(testloader), stats))
print ('(test) %s'%stats)
#----------------------------------------------------------------------------------------------------------------------------------------#
from data import kinetics
from models import resnet
if args.mode == 'video':
testset = kinetics.KineticsMultiCrop(root='data/kinetics/', split='val', clip_len=32)
elif args.mode == 'clip':
testset = kinetics.Kinetics(root='data/kinetics/', split='val', clip_len=32)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
if args.model=='r50':
net = resnet.i3_res50(num_classes=len(testset.labels))
elif args.model=='r50_nl':
net = resnet.i3_res50_nl(num_classes=len(testset.labels))
net.cuda()
if args.parallel:
net = nn.DataParallel(net)
with torch.no_grad():
test()