-
Notifications
You must be signed in to change notification settings - Fork 25
/
test.py
69 lines (52 loc) · 1.94 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import warnings
warnings.filterwarnings("ignore")
import os
import numpy as np
from tqdm import tqdm
from imageio import imsave
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils import data
from torchvision.utils import save_image
from models.generator.generator import Generator
from datasets.dataset import create_image_dataset
from options.test_options import TestOptions
from utils.misc import sample_data, postprocess
is_cuda = torch.cuda.is_available()
if is_cuda:
print('Cuda is available')
cudnn.enable = True
cudnn.benchmark = True
opts = TestOptions().parse
os.makedirs('{:s}'.format(opts.result_root), exist_ok=True)
# model & load model
generator = Generator(image_in_channels=3, edge_in_channels=2, out_channels=3)
if opts.pre_trained != '':
generator.load_state_dict(torch.load(opts.pre_trained)['generator'])
else:
print('Please provide pre-trained model!')
if is_cuda:
generator = generator.cuda()
# dataset
image_dataset = create_image_dataset(opts)
image_data_loader = data.DataLoader(
image_dataset,
batch_size=opts.batch_size,
shuffle=True,
num_workers=opts.num_workers,
drop_last=False
)
image_data_loader = sample_data(image_data_loader)
print('start test...')
with torch.no_grad():
generator.eval()
for _ in tqdm(range(opts.number_eval)):
ground_truth, mask, edge, gray_image = next(image_data_loader)
if is_cuda:
ground_truth, mask, edge, gray_image = ground_truth.cuda(), mask.cuda(), edge.cuda(), gray_image.cuda()
input_image, input_edge, input_gray_image = ground_truth * mask, edge * mask, gray_image * mask
output, __, __ = generator(input_image, torch.cat((input_edge, input_gray_image), dim=1), mask)
output_comp = ground_truth * mask + output * (1 - mask)
output_comp = postprocess(output_comp)
save_image(output_comp, opts.result_root + '/{:05d}.png'.format(_))