-
Notifications
You must be signed in to change notification settings - Fork 9
/
DMSHN_test.py
77 lines (47 loc) · 1.7 KB
/
DMSHN_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from __future__ import absolute_import, division, print_function
import cv2
import os
import sys
import glob
import argparse
import numpy as np
import PIL.Image as pil
import matplotlib as mpl
import matplotlib.cm as cm
import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.utils import save_image
from skimage.measure import compare_psnr,compare_ssim
from tqdm import tqdm
import math
import numbers
import sys
import matplotlib.pyplot as plt
from DMSHN import DMSHN
device = torch.device("cuda:0")
feed_width = 1536
feed_height = 1024
bokehnet = DMSHN().to(device)
# bokehnet = nn.DataParallel(bokehnet)
bokehnet.load_state_dict(torch.load('checkpoints/DMSHN/dmshn.pth',map_location=device))
os.makedirs('outputs/DMSHN',exist_ok=True)
with torch.no_grad():
for i in tqdm(range(4400,4694)):
image_path = '/media/data2/saikat/bokeh_data/Training/original/' + str(i) + '.jpg' # change input path
# Load image and preprocess
input_image = pil.open(image_path).convert('RGB')
original_width, original_height = input_image.size
input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
# PREDICTION
input_image = input_image.to(device)
bok_pred = bokehnet(input_image)
bok_pred = F.interpolate(bok_pred,(original_height,original_width),mode = 'bilinear')
save_image(bok_pred,'./outputs/DMSHN/'+ str(i)+'.png')
del bok_pred
del input_image