-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_scannet.py
170 lines (133 loc) · 6.29 KB
/
eval_scannet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
''' Evaluation script for the ScanNet dataset
Dynamics-aware Adversarial Attack of Adaptive Neural Networks
Author: Ziyi Wu, An Tao
Email: dazitu616@gmail.com, ta19@mails.tsinghua.edu.cn
Date: 2022/1/13
Required Inputs:
--data_path (str): Data path to the dataset.
Important Optional Inputs:
--attacked_coords (str): Evaluate the model performance with attacked point cloud coordinates.
The format of path is `outputs/scannet/budget_<your budget>/<your exp name>/coord`.
--exp_name (str): Assign an experiment name. Default is `Logs_<date>_<time>`.
--save_preds (store_true): Whether to save the class prediction results. Default is `False`.
--save_probs (store_true): Whether to save the probability values of classes. Default is `False`.
--visual (store_true): Whether to save the visualization results in `.ply` files. Default is `False`.
Example Usage:
python eval_scannet.py --data_path <data path>
'''
import os
from datetime import datetime
import numpy as np
np.seterr(divide='ignore',invalid='ignore')
import torch
import MinkowskiEngine as ME
import utils
from config.scannet import parser
from models.res16unet import Res16UNet34C
# Set labels for ScanNet dataset
VALID_CLASS_NAMES = utils.SCANNET_VALID_CLASS_NAMES
def load_file(file_name, voxel_size, attacked_coords):
"""
Load point clouds
"""
plydata = PlyData.read(file_name+'.ply')
data = plydata.elements[0].data
if attacked_coords is not None:
room_name = file_name.split('/')[-1]
coords = np.loadtxt(os.path.join(attacked_coords, room_name + '.txt'))
else:
coords = np.array([data['x'], data['y'], data['z']], dtype=np.float32).T
colors = np.array([data['red'], data['green'],
data['blue']], dtype=np.float32).T / 255.
labels = np.array(data['label'], dtype=np.int32)
feats = colors - 0.5
idx, inverse_idx, quan_coords, quan_feats = utils.sparse_quantize(
coords, feats, None, return_index=True,
return_inverse=True, quantization_size=voxel_size)
return inverse_idx, quan_coords, quan_feats, labels
def generate_input_sparse_tensor(file_name, voxel_size=0.02, attacked_coords=None):
"""
Obtain sparse tensor for input
"""
# Create a batch, this process is done in a data loader during training in parallel.
batch = [load_file(file_name, voxel_size, attacked_coords)]
inverse_idx, coordinates_, featrues_, labels = list(zip(*batch))
coordinates, features = ME.utils.sparse_collate(
coordinates_, featrues_, None)
return inverse_idx, coordinates, features.float(), labels[0]
if __name__ == '__main__':
### Initialization ###
config = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if config.exp_name is None:
dt = datetime.now()
config.exp_name = 'Logs_' + dt.strftime('%Y-%m-%d_%H-%M-%S')
save_path = os.path.join('outputs/scannet/eval', config.split, config.exp_name)
if not os.path.exists(save_path):
os.makedirs(save_path)
io = utils.IOStream(save_path + '/run.log')
# Define a model and load the weights
model = Res16UNet34C(3, 20, config).to(device)
model_dict = torch.load(config.weights)
model.load_state_dict(model_dict['state_dict'])
model.eval()
### Evaluation ###
label_all, pred_all = np.array([]), np.array([])
if config.split == 'val':
with open(os.path.join(config.data_path, 'scannetv2_val.txt'), 'r') as f:
all_rooms = f.readlines()
elif config.split == 'train':
with open(os.path.join(config.data_path, 'scannetv2_train.txt'), 'r') as f:
all_rooms = f.readlines()
all_rooms = [room[:-1] for room in all_rooms]
room_num = len(all_rooms)
num_classes = len(VALID_CLASS_NAMES)
io.cprint('ScanNet Class Number: {}'.format(num_classes))
# Start evaluation for each room
probs_pcl = None
for idx, room_name in enumerate(all_rooms):
with torch.no_grad():
data = os.path.join(config.data_path, room_name)
inverse_idx, coords_vox, feats_vox, labels_pcl = \
utils.generate_input_sparse_tensor_eval(
data,
voxel_size=config.voxel_size,
attacked_coords=config.attacked_coords,
dataset='scannet')
labels_pcl = utils.convert_label_scannet(labels_pcl)
# Feed-forward pass and get the prediction
sinput = ME.SparseTensor(feats_vox, coords=coords_vox).to(device)
soutput = model(sinput)
preds_vox = soutput.F.max(1)[1].cpu().numpy()
preds_pcl = preds_vox[inverse_idx]
if config.save_probs:
probs_vox = torch.nn.functional.softmax(soutput.F, dim=1).cpu().numpy()
probs_pcl = probs_vox[inverse_idx]
intersection, union, target = utils.intersectionAndUnion(
preds_pcl, labels_pcl, num_classes, 255)
mIoU = np.nanmean(intersection / union)
print('Room: {:>3}/{:>3} | mIoU: {:.4F}'.format(idx, room_num, mIoU))
# Save results
pred_all = np.hstack([pred_all, preds_pcl]) if \
pred_all.size else preds_pcl
label_all = np.hstack([label_all, labels_pcl]) if \
label_all.size else labels_pcl
torch.cuda.empty_cache()
if config.save_preds or config.save_probs:
utils.save_prediction(config, save_path, room_name, preds_pcl, probs_pcl, dataset='scannet')
# Visualization
if config.visual:
utils.visualize(config, room_name, None, preds_pcl, save_path)
intersection, union, target = \
utils.intersectionAndUnion(
pred_all, label_all, num_classes, 255)
iou_class = intersection / (union + 1e-10)
accuracy_class = intersection / (target + 1e-10)
mIoU = np.mean(iou_class)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection) / (sum(target) + 1e-10)
io.cprint('Evaluation Result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.
format(mIoU, mAcc, allAcc))
for i in range(num_classes):
io.cprint('Class_{} Result: IoU/Accuracy {:.4f}/{:.4f}, Name: {}.'.
format(i, iou_class[i], accuracy_class[i], VALID_CLASS_NAMES[i]))