-
Notifications
You must be signed in to change notification settings - Fork 6
/
test.py
129 lines (110 loc) · 4 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
from tqdm import tqdm
import torch
import cv2
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from options import args_parser
from src.models import CVAE_imagenet
import copy
from PIL import Image
from torchvision import datasets, transforms
import os
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
import time
import pickle
from tensorboardX import SummaryWriter
from src import models
import sys
sys.path.append('.')
class Logger(object):
def __init__(self, filename='test.log', stream=sys.stdout):
self.terminal = stream
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def test_inference(args, model, test_dataset):
""" Returns the test accuracy and loss.
"""
model.eval()
loss, total, correct = 0.0, 0.0, 0.0
device = 'cuda' if args.gpu else 'cpu'
criterion = nn.CrossEntropyLoss().to(device)
testloader = DataLoader(test_dataset, batch_size=32,
shuffle=False)
for batch_idx, (images, labels) in enumerate(testloader):
images, labels = images.to(device), labels.to(device)
# Inference
outputs = model(images)
batch_loss = criterion(outputs[0], labels)
loss += batch_loss.item()
# Prediction
_, pred_labels = torch.max(outputs[0], 1)
pred_labels = pred_labels.view(-1)
correct += torch.sum(torch.eq(pred_labels, labels)).item()
total += len(labels)
accuracy = correct / total
return accuracy, loss
sys.stdout = Logger(stream=sys.stdout)
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
self.test = test
self.transforms = transforms
imgs = [os.path.join(root, img) for img in os.listdir(root)]
if self.test:
self.imgs = imgs
else:
self.imgs = imgs
def __getitem__(self, index):
img_path = self.imgs[index]
if self.test:
label = 1 if 'real' in img_path.split('/')[-1] else 0
else:
label = 1 if 'real' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
if __name__ == '__main__':
args = args_parser()
transform = transforms.Compose([
transforms.Resize((296, 296)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
transform_test = transforms.Compose([
transforms.Resize((296, 296)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
test_dataset = DogCat(r'testset', transforms=transform_test, train=False)# testset
# BUILD MODEL
if args.dataset == 'cnn':
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model =='FedForgery':
if args.dataset == 'forgery_dataset':
global_model = CVAE_imagenet(d=64, k=128, num_classes=2)
else:
exit('Error: unrecognized model')
path = 'pretrained/model.pth'
global_model.load_state_dict(torch.load(path))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print('Test on', len(test_dataset), 'samples')
print("Test Accuracy: {:.2f}%".format(100*test_acc))