-
Notifications
You must be signed in to change notification settings - Fork 29
/
test_eth3d.py
101 lines (74 loc) · 3.33 KB
/
test_eth3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd import grad as Grad
from torchvision import transforms
import os
import copy
import skimage.io
from collections import OrderedDict
from tqdm import tqdm, trange
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import argparse
from models import __models__
from datasets import ETH3D_loader as et
from datasets.readpfm import readPFM
import cv2
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = argparse.ArgumentParser(description='Accurate and Real-Time Stereo Matching via Context and Geometry Interaction (ICG-Stereo)')
parser.add_argument('--model', default='CGI_Stereo', help='select a model structure', choices=__models__.keys())
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
parser.add_argument('--datapath', default="/data/ETH3D/", help='data path')
parser.add_argument('--loadckpt', default='./pretrained_models/CGI_Stereo/sceneflow.ckpt',help='load the weights from a specific checkpoint')
# parse arguments
args = parser.parse_args()
all_limg, all_rimg, all_disp, all_mask = et.et_loader(args.datapath)
model = __models__[args.model](args.maxdisp)
model = nn.DataParallel(model)
model.cuda()
model.eval()
os.makedirs('./demo/ETH3D/', exist_ok=True)
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'])
pred_mae = 0
pred_op = 0
for i in trange(len(all_limg)):
limg = Image.open(all_limg[i]).convert('RGB')
rimg = Image.open(all_rimg[i]).convert('RGB')
w, h = limg.size
wi, hi = (w // 32 + 1) * 32, (h // 32 + 1) * 32
limg = limg.crop((w - wi, h - hi, w, h))
rimg = rimg.crop((w - wi, h - hi, w, h))
limg_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg)
rimg_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg)
limg_tensor = limg_tensor.unsqueeze(0).cuda()
rimg_tensor = rimg_tensor.unsqueeze(0).cuda()
disp_gt, _ = readPFM(all_disp[i])
disp_gt = np.ascontiguousarray(disp_gt, dtype=np.float32)
disp_gt[disp_gt == np.inf] = 0
gt_tensor = torch.FloatTensor(disp_gt).unsqueeze(0).unsqueeze(0).cuda()
occ_mask = np.ascontiguousarray(Image.open(all_mask[i]))
with torch.no_grad():
pred_disp = model(limg_tensor, rimg_tensor)[-1]
pred_disp = pred_disp[:, hi - h:, wi - w:]
predict_np = pred_disp.squeeze().cpu().numpy()
op_thresh = 1
mask = (disp_gt > 0) & (occ_mask == 255)
# mask = disp_gt > 0
error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
pred_error = np.abs(predict_np * mask.astype(np.float32) - disp_gt * mask.astype(np.float32))
pred_op += np.sum(pred_error > op_thresh) / np.sum(mask)
pred_mae += np.mean(pred_error[mask])
########save
filename = os.path.join('./demo/ETH3D/', all_limg[i].split('/')[-2]+all_limg[i].split('/')[-1])
pred_np_save = np.round(predict_np*4 * 256).astype(np.uint16)
cv2.imwrite(filename, cv2.applyColorMap(cv2.convertScaleAbs(pred_np_save, alpha=0.01),cv2.COLORMAP_JET), [int(cv2.IMWRITE_PNG_COMPRESSION), 0])
print(pred_mae / len(all_limg))
print(pred_op / len(all_limg))