-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_bimanual_axis.py
172 lines (147 loc) · 7.67 KB
/
test_bimanual_axis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Script to test the PointNet model on the bimanual axis prediction task.
Outputs are saved in <log-dir>/eval
Run as:
python test_bimanual_axis.py --obj tissue --log_dir pointnet_reg --normal --split val
python test_bimanual_axis.py --obj tissue --log_dir pointnet_reg --normal --split test
"""
import argparse
import os
import os.path as osp
from pathlib import Path
from data_utils.BimanualDataLoader import PartNormalDataset, pc_normalize
from visualizer.bimanual_utils import visualize_pcl_axis
import torch
import logging
import sys
import importlib
from tqdm import tqdm
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
def point_unnormalize(pc, point):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
point_unnormalized = point * m + centroid
return point_unnormalized
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('PointNet')
parser.add_argument('--data_dir', type=str, default='data/bimanual', help='data directory')
parser.add_argument('--obj', type=str, default='tissue', help='object to evaluate')
parser.add_argument('--split', type=str, default='val', help='Choose from: val, test')
parser.add_argument('--task', type=str, default='axis', help='Choose from: contact, axis')
parser.add_argument('--use_q', action='store_true', default=False, help='use q in axis prediction')
parser.add_argument('--use_s', action='store_true', default=False, help='use s in axis prediction')
parser.add_argument('--log_dir', type=str, required=True, help='experiment root')
parser.add_argument('--batch_size', type=int, default=16, help='batch size in validation')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--num_point', type=int, default=2048, help='point Number')
parser.add_argument('--normal', action='store_true', default=False, help='use normals')
parser.add_argument('--num_votes', type=int, default=3, help='aggregate segmentation scores with voting')
parser.add_argument('--mat_diff_loss_scale', type=float, default=0.001, help='weight for matching different loss')
parser.add_argument('--axis_loss_scale', type=float, default=1.0, help='weight for axis loss')
return parser.parse_args()
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
experiment_dir = osp.join('log/axis_reg', args.obj, args.log_dir)
eval_dir = osp.join(experiment_dir, 'eval')
os.makedirs(eval_dir, exist_ok=True)
savedir = osp.join(eval_dir, 'viz')
os.makedirs(savedir, exist_ok=True)
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/log.txt' % eval_dir)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
datapath = osp.join(args.data_dir, args.obj)
if args.use_q and args.use_s:
k = 6
else:
k = 3
'''MODEL LOADING'''
model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
MODEL = importlib.import_module(model_name)
classifier = MODEL.get_model(k, normal_channel=args.normal).cuda()
criterion = MODEL.get_loss(mat_diff_loss_scale=args.mat_diff_loss_scale, axis_loss_scale=args.axis_loss_scale).cuda()
checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
# checkpoint = torch.load(str(experiment_dir) + '/checkpoints/5500.pth')
classifier.load_state_dict(checkpoint['model_state_dict'])
classifier = classifier.eval()
if args.split == 'test':
with torch.no_grad():
data = np.loadtxt(os.path.join(datapath, 'test_bottle_11.csv'), delimiter=',').astype(np.float32) # dim nx6: [x,y,z,nx,ny,nz]
if not args.normal:
points = data[:, 0:3]
else:
points = data[:, 0:6]
points_unnormalized = points[:, :3].copy()
points[:, 0:3] = pc_normalize(points[:, 0:3])
choice = np.random.choice(len(points), args.num_point, replace=True)
points = points[choice, :]
points = torch.from_numpy(points).float().cuda()
points = points.transpose(1, 0).unsqueeze(0) # dim 1x6xn
axis_pred, _ = classifier(points) # dim 1x3 or 1x6
# visualization
points = points.transpose(2, 1).cpu().numpy() # dim 1xnx6
axisp = axis_pred[0].cpu().data.numpy() # dim 3 or 6
savepath = osp.join(eval_dir, 'test.png')
if args.use_q and args.use_s:
axisp_unnormalized = point_unnormalize(points_unnormalized, axisp[3:])
print("axisp: ", axisp[:3], axisp_unnormalized)
elif args.use_q:
axisp_unnormalized = point_unnormalize(points_unnormalized, axisp)
print("axisp: ", axisp_unnormalized)
visualize_pcl_axis([axisp], args.num_point, points[0,:,:3], savepath, args.use_q, args.use_s)
elif args.split == 'val':
VAL_DATASET = PartNormalDataset(root=datapath, npoints=args.num_point, task=args.task, split='val', normal_channel=args.normal, use_q=args.use_q, use_s=args.use_s)
valDataLoader = torch.utils.data.DataLoader(VAL_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
log_string("The number of val data is: %d" % len(VAL_DATASET))
with torch.no_grad():
val_metrics = {}
val_losses = {'total': [], 'axis': [], 'mat_diff': []}
classifier = classifier.eval()
for batch_id, (points, target) in tqdm(enumerate(valDataLoader), total=len(valDataLoader), smoothing=0.9):
cur_batch_size, NUM_POINT, _ = points.size()
points, target = points.float().cuda(), target.float().cuda()
points = points.transpose(2, 1)
axis_pred, trans_feat = classifier(points)
loss_dict = criterion(axis_pred, target, trans_feat)
val_losses['total'].append(loss_dict['total'].item())
val_losses['axis'].append(loss_dict['axis'].item())
val_losses['mat_diff'].append(loss_dict['mat_diff'].item())
# save visualization of predictions
points = points.transpose(2, 1).cpu().numpy()
axis_pred = axis_pred.cpu().data.numpy()
target = target.cpu().data.numpy()
for i in range(cur_batch_size):
axisp = axis_pred[i, :]
axisl = target[i, :]
savepath = osp.join(savedir, f'{i:02d}.png')
visualize_pcl_axis([axisl, axisp], NUM_POINT, points[i,:,:3], savepath, use_q=args.use_q, use_s=args.use_s)
val_metrics['total_loss'] = np.mean(val_losses['total'])
val_metrics['axis_loss'] = np.mean(val_losses['axis'])
val_metrics['mat_diff_loss'] = np.mean(val_losses['mat_diff'])
log_string('Val loss\t total: {:.4f} axis: {:.4f} mat diff: {:.4f}' .format(
val_metrics['total_loss'], val_metrics['axis_loss'], val_metrics['mat_diff_loss']))
# save metrics
with open(osp.join(eval_dir, 'metrics.txt'), 'w') as f:
for key, value in val_metrics.items():
f.write(f'{key}: {value}\n')
if __name__ == '__main__':
args = parse_args()
main(args)