forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pointnet2.py
447 lines (383 loc) · 13.1 KB
/
pointnet2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.geometry import (
farthest_point_sampler,
) # dgl.geometry.pytorch -> dgl.geometry
from torch.autograd import Variable
"""
Part of the code are adapted from
https://github.com/yanx27/Pointnet_Pointnet2_pytorch
"""
def square_distance(src, dst):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src**2, -1).view(B, N, 1)
dist += torch.sum(dst**2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = (
torch.arange(B, dtype=torch.long)
.to(device)
.view(view_shape)
.repeat(repeat_shape)
)
new_points = points[batch_indices, idx, :]
return new_points
class FixedRadiusNearNeighbors(nn.Module):
"""
Ball Query - Find the neighbors with-in a fixed radius
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNearNeighbors, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
def forward(self, pos, centroids):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
"""
device = pos.device
B, N, _ = pos.shape
center_pos = index_points(pos, centroids)
_, S, _ = center_pos.shape
group_idx = (
torch.arange(N, dtype=torch.long)
.to(device)
.view(1, 1, N)
.repeat([B, S, 1])
)
sqrdists = square_distance(center_pos, pos)
group_idx[sqrdists > self.radius**2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, : self.n_neighbor]
group_first = (
group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, self.n_neighbor])
)
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
class FixedRadiusNNGraph(nn.Module):
"""
Build NN graph
"""
def __init__(self, radius, n_neighbor):
super(FixedRadiusNNGraph, self).__init__()
self.radius = radius
self.n_neighbor = n_neighbor
self.frnn = FixedRadiusNearNeighbors(radius, n_neighbor)
def forward(self, pos, centroids, feat=None):
dev = pos.device
group_idx = self.frnn(pos, centroids)
B, N, _ = pos.shape
glist = []
for i in range(B):
center = torch.zeros((N)).to(dev)
center[centroids[i]] = 1
src = group_idx[i].contiguous().view(-1)
dst = centroids[i].view(-1, 1).repeat(1, self.n_neighbor).view(-1)
unified = torch.cat([src, dst])
uniq, inv_idx = torch.unique(unified, return_inverse=True)
src_idx = inv_idx[: src.shape[0]]
dst_idx = inv_idx[src.shape[0] :]
g = dgl.graph((src_idx, dst_idx))
g.ndata["pos"] = pos[i][uniq]
g.ndata["center"] = center[uniq]
if feat is not None:
g.ndata["feat"] = feat[i][uniq]
glist.append(g)
bg = dgl.batch(glist)
return bg
class RelativePositionMessage(nn.Module):
"""
Compute the input feature from neighbors
"""
def __init__(self, n_neighbor):
super(RelativePositionMessage, self).__init__()
self.n_neighbor = n_neighbor
def forward(self, edges):
pos = edges.src["pos"] - edges.dst["pos"]
if "feat" in edges.src:
res = torch.cat([pos, edges.src["feat"]], 1)
else:
res = pos
return {"agg_feat": res}
class PointNetConv(nn.Module):
"""
Feature aggregation
"""
def __init__(self, sizes, batch_size):
super(PointNetConv, self).__init__()
self.batch_size = batch_size
self.conv = nn.ModuleList()
self.bn = nn.ModuleList()
for i in range(1, len(sizes)):
self.conv.append(nn.Conv2d(sizes[i - 1], sizes[i], 1))
self.bn.append(nn.BatchNorm2d(sizes[i]))
def forward(self, nodes):
shape = nodes.mailbox["agg_feat"].shape
h = (
nodes.mailbox["agg_feat"]
.view(self.batch_size, -1, shape[1], shape[2])
.permute(0, 3, 2, 1)
)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h, 2)[0]
feat_dim = h.shape[1]
h = h.permute(0, 2, 1).reshape(-1, feat_dim)
return {"new_feat": h}
def group_all(self, pos, feat):
"""
Feature aggregation and pooling for the non-sampling layer
"""
if feat is not None:
h = torch.cat([pos, feat], 2)
else:
h = pos
B, N, D = h.shape
_, _, C = pos.shape
new_pos = torch.zeros(B, 1, C)
h = h.permute(0, 2, 1).view(B, -1, N, 1)
for conv, bn in zip(self.conv, self.bn):
h = conv(h)
h = bn(h)
h = F.relu(h)
h = torch.max(h[:, :, :, 0], 2)[0] # [B,D]
return new_pos, h
class SAModule(nn.Module):
"""
The Set Abstraction Layer
"""
def __init__(
self,
npoints,
batch_size,
radius,
mlp_sizes,
n_neighbor=64,
group_all=False,
):
super(SAModule, self).__init__()
self.group_all = group_all
if not group_all:
self.npoints = npoints
self.frnn_graph = FixedRadiusNNGraph(radius, n_neighbor)
self.message = RelativePositionMessage(n_neighbor)
self.conv = PointNetConv(mlp_sizes, batch_size)
self.batch_size = batch_size
def forward(self, pos, feat):
if self.group_all:
return self.conv.group_all(pos, feat)
centroids = farthest_point_sampler(pos, self.npoints)
g = self.frnn_graph(pos, centroids, feat)
g.update_all(self.message, self.conv)
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
pos_res = g.ndata["pos"][mask].view(self.batch_size, -1, pos_dim)
feat_res = g.ndata["new_feat"][mask].view(self.batch_size, -1, feat_dim)
return pos_res, feat_res
class SAMSGModule(nn.Module):
"""
The Set Abstraction Multi-Scale grouping Layer
"""
def __init__(
self, npoints, batch_size, radius_list, n_neighbor_list, mlp_sizes_list
):
super(SAMSGModule, self).__init__()
self.batch_size = batch_size
self.group_size = len(radius_list)
self.npoints = npoints
self.frnn_graph_list = nn.ModuleList()
self.message_list = nn.ModuleList()
self.conv_list = nn.ModuleList()
for i in range(self.group_size):
self.frnn_graph_list.append(
FixedRadiusNNGraph(radius_list[i], n_neighbor_list[i])
)
self.message_list.append(
RelativePositionMessage(n_neighbor_list[i])
)
self.conv_list.append(PointNetConv(mlp_sizes_list[i], batch_size))
def forward(self, pos, feat):
centroids = farthest_point_sampler(pos, self.npoints)
feat_res_list = []
for i in range(self.group_size):
g = self.frnn_graph_list[i](pos, centroids, feat)
g.update_all(self.message_list[i], self.conv_list[i])
mask = g.ndata["center"] == 1
pos_dim = g.ndata["pos"].shape[-1]
feat_dim = g.ndata["new_feat"].shape[-1]
if i == 0:
pos_res = g.ndata["pos"][mask].view(
self.batch_size, -1, pos_dim
)
feat_res = g.ndata["new_feat"][mask].view(
self.batch_size, -1, feat_dim
)
feat_res_list.append(feat_res)
feat_res = torch.cat(feat_res_list, 2)
return pos_res, feat_res
class PointNet2FP(nn.Module):
"""
The Feature Propagation Layer
"""
def __init__(self, input_dims, sizes):
super(PointNet2FP, self).__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
sizes = [input_dims] + sizes
for i in range(1, len(sizes)):
self.convs.append(nn.Conv1d(sizes[i - 1], sizes[i], 1))
self.bns.append(nn.BatchNorm1d(sizes[i]))
def forward(self, x1, x2, feat1, feat2):
"""
Adapted from https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Input:
x1: input points position data, [B, N, C]
x2: sampled input points position data, [B, S, C]
feat1: input points data, [B, N, D]
feat2: input points data, [B, S, D]
Return:
new_feat: upsampled points data, [B, D', N]
"""
B, N, C = x1.shape
_, S, _ = x2.shape
if S == 1:
interpolated_feat = feat2.repeat(1, N, 1)
else:
dists = square_distance(x1, x2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feat = torch.sum(
index_points(feat2, idx) * weight.view(B, N, 3, 1), dim=2
)
if feat1 is not None:
new_feat = torch.cat([feat1, interpolated_feat], dim=-1)
else:
new_feat = interpolated_feat
new_feat = new_feat.permute(0, 2, 1) # [B, D, S]
for i, conv in enumerate(self.convs):
bn = self.bns[i]
new_feat = F.relu(bn(conv(new_feat)))
return new_feat
class PointNet2SSGCls(nn.Module):
def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2SSGCls, self).__init__()
self.input_dims = input_dims
self.sa_module1 = SAModule(
512, batch_size, 0.2, [input_dims, 64, 64, 128]
)
self.sa_module2 = SAModule(
128, batch_size, 0.4, [128 + 3, 128, 128, 256]
)
self.sa_module3 = SAModule(
None, batch_size, None, [256 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(dropout_prob)
self.mlp2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(dropout_prob)
self.mlp_out = nn.Linear(256, output_classes)
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
feat = x[:, :, 3:]
else:
pos = x
feat = None
pos, feat = self.sa_module1(pos, feat)
pos, feat = self.sa_module2(pos, feat)
_, h = self.sa_module3(pos, feat)
h = self.mlp1(h)
h = self.bn1(h)
h = F.relu(h)
h = self.drop1(h)
h = self.mlp2(h)
h = self.bn2(h)
h = F.relu(h)
h = self.drop2(h)
out = self.mlp_out(h)
return out
class PointNet2MSGCls(nn.Module):
def __init__(
self, output_classes, batch_size, input_dims=3, dropout_prob=0.4
):
super(PointNet2MSGCls, self).__init__()
self.input_dims = input_dims
self.sa_msg_module1 = SAMSGModule(
512,
batch_size,
[0.1, 0.2, 0.4],
[16, 32, 128],
[
[input_dims, 32, 32, 64],
[input_dims, 64, 64, 128],
[input_dims, 64, 96, 128],
],
)
self.sa_msg_module2 = SAMSGModule(
128,
batch_size,
[0.2, 0.4, 0.8],
[32, 64, 128],
[
[320 + 3, 64, 64, 128],
[320 + 3, 128, 128, 256],
[320 + 3, 128, 128, 256],
],
)
self.sa_module3 = SAModule(
None, batch_size, None, [640 + 3, 256, 512, 1024], group_all=True
)
self.mlp1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(dropout_prob)
self.mlp2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(dropout_prob)
self.mlp_out = nn.Linear(256, output_classes)
def forward(self, x):
if x.shape[-1] > 3:
pos = x[:, :, :3]
feat = x[:, :, 3:]
else:
pos = x
feat = None
pos, feat = self.sa_msg_module1(pos, feat)
pos, feat = self.sa_msg_module2(pos, feat)
_, h = self.sa_module3(pos, feat)
h = self.mlp1(h)
h = self.bn1(h)
h = F.relu(h)
h = self.drop1(h)
h = self.mlp2(h)
h = self.bn2(h)
h = F.relu(h)
h = self.drop2(h)
out = self.mlp_out(h)
return out