-
Notifications
You must be signed in to change notification settings - Fork 165
/
pyramid.py
127 lines (105 loc) · 3.98 KB
/
pyramid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torch.nn.functional as F
from lib.exceptions import EmptyTensorError
from lib.utils import interpolate_dense_features, upscale_positions
def process_multiscale(image, model, scales=[.5, 1, 2]):
b, _, h_init, w_init = image.size()
device = image.device
assert(b == 1)
all_keypoints = torch.zeros([3, 0])
all_descriptors = torch.zeros([
model.dense_feature_extraction.num_channels, 0
])
all_scores = torch.zeros(0)
previous_dense_features = None
banned = None
for idx, scale in enumerate(scales):
current_image = F.interpolate(
image, scale_factor=scale,
mode='bilinear', align_corners=True
)
_, _, h_level, w_level = current_image.size()
dense_features = model.dense_feature_extraction(current_image)
del current_image
_, _, h, w = dense_features.size()
# Sum the feature maps.
if previous_dense_features is not None:
dense_features += F.interpolate(
previous_dense_features, size=[h, w],
mode='bilinear', align_corners=True
)
del previous_dense_features
# Recover detections.
detections = model.detection(dense_features)
if banned is not None:
banned = F.interpolate(banned.float(), size=[h, w]).byte()
detections = torch.min(detections, (1 - banned))
banned = torch.max(
torch.max(detections, dim=1)[0].unsqueeze(1), banned
)
else:
banned = torch.max(detections, dim=1)[0].unsqueeze(1)
fmap_pos = torch.nonzero(detections[0]).t().cpu()
del detections
# Recover displacements.
displacements = model.localization(dense_features)[0]
displacements_i = displacements[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
displacements_j = displacements[
1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
]
del displacements
mask = torch.min(
torch.abs(displacements_i) < 0.5,
torch.abs(displacements_j) < 0.5
)
fmap_pos = fmap_pos[:, mask]
valid_displacements = torch.stack([
displacements_i[mask],
displacements_j[mask]
], dim=0).cpu()
del mask, displacements_i, displacements_j
fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
del valid_displacements
try:
raw_descriptors, _, ids = interpolate_dense_features(
fmap_keypoints.to(device),
dense_features[0]
)
except EmptyTensorError:
continue
fmap_pos = fmap_pos[:, ids]
fmap_keypoints = fmap_keypoints[:, ids]
del ids
keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
del fmap_keypoints
descriptors = F.normalize(raw_descriptors, dim=0).cpu()
del raw_descriptors
keypoints[0, :] *= h_init / h_level
keypoints[1, :] *= w_init / w_level
fmap_pos = fmap_pos.cpu()
keypoints = keypoints.cpu()
keypoints = torch.cat([
keypoints,
torch.ones([1, keypoints.size(1)]) * 1 / scale,
], dim=0)
scores = dense_features[
0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
].cpu() / (idx + 1)
del fmap_pos
all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
all_scores = torch.cat([all_scores, scores], dim=0)
del keypoints, descriptors
previous_dense_features = dense_features
del dense_features
del previous_dense_features, banned
keypoints = all_keypoints.t().numpy()
del all_keypoints
scores = all_scores.numpy()
del all_scores
descriptors = all_descriptors.t().numpy()
del all_descriptors
return keypoints, scores, descriptors